LukasHug commited on
Commit
9853b95
·
verified ·
1 Parent(s): e911fd3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -7
app.py CHANGED
@@ -199,13 +199,9 @@ def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512):
199
  text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
200
  inputs = processor(text=text_prompt, images=image, return_tensors="pt")
201
 
202
- target_dtype = getattr(model, "dtype", None)
203
  processed_inputs = {}
204
  for key, value in inputs.items():
205
- if torch.is_floating_point(value) and target_dtype is not None:
206
- processed_inputs[key] = value.to(model.device, dtype=target_dtype)
207
- else:
208
- processed_inputs[key] = value.to(model.device)
209
  inputs = processed_inputs
210
 
211
  with torch.no_grad():
@@ -596,7 +592,7 @@ logger.info(f"Loading model: {model_path}")
596
  if "qwenguard" in model_path.lower():
597
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
598
  model_path,
599
- torch_dtype="auto",
600
  device_map="auto"
601
  )
602
  processor = AutoProcessor.from_pretrained(model_path)
@@ -606,7 +602,7 @@ if "qwenguard" in model_path.lower():
606
  else:
607
  model = LlavaOnevisionForConditionalGeneration.from_pretrained(
608
  model_path,
609
- torch_dtype="auto",
610
  device_map="auto",
611
  trust_remote_code=True
612
  )
 
199
  text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
200
  inputs = processor(text=text_prompt, images=image, return_tensors="pt")
201
 
 
202
  processed_inputs = {}
203
  for key, value in inputs.items():
204
+ processed_inputs[key] = value.to(model.device, dtype=torch.bfloat16)
 
 
 
205
  inputs = processed_inputs
206
 
207
  with torch.no_grad():
 
592
  if "qwenguard" in model_path.lower():
593
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
594
  model_path,
595
+ torch_dtype=torch.bfloat16,
596
  device_map="auto"
597
  )
598
  processor = AutoProcessor.from_pretrained(model_path)
 
602
  else:
603
  model = LlavaOnevisionForConditionalGeneration.from_pretrained(
604
  model_path,
605
+ torch_dtype=torch.bfloat16,
606
  device_map="auto",
607
  trust_remote_code=True
608
  )