KRISH09bha commited on
Commit
3722851
·
verified ·
1 Parent(s): 80d1ffd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -19
app.py CHANGED
@@ -1,26 +1,22 @@
 
1
  from fastapi import FastAPI, File, UploadFile
2
  from fastapi.responses import JSONResponse
3
- from transformers import AutoTokenizer, AutoProcessor, AutoModelForCausalLM
4
  import torch
5
  from PIL import Image
6
  import io
 
7
 
8
  app = FastAPI()
9
 
10
- model_path = "lmms-lab/LLaVA-OneVision-1.5-8B-Instruct"
11
- os.environ["HF_HOME"] = "./huggingface_cache"
12
-
13
- model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto", device_map="auto", trust_remote_code=True)
14
- processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
15
-
16
- def process_vision_info(messages):
17
- # Dummy implementation, replace with actual from qwen_vl_utils
18
- image_inputs = [msg['content'][0]['image'] for msg in messages]
19
- video_inputs = None
20
- return image_inputs, video_inputs
21
 
22
  @app.post("/analyze-image")
23
- async def analyze_image(file: UploadFile = File(...)):
24
  image_bytes = await file.read()
25
  image = Image.open(io.BytesIO(image_bytes))
26
  messages = [
@@ -28,20 +24,18 @@ async def analyze_image(file: UploadFile = File(...)):
28
  "role": "user",
29
  "content": [
30
  {"type": "image", "image": image},
31
- {"type": "text", "text": "Describe this image."},
32
  ],
33
  }
34
  ]
35
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
36
- image_inputs, video_inputs = process_vision_info(messages)
37
  inputs = processor(
38
  text=[text],
39
- images=image_inputs,
40
- videos=video_inputs,
41
  padding=True,
42
  return_tensors="pt",
43
  )
44
- inputs = inputs.to("cuda")
45
  generated_ids = model.generate(**inputs, max_new_tokens=1024)
46
  generated_ids_trimmed = [
47
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
@@ -49,4 +43,4 @@ async def analyze_image(file: UploadFile = File(...)):
49
  output_text = processor.batch_decode(
50
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
51
  )
52
- return JSONResponse(content={"result": output_text})
 
1
+
2
  from fastapi import FastAPI, File, UploadFile
3
  from fastapi.responses import JSONResponse
4
+ from transformers import AutoProcessor, AutoModelForCausalLM
5
  import torch
6
  from PIL import Image
7
  import io
8
+ import os
9
 
10
  app = FastAPI()
11
 
12
+ MODEL_NAME = os.getenv("MODEL_NAME", "lmms-lab/LLaVA-OneVision-1.5-8B-Instruct")
13
+ model = AutoModelForCausalLM.from_pretrained(
14
+ MODEL_NAME, torch_dtype="auto", device_map="auto", trust_remote_code=True
15
+ )
16
+ processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
 
 
 
 
 
 
17
 
18
  @app.post("/analyze-image")
19
+ async def analyze_image(file: UploadFile = File(...), prompt: str = "Describe this image."):
20
  image_bytes = await file.read()
21
  image = Image.open(io.BytesIO(image_bytes))
22
  messages = [
 
24
  "role": "user",
25
  "content": [
26
  {"type": "image", "image": image},
27
+ {"type": "text", "text": prompt},
28
  ],
29
  }
30
  ]
31
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
32
  inputs = processor(
33
  text=[text],
34
+ images=[image],
 
35
  padding=True,
36
  return_tensors="pt",
37
  )
38
+ inputs = inputs.to(model.device)
39
  generated_ids = model.generate(**inputs, max_new_tokens=1024)
40
  generated_ids_trimmed = [
41
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
 
43
  output_text = processor.batch_decode(
44
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
45
  )
46
+ return JSONResponse(content={"result": output_text[0]})