Luigi commited on
Commit
a73d8f4
·
1 Parent(s): 4af617b

Fix cancel generation to gracefully stop ongoing response generation

Browse files

- Add StoppingCriteria import and CancelStoppingCriteria class
- Integrate stopping criteria into generation pipeline to halt token generation when cancel event is set
- Ensures generation stops at the model level, preventing unnecessary GPU usage

Files changed (1) hide show
  1. app.py +9 -1
app.py CHANGED
@@ -7,7 +7,7 @@ from datetime import datetime
7
  import re # for parsing <think> blocks
8
  import gradio as gr
9
  import torch
10
- from transformers import pipeline, TextIteratorStreamer
11
  from transformers import AutoTokenizer
12
  from ddgs import DDGS
13
  import spaces # Import spaces early to enable ZeroGPU support
@@ -23,6 +23,13 @@ access_token=os.environ['HF_TOKEN']
23
  # ------------------------------
24
  cancel_event = threading.Event()
25
 
 
 
 
 
 
 
 
26
  # ------------------------------
27
  # Torch-Compatible Model Definitions with Adjusted Descriptions
28
  # ------------------------------
@@ -499,6 +506,7 @@ def chat_response(user_msg, chat_history, system_prompt,
499
  'top_p': top_p,
500
  'repetition_penalty': repeat_penalty,
501
  'streamer': streamer,
 
502
  'return_full_text': False,
503
  }
504
  )
 
7
  import re # for parsing <think> blocks
8
  import gradio as gr
9
  import torch
10
+ from transformers import pipeline, TextIteratorStreamer, StoppingCriteria
11
  from transformers import AutoTokenizer
12
  from ddgs import DDGS
13
  import spaces # Import spaces early to enable ZeroGPU support
 
23
  # ------------------------------
24
  cancel_event = threading.Event()
25
 
26
+ # ------------------------------
27
+ # Stopping Criteria for Cancellation
28
+ # ------------------------------
29
+ class CancelStoppingCriteria(StoppingCriteria):
30
+ def __call__(self, input_ids, scores, **kwargs):
31
+ return cancel_event.is_set()
32
+
33
  # ------------------------------
34
  # Torch-Compatible Model Definitions with Adjusted Descriptions
35
  # ------------------------------
 
506
  'top_p': top_p,
507
  'repetition_penalty': repeat_penalty,
508
  'streamer': streamer,
509
+ 'stopping_criteria': [CancelStoppingCriteria()],
510
  'return_full_text': False,
511
  }
512
  )