skin-advisor-api / main.py
lingling707's picture
Update main.py
6d92e94 verified
import os
# Thư mục cache có quyền ghi
os.environ["HF_HOME"] = "/tmp/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_cache"
import torch
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Optional
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
from fastapi.middleware.cors import CORSMiddleware
# ==== Hugging Face token & model ====
model_name = "lingling707/vit5-skinbot"
hf_token = os.getenv("HF_TOKEN")
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token, local_files_only=False)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, token=hf_token, local_files_only=False)
# ==== Chatbot pipeline ====
chatbot = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
device=0 # GPU: 0, CPU: -1
)
# ==== Model phân loại ảnh da ====
image_model = pipeline("image-classification", model="dima806/skin_types_image_detection")
# ==== FastAPI setup ====
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class RequestData(BaseModel):
userMessage: str
imageUrl: Optional[str] = None
def map_labels_to_skin_type(label: str):
label = label.lower()
return {
"oily": "da dầu, dễ nổi mụn",
"dry": "da khô, có thể bong tróc hoặc lão hóa sớm",
"normal": "da thường, cân bằng"
}.get(label, "không xác định rõ loại da")
@app.get("/")
async def root():
return {"message": "Skin Advisor API (Vit5 Fine-Tune) is running 🚀"}
@app.post("/skinAdvisor")
async def skin_advisor(data: RequestData):
skin_analysis = ""
if data.imageUrl and data.imageUrl.startswith(("http://","https://")):
try:
results = image_model(data.imageUrl)
top = max(results, key=lambda x: x['score'])
skin_type = map_labels_to_skin_type(top['label'])
skin_analysis = f"Ảnh phân tích cho thấy: {skin_type}."
except Exception as e:
print("Image analysis error:", e)
skin_analysis = "Không thể phân tích ảnh da."
prompt = (
f"Bạn là chuyên gia tư vấn chăm sóc da. {skin_analysis} Người dùng hỏi: {data.userMessage}. "
"Hãy trả lời ngắn gọn, thân thiện, bằng tiếng Việt, "
"chỉ tư vấn chăm sóc da, tuyệt đối không gợi ý nguy hiểm."
)
try:
result = chatbot(prompt, max_new_tokens=120, truncation=True)
reply = result[0]["generated_text"].strip() or "Xin lỗi, mình chưa có câu trả lời phù hợp."
except Exception as e:
print("Error:", e)
reply = "Xin lỗi, bot chưa trả lời được. Vui lòng thử lại."
if skin_analysis and skin_analysis.lower() not in reply.lower():
reply = f"{skin_analysis} {reply}"
return {"reply": reply}