Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -199,13 +199,9 @@ def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512):
|
|
| 199 |
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
| 200 |
inputs = processor(text=text_prompt, images=image, return_tensors="pt")
|
| 201 |
|
| 202 |
-
target_dtype = getattr(model, "dtype", None)
|
| 203 |
processed_inputs = {}
|
| 204 |
for key, value in inputs.items():
|
| 205 |
-
|
| 206 |
-
processed_inputs[key] = value.to(model.device, dtype=target_dtype)
|
| 207 |
-
else:
|
| 208 |
-
processed_inputs[key] = value.to(model.device)
|
| 209 |
inputs = processed_inputs
|
| 210 |
|
| 211 |
with torch.no_grad():
|
|
@@ -596,7 +592,7 @@ logger.info(f"Loading model: {model_path}")
|
|
| 596 |
if "qwenguard" in model_path.lower():
|
| 597 |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 598 |
model_path,
|
| 599 |
-
torch_dtype=
|
| 600 |
device_map="auto"
|
| 601 |
)
|
| 602 |
processor = AutoProcessor.from_pretrained(model_path)
|
|
@@ -606,7 +602,7 @@ if "qwenguard" in model_path.lower():
|
|
| 606 |
else:
|
| 607 |
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
|
| 608 |
model_path,
|
| 609 |
-
torch_dtype=
|
| 610 |
device_map="auto",
|
| 611 |
trust_remote_code=True
|
| 612 |
)
|
|
|
|
| 199 |
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
| 200 |
inputs = processor(text=text_prompt, images=image, return_tensors="pt")
|
| 201 |
|
|
|
|
| 202 |
processed_inputs = {}
|
| 203 |
for key, value in inputs.items():
|
| 204 |
+
processed_inputs[key] = value.to(model.device, dtype=torch.bfloat16)
|
|
|
|
|
|
|
|
|
|
| 205 |
inputs = processed_inputs
|
| 206 |
|
| 207 |
with torch.no_grad():
|
|
|
|
| 592 |
if "qwenguard" in model_path.lower():
|
| 593 |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 594 |
model_path,
|
| 595 |
+
torch_dtype=torch.bfloat16,
|
| 596 |
device_map="auto"
|
| 597 |
)
|
| 598 |
processor = AutoProcessor.from_pretrained(model_path)
|
|
|
|
| 602 |
else:
|
| 603 |
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
|
| 604 |
model_path,
|
| 605 |
+
torch_dtype=torch.bfloat16,
|
| 606 |
device_map="auto",
|
| 607 |
trust_remote_code=True
|
| 608 |
)
|