dx2102 commited on
Commit
feecefa
·
verified ·
1 Parent(s): 0d8d838

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -12
app.py CHANGED
@@ -24,6 +24,7 @@ pipe = transformers.pipeline(
24
  torch_dtype="bfloat16",
25
  device="cuda",
26
  )
 
27
  print('Done')
28
 
29
  example_prefix = '''pitch duration wait velocity instrument
@@ -166,9 +167,9 @@ CPUs will be slower but there is no time limit.
166
 
167
  def model_fn(prefix, history, server):
168
  if server == "Huggingface ZeroGPU":
169
- generator = zerogpu_model_fn(prefix, history)
170
  elif server == "CPU":
171
- generator = cpu_model_fn(prefix, history)
172
  # elif server == "RunPod":
173
  # generator = runpod_model_fn(prefix, history)
174
  else:
@@ -176,7 +177,7 @@ CPUs will be slower but there is no time limit.
176
  for history in generator:
177
  yield history
178
 
179
- def cpu_model_fn(prefix, history):
180
  queue = Queue(maxsize=10)
181
  class MyStreamer:
182
  def put(self, tokens):
@@ -188,15 +189,12 @@ CPUs will be slower but there is no time limit.
188
  def end(self):
189
  queue.put(None)
190
  def background_fn():
191
- try:
192
- result = pipe(
193
- prefix,
194
- streamer=MyStreamer(),
195
- max_new_tokens=500,
196
- top_p=0.9, temperature=0.6,
197
- )
198
- except queue.Full:
199
- print("Queue is full. Exiting.")
200
  print('Generated text:')
201
  print(result[0]['generated_text'])
202
  print()
 
24
  torch_dtype="bfloat16",
25
  device="cuda",
26
  )
27
+ cpu_pipe = pipe.to("cpu")
28
  print('Done')
29
 
30
  example_prefix = '''pitch duration wait velocity instrument
 
167
 
168
  def model_fn(prefix, history, server):
169
  if server == "Huggingface ZeroGPU":
170
+ generator = zerogpu_model_fn(prefix, history, pipe)
171
  elif server == "CPU":
172
+ generator = cpu_model_fn(prefix, history, cpu_pipe)
173
  # elif server == "RunPod":
174
  # generator = runpod_model_fn(prefix, history)
175
  else:
 
177
  for history in generator:
178
  yield history
179
 
180
+ def cpu_model_fn(prefix, history, pipe):
181
  queue = Queue(maxsize=10)
182
  class MyStreamer:
183
  def put(self, tokens):
 
189
  def end(self):
190
  queue.put(None)
191
  def background_fn():
192
+ result = pipe(
193
+ prefix,
194
+ streamer=MyStreamer(),
195
+ max_new_tokens=500,
196
+ top_p=0.9, temperature=0.6,
197
+ )
 
 
 
198
  print('Generated text:')
199
  print(result[0]['generated_text'])
200
  print()