ollamcivic / app.py
KRISH09bha's picture
Update app.py
5332848 verified
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from transformers import AutoProcessor, AutoModelForCausalLM
import torch
from PIL import Image
import io
import os
os.environ["HF_HOME"] = "/app/.cache"
os.environ["HF_DATASETS_CACHE"] = "/app/.cache"
os.environ["TRANSFORMERS_CACHE"] = "/app/.cache"
app = FastAPI()
MODEL_NAME = os.getenv("MODEL_NAME", "lmms-lab/LLaVA-OneVision-1.5-8B-Instruct")
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME, torch_dtype="auto", device_map="auto", trust_remote_code=True
)
processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
@app.post("/analyze-image")
async def analyze_image(file: UploadFile = File(...), prompt: str = "Describe this image."):
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes))
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(
text=[text],
images=[image],
padding=True,
return_tensors="pt",
)
inputs = inputs.to(model.device)
generated_ids = model.generate(**inputs, max_new_tokens=1024)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return JSONResponse(content={"result": output_text[0]})