AmdKamel's picture
Update app.py
38eeb5b verified
raw
history blame contribute delete
873 Bytes
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline
from fastapi.middleware.cors import CORSMiddleware
# Load the Arabic dialect classifier
model_name = "IbrahimAmin/marbertv2-arabic-written-dialect-classifier"
dialect = pipeline("text-classification", model=model_name)
# FastAPI app
app = FastAPI(title="Arabic Dialect Detector API")
# Allow CORS (so n8n or browser can call it)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"]
)
# Input schema
class InputText(BaseModel):
text: str
# API endpoint
@app.post("/predict")
def predict(input: InputText):
if not input.text:
return {"label": None, "score": None, "error": "No input text"}
result = dialect(input.text)[0]
return {"label": result['label'], "score": round(result['score'], 3)}