ContentAgent / tools /profanity_guard.py
yetessam's picture
Create profanity_guard.py
9ee432e verified
raw
history blame
2.42 kB
# tools/profanity_guard.py
from typing import Any, Dict
from smolagents.tools import Tool
import json
class ProfanityGuardTool(Tool):
name = "profanity_guard"
description = "Detects profanity in English text and returns a label and confidence."
inputs: Dict[str, Dict[str, Any]] = {
"text": {"type": "string", "description": "English text to check for profanity."}
}
output_type = "string" # return JSON string to match your web_search.py pattern
def __init__(self, model_name: str = "tarekziade/pardonmyai", device: int | None = None, **kwargs: Any) -> None:
"""
model_name options:
- "tarekziade/pardonmyai" (default, DistilBERT-based, binary PROFANE/CLEAN)
- "tarekziade/pardonmyai-tiny" (smaller, faster)
"""
super().__init__()
try:
import torch # noqa: F401
from transformers import pipeline # type: ignore
except ImportError as e:
raise ImportError(
"You must install `transformers` (and optionally `torch`) to use ProfanityGuardTool.\n"
"Example: pip install transformers torch --extra-index-url https://download.pytorch.org/whl/cu121"
) from e
self.model_name = model_name
# Pick device automatically if not specified
try:
import torch
if device is None:
device = 0 if torch.cuda.is_available() else -1
except Exception:
device = -1 # CPU fallback if torch not available/working
# Build the pipeline once (fast subsequent calls)
from transformers import pipeline
self.pipe = pipeline(
task="sentiment-analysis", # model card uses this task name
model=self.model_name,
device=device,
truncation=True
)
def forward(self, text: str) -> str:
t = (text or "").strip()
if not t:
raise ValueError("`text` must be a non-empty string.")
# Light normalization so profanity isn't split by odd whitespace
t = " ".join(t.split())
out = self.pipe(t)[0] # e.g. {'label': 'PROFANE'|'CLEAN', 'score': 0.xx}
payload = {
"model": self.model_name,
"label": str(out.get("label", "")),
"score": float(out.get("score", 0.0)),
}
return json.dumps(payload)