yetessam commited on
Commit
9ee432e
·
verified ·
1 Parent(s): fb0c953

Create profanity_guard.py

Browse files
Files changed (1) hide show
  1. tools/profanity_guard.py +63 -0
tools/profanity_guard.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tools/profanity_guard.py
2
+ from typing import Any, Dict
3
+ from smolagents.tools import Tool
4
+ import json
5
+
6
+ class ProfanityGuardTool(Tool):
7
+ name = "profanity_guard"
8
+ description = "Detects profanity in English text and returns a label and confidence."
9
+ inputs: Dict[str, Dict[str, Any]] = {
10
+ "text": {"type": "string", "description": "English text to check for profanity."}
11
+ }
12
+ output_type = "string" # return JSON string to match your web_search.py pattern
13
+
14
+ def __init__(self, model_name: str = "tarekziade/pardonmyai", device: int | None = None, **kwargs: Any) -> None:
15
+ """
16
+ model_name options:
17
+ - "tarekziade/pardonmyai" (default, DistilBERT-based, binary PROFANE/CLEAN)
18
+ - "tarekziade/pardonmyai-tiny" (smaller, faster)
19
+ """
20
+ super().__init__()
21
+ try:
22
+ import torch # noqa: F401
23
+ from transformers import pipeline # type: ignore
24
+ except ImportError as e:
25
+ raise ImportError(
26
+ "You must install `transformers` (and optionally `torch`) to use ProfanityGuardTool.\n"
27
+ "Example: pip install transformers torch --extra-index-url https://download.pytorch.org/whl/cu121"
28
+ ) from e
29
+
30
+ self.model_name = model_name
31
+ # Pick device automatically if not specified
32
+ try:
33
+ import torch
34
+ if device is None:
35
+ device = 0 if torch.cuda.is_available() else -1
36
+ except Exception:
37
+ device = -1 # CPU fallback if torch not available/working
38
+
39
+ # Build the pipeline once (fast subsequent calls)
40
+ from transformers import pipeline
41
+ self.pipe = pipeline(
42
+ task="sentiment-analysis", # model card uses this task name
43
+ model=self.model_name,
44
+ device=device,
45
+ truncation=True
46
+ )
47
+
48
+ def forward(self, text: str) -> str:
49
+ t = (text or "").strip()
50
+ if not t:
51
+ raise ValueError("`text` must be a non-empty string.")
52
+
53
+ # Light normalization so profanity isn't split by odd whitespace
54
+ t = " ".join(t.split())
55
+
56
+ out = self.pipe(t)[0] # e.g. {'label': 'PROFANE'|'CLEAN', 'score': 0.xx}
57
+
58
+ payload = {
59
+ "model": self.model_name,
60
+ "label": str(out.get("label", "")),
61
+ "score": float(out.get("score", 0.0)),
62
+ }
63
+ return json.dumps(payload)