LukasHug commited on
Commit
e911fd3
·
verified ·
1 Parent(s): 2b4048a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -1
app.py CHANGED
@@ -199,7 +199,14 @@ def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512):
199
  text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
200
  inputs = processor(text=text_prompt, images=image, return_tensors="pt")
201
 
202
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
 
 
 
 
 
 
 
203
 
204
  with torch.no_grad():
205
  generated_ids = model.generate(
@@ -582,6 +589,9 @@ if api_key:
582
  # Load model at startup
583
  model_path = DEFAULT_MODEL
584
  logger.info(f"Loading model: {model_path}")
 
 
 
585
  # Check if it's a Qwen model
586
  if "qwenguard" in model_path.lower():
587
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
 
199
  text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
200
  inputs = processor(text=text_prompt, images=image, return_tensors="pt")
201
 
202
+ target_dtype = getattr(model, "dtype", None)
203
+ processed_inputs = {}
204
+ for key, value in inputs.items():
205
+ if torch.is_floating_point(value) and target_dtype is not None:
206
+ processed_inputs[key] = value.to(model.device, dtype=target_dtype)
207
+ else:
208
+ processed_inputs[key] = value.to(model.device)
209
+ inputs = processed_inputs
210
 
211
  with torch.no_grad():
212
  generated_ids = model.generate(
 
589
  # Load model at startup
590
  model_path = DEFAULT_MODEL
591
  logger.info(f"Loading model: {model_path}")
592
+
593
+
594
+
595
  # Check if it's a Qwen model
596
  if "qwenguard" in model_path.lower():
597
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(