LukasHug commited on
Commit
3871a0b
·
verified ·
1 Parent(s): 9853b95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -5
app.py CHANGED
@@ -183,6 +183,7 @@ def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512):
183
  padding=True,
184
  return_tensors="pt",
185
  )
 
186
 
187
 
188
  # Otherwise assume it's a LlavaGuard model
@@ -198,11 +199,7 @@ def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512):
198
  ]
199
  text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
200
  inputs = processor(text=text_prompt, images=image, return_tensors="pt")
201
-
202
- processed_inputs = {}
203
- for key, value in inputs.items():
204
- processed_inputs[key] = value.to(model.device, dtype=torch.bfloat16)
205
- inputs = processed_inputs
206
 
207
  with torch.no_grad():
208
  generated_ids = model.generate(
 
183
  padding=True,
184
  return_tensors="pt",
185
  )
186
+ inputs.to(model.device)
187
 
188
 
189
  # Otherwise assume it's a LlavaGuard model
 
199
  ]
200
  text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
201
  inputs = processor(text=text_prompt, images=image, return_tensors="pt")
202
+ inputs.to(model.device)
 
 
 
 
203
 
204
  with torch.no_grad():
205
  generated_ids = model.generate(