Spaces:
Running
on
Zero
Running
on
Zero
Update utils/models.py
Browse files- utils/models.py +15 -5
utils/models.py
CHANGED
|
@@ -154,6 +154,9 @@ def run_inference(model_name, context, question):
|
|
| 154 |
if generation_interrupt.is_set():
|
| 155 |
return ""
|
| 156 |
|
|
|
|
|
|
|
|
|
|
| 157 |
print("REACHED HERE BEFORE pipe")
|
| 158 |
print(f"Loading model {model_name}...")
|
| 159 |
if "bitnet" in model_name.lower():
|
|
@@ -206,7 +209,10 @@ def run_inference(model_name, context, question):
|
|
| 206 |
result = pipe(
|
| 207 |
text_input,
|
| 208 |
max_new_tokens=512,
|
| 209 |
-
generation_kwargs={
|
|
|
|
|
|
|
|
|
|
| 210 |
)[0]["generated_text"]
|
| 211 |
|
| 212 |
result = result[-1]["content"]
|
|
@@ -221,7 +227,6 @@ def run_inference(model_name, context, question):
|
|
| 221 |
**tokenizer_kwargs,
|
| 222 |
)
|
| 223 |
|
| 224 |
-
|
| 225 |
model_inputs = model_inputs.to(model.device)
|
| 226 |
|
| 227 |
input_ids = model_inputs.input_ids
|
|
@@ -239,7 +244,8 @@ def run_inference(model_name, context, question):
|
|
| 239 |
attention_mask=attention_mask,
|
| 240 |
max_new_tokens=512,
|
| 241 |
eos_token_id=tokenizer.eos_token_id,
|
| 242 |
-
pad_token_id=tokenizer.pad_token_id
|
|
|
|
| 243 |
)
|
| 244 |
|
| 245 |
generated_token_ids = output_sequences[0][prompt_tokens_length:]
|
|
@@ -259,6 +265,7 @@ def run_inference(model_name, context, question):
|
|
| 259 |
# output_sequences = bitnet_model.generate(
|
| 260 |
# **formatted,
|
| 261 |
# max_new_tokens=512,
|
|
|
|
| 262 |
# )
|
| 263 |
|
| 264 |
# result = tokenizer.decode(output_sequences[0][formatted['input_ids'].shape[-1]:], skip_special_tokens=True)
|
|
@@ -275,7 +282,10 @@ def run_inference(model_name, context, question):
|
|
| 275 |
outputs = pipe(
|
| 276 |
formatted,
|
| 277 |
max_new_tokens=512,
|
| 278 |
-
generation_kwargs={
|
|
|
|
|
|
|
|
|
|
| 279 |
)
|
| 280 |
# print(outputs[0]['generated_text'])
|
| 281 |
result = outputs[0]["generated_text"][input_length:]
|
|
@@ -290,4 +300,4 @@ def run_inference(model_name, context, question):
|
|
| 290 |
if torch.cuda.is_available():
|
| 291 |
torch.cuda.empty_cache()
|
| 292 |
|
| 293 |
-
return result
|
|
|
|
| 154 |
if generation_interrupt.is_set():
|
| 155 |
return ""
|
| 156 |
|
| 157 |
+
# Create interrupt criteria for this generation
|
| 158 |
+
interrupt_criteria = InterruptCriteria(generation_interrupt)
|
| 159 |
+
|
| 160 |
print("REACHED HERE BEFORE pipe")
|
| 161 |
print(f"Loading model {model_name}...")
|
| 162 |
if "bitnet" in model_name.lower():
|
|
|
|
| 209 |
result = pipe(
|
| 210 |
text_input,
|
| 211 |
max_new_tokens=512,
|
| 212 |
+
generation_kwargs={
|
| 213 |
+
"skip_special_tokens": True,
|
| 214 |
+
"stopping_criteria": [interrupt_criteria] # ADD INTERRUPT SUPPORT
|
| 215 |
+
},
|
| 216 |
)[0]["generated_text"]
|
| 217 |
|
| 218 |
result = result[-1]["content"]
|
|
|
|
| 227 |
**tokenizer_kwargs,
|
| 228 |
)
|
| 229 |
|
|
|
|
| 230 |
model_inputs = model_inputs.to(model.device)
|
| 231 |
|
| 232 |
input_ids = model_inputs.input_ids
|
|
|
|
| 244 |
attention_mask=attention_mask,
|
| 245 |
max_new_tokens=512,
|
| 246 |
eos_token_id=tokenizer.eos_token_id,
|
| 247 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 248 |
+
stopping_criteria=[interrupt_criteria] # ADD INTERRUPT SUPPORT
|
| 249 |
)
|
| 250 |
|
| 251 |
generated_token_ids = output_sequences[0][prompt_tokens_length:]
|
|
|
|
| 265 |
# output_sequences = bitnet_model.generate(
|
| 266 |
# **formatted,
|
| 267 |
# max_new_tokens=512,
|
| 268 |
+
# stopping_criteria=[interrupt_criteria] # ADD INTERRUPT SUPPORT
|
| 269 |
# )
|
| 270 |
|
| 271 |
# result = tokenizer.decode(output_sequences[0][formatted['input_ids'].shape[-1]:], skip_special_tokens=True)
|
|
|
|
| 282 |
outputs = pipe(
|
| 283 |
formatted,
|
| 284 |
max_new_tokens=512,
|
| 285 |
+
generation_kwargs={
|
| 286 |
+
"skip_special_tokens": True,
|
| 287 |
+
"stopping_criteria": [interrupt_criteria] # ADD INTERRUPT SUPPORT
|
| 288 |
+
},
|
| 289 |
)
|
| 290 |
# print(outputs[0]['generated_text'])
|
| 291 |
result = outputs[0]["generated_text"][input_length:]
|
|
|
|
| 300 |
if torch.cuda.is_available():
|
| 301 |
torch.cuda.empty_cache()
|
| 302 |
|
| 303 |
+
return result
|