Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline | |
| import torch | |
| from detoxify import Detoxify | |
| import asyncio | |
| from fastapi.concurrency import run_in_threadpool | |
| from typing import List | |
| class Guardrail: | |
| def __init__(self): | |
| tokenizer = AutoTokenizer.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection") | |
| model = AutoModelForSequenceClassification.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection") | |
| self.classifier = pipeline( | |
| "text-classification", | |
| model=model, | |
| tokenizer=tokenizer, | |
| truncation=True, | |
| max_length=512, | |
| device=torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| ) | |
| async def guard(self, prompt): | |
| return await run_in_threadpool(self.classifier, prompt) | |
| def determine_level(self, label, score): | |
| if label == "SAFE": | |
| return 0, "safe" | |
| else: | |
| if score > 0.9: | |
| return 4, "high" | |
| elif score > 0.75: | |
| return 3, "medium" | |
| elif score > 0.5: | |
| return 2, "low" | |
| else: | |
| return 1, "very low" | |
| class TextPrompt(BaseModel): | |
| prompt: str | |
| class ClassificationResult(BaseModel): | |
| label: str | |
| score: float | |
| level: int | |
| severity_label: str | |
| class ToxicityResult(BaseModel): | |
| toxicity: float | |
| severe_toxicity: float | |
| obscene: float | |
| threat: float | |
| insult: float | |
| identity_attack: float | |
| class TopicBannerClassifier: | |
| def __init__(self): | |
| self.classifier = pipeline( | |
| "zero-shot-classification", | |
| model="MoritzLaurer/deberta-v3-large-zeroshot-v2.0", | |
| device=torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| ) | |
| self.hypothesis_template = "This text is about {}" | |
| async def classify(self, text, labels): | |
| return await run_in_threadpool( | |
| self.classifier, | |
| text, | |
| labels, | |
| hypothesis_template=self.hypothesis_template, | |
| multi_label=False | |
| ) | |
| class TopicBannerRequest(BaseModel): | |
| prompt: str | |
| labels: List[str] | |
| class TopicBannerResult(BaseModel): | |
| sequence: str | |
| labels: list | |
| scores: list | |
| app = FastAPI() | |
| guardrail = Guardrail() | |
| toxicity_classifier = Detoxify('original') | |
| topic_banner_classifier = TopicBannerClassifier() | |
| async def classify_toxicity(text_prompt: TextPrompt): | |
| try: | |
| result = await run_in_threadpool(toxicity_classifier.predict, text_prompt.prompt) | |
| return { | |
| "toxicity": result['toxicity'], | |
| "severe_toxicity": result['severe_toxicity'], | |
| "obscene": result['obscene'], | |
| "threat": result['threat'], | |
| "insult": result['insult'], | |
| "identity_attack": result['identity_attack'] | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def classify_text(text_prompt: TextPrompt): | |
| try: | |
| result = await guardrail.guard(text_prompt.prompt) | |
| label = result[0]['label'] | |
| score = result[0]['score'] | |
| level, severity_label = guardrail.determine_level(label, score) | |
| return {"label": label, "score": score, "level": level, "severity_label": severity_label} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def classify_topic_banner(request: TopicBannerRequest): | |
| try: | |
| result = await topic_banner_classifier.classify(request.prompt, request.labels) | |
| return { | |
| "sequence": result["sequence"], | |
| "labels": result["labels"], | |
| "scores": result["scores"] | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |