Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
fd247b7
1
Parent(s):
7f28f16
add gen prompt and kwargs dicts
Browse files- utils/models.py +23 -4
utils/models.py
CHANGED
|
@@ -29,6 +29,8 @@ models = {
|
|
| 29 |
|
| 30 |
}
|
| 31 |
|
|
|
|
|
|
|
| 32 |
# List of model names for easy access
|
| 33 |
model_names = list(models.keys())
|
| 34 |
|
|
@@ -101,13 +103,29 @@ def run_inference(model_name, context, question):
|
|
| 101 |
|
| 102 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 103 |
result = ""
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
if "qwen3" in model_name.lower():
|
| 106 |
print(f"Recognized {model_name} as a Qwen3 model. Setting enable_thinking=False.")
|
| 107 |
-
|
|
|
|
| 108 |
|
| 109 |
try:
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
accepts_sys = (
|
| 112 |
"System role not supported" not in tokenizer.chat_template
|
| 113 |
if tokenizer.chat_template else False # Handle missing chat_template
|
|
@@ -126,6 +144,7 @@ def run_inference(model_name, context, question):
|
|
| 126 |
tokenizer=tokenizer,
|
| 127 |
device_map='auto',
|
| 128 |
trust_remote_code=True,
|
|
|
|
| 129 |
)
|
| 130 |
|
| 131 |
text_input = format_rag_prompt(question, context, accepts_sys)
|
|
@@ -134,7 +153,7 @@ def run_inference(model_name, context, question):
|
|
| 134 |
if generation_interrupt.is_set():
|
| 135 |
return ""
|
| 136 |
|
| 137 |
-
outputs = pipe(text_input, max_new_tokens=512)
|
| 138 |
result = outputs[0]['generated_text'][-1]['content']
|
| 139 |
|
| 140 |
except Exception as e:
|
|
|
|
| 29 |
|
| 30 |
}
|
| 31 |
|
| 32 |
+
tokenizer_cache = {}
|
| 33 |
+
|
| 34 |
# List of model names for easy access
|
| 35 |
model_names = list(models.keys())
|
| 36 |
|
|
|
|
| 103 |
|
| 104 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 105 |
result = ""
|
| 106 |
+
tokenizer_kwargs = {
|
| 107 |
+
"add_generation_prompt": True,
|
| 108 |
+
} # make sure qwen3 doesn't use thinking
|
| 109 |
+
generation_kwargs = {
|
| 110 |
+
"max_new_tokens": 512,
|
| 111 |
+
}
|
| 112 |
if "qwen3" in model_name.lower():
|
| 113 |
print(f"Recognized {model_name} as a Qwen3 model. Setting enable_thinking=False.")
|
| 114 |
+
tokenizer_kwargs["enable_thinking"] = False
|
| 115 |
+
generation_kwargs["enable_thinking"] = False
|
| 116 |
|
| 117 |
try:
|
| 118 |
+
if model_name in tokenizer_cache:
|
| 119 |
+
tokenizer = tokenizer_cache[model_name]
|
| 120 |
+
else:
|
| 121 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 122 |
+
model_name,
|
| 123 |
+
padding_side="left",
|
| 124 |
+
token=True,
|
| 125 |
+
kwargs=tokenizer_kwargs
|
| 126 |
+
)
|
| 127 |
+
tokenizer_cache[model_name] = tokenizer
|
| 128 |
+
|
| 129 |
accepts_sys = (
|
| 130 |
"System role not supported" not in tokenizer.chat_template
|
| 131 |
if tokenizer.chat_template else False # Handle missing chat_template
|
|
|
|
| 144 |
tokenizer=tokenizer,
|
| 145 |
device_map='auto',
|
| 146 |
trust_remote_code=True,
|
| 147 |
+
torch_dtype=torch.bfloat16,
|
| 148 |
)
|
| 149 |
|
| 150 |
text_input = format_rag_prompt(question, context, accepts_sys)
|
|
|
|
| 153 |
if generation_interrupt.is_set():
|
| 154 |
return ""
|
| 155 |
|
| 156 |
+
outputs = pipe(text_input, max_new_tokens=512, generate_kwargs=generation_kwargs)
|
| 157 |
result = outputs[0]['generated_text'][-1]['content']
|
| 158 |
|
| 159 |
except Exception as e:
|