Spaces:
Sleeping
Sleeping
Create profanity_guard.py
Browse files- tools/profanity_guard.py +63 -0
tools/profanity_guard.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# tools/profanity_guard.py
|
| 2 |
+
from typing import Any, Dict
|
| 3 |
+
from smolagents.tools import Tool
|
| 4 |
+
import json
|
| 5 |
+
|
| 6 |
+
class ProfanityGuardTool(Tool):
|
| 7 |
+
name = "profanity_guard"
|
| 8 |
+
description = "Detects profanity in English text and returns a label and confidence."
|
| 9 |
+
inputs: Dict[str, Dict[str, Any]] = {
|
| 10 |
+
"text": {"type": "string", "description": "English text to check for profanity."}
|
| 11 |
+
}
|
| 12 |
+
output_type = "string" # return JSON string to match your web_search.py pattern
|
| 13 |
+
|
| 14 |
+
def __init__(self, model_name: str = "tarekziade/pardonmyai", device: int | None = None, **kwargs: Any) -> None:
|
| 15 |
+
"""
|
| 16 |
+
model_name options:
|
| 17 |
+
- "tarekziade/pardonmyai" (default, DistilBERT-based, binary PROFANE/CLEAN)
|
| 18 |
+
- "tarekziade/pardonmyai-tiny" (smaller, faster)
|
| 19 |
+
"""
|
| 20 |
+
super().__init__()
|
| 21 |
+
try:
|
| 22 |
+
import torch # noqa: F401
|
| 23 |
+
from transformers import pipeline # type: ignore
|
| 24 |
+
except ImportError as e:
|
| 25 |
+
raise ImportError(
|
| 26 |
+
"You must install `transformers` (and optionally `torch`) to use ProfanityGuardTool.\n"
|
| 27 |
+
"Example: pip install transformers torch --extra-index-url https://download.pytorch.org/whl/cu121"
|
| 28 |
+
) from e
|
| 29 |
+
|
| 30 |
+
self.model_name = model_name
|
| 31 |
+
# Pick device automatically if not specified
|
| 32 |
+
try:
|
| 33 |
+
import torch
|
| 34 |
+
if device is None:
|
| 35 |
+
device = 0 if torch.cuda.is_available() else -1
|
| 36 |
+
except Exception:
|
| 37 |
+
device = -1 # CPU fallback if torch not available/working
|
| 38 |
+
|
| 39 |
+
# Build the pipeline once (fast subsequent calls)
|
| 40 |
+
from transformers import pipeline
|
| 41 |
+
self.pipe = pipeline(
|
| 42 |
+
task="sentiment-analysis", # model card uses this task name
|
| 43 |
+
model=self.model_name,
|
| 44 |
+
device=device,
|
| 45 |
+
truncation=True
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
def forward(self, text: str) -> str:
|
| 49 |
+
t = (text or "").strip()
|
| 50 |
+
if not t:
|
| 51 |
+
raise ValueError("`text` must be a non-empty string.")
|
| 52 |
+
|
| 53 |
+
# Light normalization so profanity isn't split by odd whitespace
|
| 54 |
+
t = " ".join(t.split())
|
| 55 |
+
|
| 56 |
+
out = self.pipe(t)[0] # e.g. {'label': 'PROFANE'|'CLEAN', 'score': 0.xx}
|
| 57 |
+
|
| 58 |
+
payload = {
|
| 59 |
+
"model": self.model_name,
|
| 60 |
+
"label": str(out.get("label", "")),
|
| 61 |
+
"score": float(out.get("score", 0.0)),
|
| 62 |
+
}
|
| 63 |
+
return json.dumps(payload)
|