Spaces:
Runtime error
Runtime error
Commit
·
ac63b1e
1
Parent(s):
7a7ff47
Update with h2oGPT hash cf3886c550581e34d9f05d69d2e3438b2a46d7b2
Browse files- generate.py +46 -38
generate.py
CHANGED
|
@@ -5,6 +5,8 @@ import traceback
|
|
| 5 |
import typing
|
| 6 |
from threading import Thread
|
| 7 |
|
|
|
|
|
|
|
| 8 |
from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial
|
| 9 |
|
| 10 |
SEED = 1236
|
|
@@ -809,46 +811,52 @@ def evaluate(
|
|
| 809 |
)
|
| 810 |
|
| 811 |
with torch.no_grad():
|
| 812 |
-
#
|
| 813 |
-
|
| 814 |
-
|
| 815 |
-
if
|
| 816 |
-
|
| 817 |
-
|
| 818 |
-
|
| 819 |
-
|
| 820 |
-
|
| 821 |
-
|
| 822 |
-
|
| 823 |
-
|
| 824 |
-
|
| 825 |
-
|
| 826 |
-
|
| 827 |
-
|
| 828 |
-
|
| 829 |
-
|
| 830 |
-
|
| 831 |
-
|
| 832 |
-
|
| 833 |
-
|
| 834 |
-
|
| 835 |
-
|
| 836 |
-
|
| 837 |
-
|
| 838 |
-
|
| 839 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 840 |
yield prompter.get_response(outputs, prompt=inputs_decoded,
|
| 841 |
sanitize_bot_response=sanitize_bot_response)
|
| 842 |
-
|
| 843 |
-
|
| 844 |
-
|
| 845 |
-
|
| 846 |
-
yield prompter.get_response(outputs, prompt=inputs_decoded,
|
| 847 |
-
sanitize_bot_response=sanitize_bot_response)
|
| 848 |
-
if outputs and len(outputs) >= 1:
|
| 849 |
-
decoded_output = prompt + outputs[0]
|
| 850 |
-
if save_dir and decoded_output:
|
| 851 |
-
save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
|
| 852 |
|
| 853 |
|
| 854 |
def generate_with_exceptions(func, prompt, inputs_decoded, raise_generate_gpu_exceptions, **kwargs):
|
|
|
|
| 5 |
import typing
|
| 6 |
from threading import Thread
|
| 7 |
|
| 8 |
+
import filelock
|
| 9 |
+
|
| 10 |
from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial
|
| 11 |
|
| 12 |
SEED = 1236
|
|
|
|
| 811 |
)
|
| 812 |
|
| 813 |
with torch.no_grad():
|
| 814 |
+
# protection for gradio not keeping track of closed users,
|
| 815 |
+
# else hit bitsandbytes lack of thread safety:
|
| 816 |
+
# https://github.com/h2oai/h2ogpt/issues/104
|
| 817 |
+
# but only makes sense if concurrency_count == 1
|
| 818 |
+
context_class = NullContext if concurrency_count > 1 else filelock.FileLock
|
| 819 |
+
with context_class("generate.lock"):
|
| 820 |
+
# decoded tokenized prompt can deviate from prompt due to special characters
|
| 821 |
+
inputs_decoded = decoder(input_ids[0])
|
| 822 |
+
inputs_decoded_raw = decoder_raw(input_ids[0])
|
| 823 |
+
if inputs_decoded == prompt:
|
| 824 |
+
# normal
|
| 825 |
+
pass
|
| 826 |
+
elif inputs_decoded.lstrip() == prompt.lstrip():
|
| 827 |
+
# sometimes extra space in front, make prompt same for prompt removal
|
| 828 |
+
prompt = inputs_decoded
|
| 829 |
+
elif inputs_decoded_raw == prompt:
|
| 830 |
+
# some models specify special tokens that are part of normal prompt, so can't skip them
|
| 831 |
+
inputs_decoded_raw = inputs_decoded
|
| 832 |
+
decoder = decoder_raw
|
| 833 |
+
else:
|
| 834 |
+
print("WARNING: Special characters in prompt", flush=True)
|
| 835 |
+
decoded_output = None
|
| 836 |
+
if stream_output:
|
| 837 |
+
skip_prompt = False
|
| 838 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=skip_prompt)
|
| 839 |
+
gen_kwargs.update(dict(streamer=streamer))
|
| 840 |
+
target_func = generate_with_exceptions
|
| 841 |
+
target = wrapped_partial(generate_with_exceptions, model.generate, prompt, inputs_decoded,
|
| 842 |
+
raise_generate_gpu_exceptions, **gen_kwargs)
|
| 843 |
+
thread = Thread(target=target)
|
| 844 |
+
thread.start()
|
| 845 |
+
outputs = ""
|
| 846 |
+
for new_text in streamer:
|
| 847 |
+
outputs += new_text
|
| 848 |
+
yield prompter.get_response(outputs, prompt=inputs_decoded,
|
| 849 |
+
sanitize_bot_response=sanitize_bot_response)
|
| 850 |
+
decoded_output = outputs
|
| 851 |
+
else:
|
| 852 |
+
outputs = model.generate(**gen_kwargs)
|
| 853 |
+
outputs = [decoder(s) for s in outputs.sequences]
|
| 854 |
yield prompter.get_response(outputs, prompt=inputs_decoded,
|
| 855 |
sanitize_bot_response=sanitize_bot_response)
|
| 856 |
+
if outputs and len(outputs) >= 1:
|
| 857 |
+
decoded_output = prompt + outputs[0]
|
| 858 |
+
if save_dir and decoded_output:
|
| 859 |
+
save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 860 |
|
| 861 |
|
| 862 |
def generate_with_exceptions(func, prompt, inputs_decoded, raise_generate_gpu_exceptions, **kwargs):
|