Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	added labels as in input
Browse files
    	
        main.py
    CHANGED
    
    | @@ -5,6 +5,7 @@ import torch | |
| 5 | 
             
            from detoxify import Detoxify
         | 
| 6 | 
             
            import asyncio
         | 
| 7 | 
             
            from fastapi.concurrency import run_in_threadpool
         | 
|  | |
| 8 |  | 
| 9 | 
             
            class Guardrail:
         | 
| 10 | 
             
                def __init__(self):
         | 
| @@ -60,17 +61,20 @@ class TopicBannerClassifier: | |
| 60 | 
             
                        device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
| 61 | 
             
                    )
         | 
| 62 | 
             
                    self.hypothesis_template = "This text is about {}"
         | 
| 63 | 
            -
                    self.classes_verbalized = ["politics", "economy", "entertainment", "environment"]
         | 
| 64 |  | 
| 65 | 
            -
                async def classify(self, text):
         | 
| 66 | 
             
                    return await run_in_threadpool(
         | 
| 67 | 
             
                        self.classifier,
         | 
| 68 | 
             
                        text,
         | 
| 69 | 
            -
                         | 
| 70 | 
             
                        hypothesis_template=self.hypothesis_template,
         | 
| 71 | 
             
                        multi_label=False
         | 
| 72 | 
             
                    )
         | 
| 73 |  | 
|  | |
|  | |
|  | |
|  | |
| 74 | 
             
            class TopicBannerResult(BaseModel):
         | 
| 75 | 
             
                sequence: str
         | 
| 76 | 
             
                labels: list
         | 
| @@ -108,9 +112,9 @@ async def classify_text(text_prompt: TextPrompt): | |
| 108 | 
             
                    raise HTTPException(status_code=500, detail=str(e))
         | 
| 109 |  | 
| 110 | 
             
            @app.post("/api/models/TopicBanner/classify", response_model=TopicBannerResult)
         | 
| 111 | 
            -
            async def classify_topic_banner( | 
| 112 | 
             
                try:
         | 
| 113 | 
            -
                    result = await topic_banner_classifier.classify( | 
| 114 | 
             
                    return {
         | 
| 115 | 
             
                        "sequence": result["sequence"],
         | 
| 116 | 
             
                        "labels": result["labels"],
         | 
|  | |
| 5 | 
             
            from detoxify import Detoxify
         | 
| 6 | 
             
            import asyncio
         | 
| 7 | 
             
            from fastapi.concurrency import run_in_threadpool
         | 
| 8 | 
            +
            from typing import List
         | 
| 9 |  | 
| 10 | 
             
            class Guardrail:
         | 
| 11 | 
             
                def __init__(self):
         | 
|  | |
| 61 | 
             
                        device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
| 62 | 
             
                    )
         | 
| 63 | 
             
                    self.hypothesis_template = "This text is about {}"
         | 
|  | |
| 64 |  | 
| 65 | 
            +
                async def classify(self, text, labels):
         | 
| 66 | 
             
                    return await run_in_threadpool(
         | 
| 67 | 
             
                        self.classifier,
         | 
| 68 | 
             
                        text,
         | 
| 69 | 
            +
                        labels,
         | 
| 70 | 
             
                        hypothesis_template=self.hypothesis_template,
         | 
| 71 | 
             
                        multi_label=False
         | 
| 72 | 
             
                    )
         | 
| 73 |  | 
| 74 | 
            +
            class TopicBannerRequest(BaseModel):
         | 
| 75 | 
            +
                prompt: str
         | 
| 76 | 
            +
                labels: List[str]
         | 
| 77 | 
            +
             | 
| 78 | 
             
            class TopicBannerResult(BaseModel):
         | 
| 79 | 
             
                sequence: str
         | 
| 80 | 
             
                labels: list
         | 
|  | |
| 112 | 
             
                    raise HTTPException(status_code=500, detail=str(e))
         | 
| 113 |  | 
| 114 | 
             
            @app.post("/api/models/TopicBanner/classify", response_model=TopicBannerResult)
         | 
| 115 | 
            +
            async def classify_topic_banner(request: TopicBannerRequest):
         | 
| 116 | 
             
                try:
         | 
| 117 | 
            +
                    result = await topic_banner_classifier.classify(request.prompt, request.labels)
         | 
| 118 | 
             
                    return {
         | 
| 119 | 
             
                        "sequence": result["sequence"],
         | 
| 120 | 
             
                        "labels": result["labels"],
         |