Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -34,7 +34,15 @@ examples=[
|
|
| 34 |
]
|
| 35 |
|
| 36 |
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
input_prompt = f"[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n "
|
| 40 |
for interaction in chatbot:
|
|
@@ -44,12 +52,13 @@ def predict(message, chatbot):
|
|
| 44 |
|
| 45 |
data = {
|
| 46 |
"inputs": input_prompt,
|
| 47 |
-
"parameters": {
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
| 53 |
response = requests.post(api_url, headers=headers, data=json.dumps(data), auth=('hf', hf_token), stream=True)
|
| 54 |
|
| 55 |
partial_message = ""
|
|
@@ -84,8 +93,16 @@ def predict(message, chatbot):
|
|
| 84 |
continue
|
| 85 |
|
| 86 |
|
| 87 |
-
|
|
|
|
| 88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
input_prompt = f"[INST]<<SYS>>\n{system_message}\n<</SYS>>\n\n "
|
| 90 |
for interaction in chatbot:
|
| 91 |
input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s> [INST] "
|
|
@@ -94,7 +111,13 @@ def predict_batch(message, chatbot):
|
|
| 94 |
|
| 95 |
data = {
|
| 96 |
"inputs": input_prompt,
|
| 97 |
-
"parameters": {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
}
|
| 99 |
|
| 100 |
response = requests.post(api_url_nostream, headers=headers, data=json.dumps(data), auth=('hf', hf_token))
|
|
@@ -114,13 +137,55 @@ def predict_batch(message, chatbot):
|
|
| 114 |
print(f"Request failed with status code {response.status_code}")
|
| 115 |
|
| 116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
# Gradio Demo
|
| 118 |
with gr.Blocks() as demo:
|
| 119 |
|
| 120 |
with gr.Tab("Streaming"):
|
| 121 |
-
gr.ChatInterface(predict, title=title, description=description, css=css, examples=examples, cache_examples=True)
|
| 122 |
|
| 123 |
with gr.Tab("Batch"):
|
| 124 |
-
gr.ChatInterface(predict_batch, title=title, description=description, css=css, examples=examples, cache_examples=True)
|
| 125 |
|
| 126 |
demo.queue(concurrency_count=75, max_size=100).launch(debug=True)
|
|
|
|
| 34 |
]
|
| 35 |
|
| 36 |
|
| 37 |
+
# Stream text
|
| 38 |
+
def predict(message, chatbot, system_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0,):
|
| 39 |
+
|
| 40 |
+
if system_prompt != "":
|
| 41 |
+
system_message = system_prompt
|
| 42 |
+
temperature = float(temperature)
|
| 43 |
+
if temperature < 1e-2:
|
| 44 |
+
temperature = 1e-2
|
| 45 |
+
top_p = float(top_p)
|
| 46 |
|
| 47 |
input_prompt = f"[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n "
|
| 48 |
for interaction in chatbot:
|
|
|
|
| 52 |
|
| 53 |
data = {
|
| 54 |
"inputs": input_prompt,
|
| 55 |
+
"parameters": {
|
| 56 |
+
"max_new_tokens":max_new_tokens,
|
| 57 |
+
"temperature"=temperature,
|
| 58 |
+
"top_p"=top_p,
|
| 59 |
+
"repetition_penalty"=repetition_penalty,
|
| 60 |
+
"do_sample":True,
|
| 61 |
+
},
|
| 62 |
response = requests.post(api_url, headers=headers, data=json.dumps(data), auth=('hf', hf_token), stream=True)
|
| 63 |
|
| 64 |
partial_message = ""
|
|
|
|
| 93 |
continue
|
| 94 |
|
| 95 |
|
| 96 |
+
# No Stream
|
| 97 |
+
def predict_batch(message, chatbot, system_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0,):
|
| 98 |
|
| 99 |
+
if system_prompt != "":
|
| 100 |
+
system_message = system_prompt
|
| 101 |
+
temperature = float(temperature)
|
| 102 |
+
if temperature < 1e-2:
|
| 103 |
+
temperature = 1e-2
|
| 104 |
+
top_p = float(top_p)
|
| 105 |
+
|
| 106 |
input_prompt = f"[INST]<<SYS>>\n{system_message}\n<</SYS>>\n\n "
|
| 107 |
for interaction in chatbot:
|
| 108 |
input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s> [INST] "
|
|
|
|
| 111 |
|
| 112 |
data = {
|
| 113 |
"inputs": input_prompt,
|
| 114 |
+
"parameters": {
|
| 115 |
+
"max_new_tokens":max_new_tokens,
|
| 116 |
+
"temperature"=temperature,
|
| 117 |
+
"top_p"=top_p,
|
| 118 |
+
"repetition_penalty"=repetition_penalty,
|
| 119 |
+
"do_sample":True,
|
| 120 |
+
},
|
| 121 |
}
|
| 122 |
|
| 123 |
response = requests.post(api_url_nostream, headers=headers, data=json.dumps(data), auth=('hf', hf_token))
|
|
|
|
| 137 |
print(f"Request failed with status code {response.status_code}")
|
| 138 |
|
| 139 |
|
| 140 |
+
|
| 141 |
+
additional_inputs=[
|
| 142 |
+
gr.Textbox("", label="Optional system prompt"),
|
| 143 |
+
gr.Slider(
|
| 144 |
+
label="Temperature",
|
| 145 |
+
value=0.9,
|
| 146 |
+
minimum=0.0,
|
| 147 |
+
maximum=1.0,
|
| 148 |
+
step=0.05,
|
| 149 |
+
interactive=True,
|
| 150 |
+
info="Higher values produce more diverse outputs",
|
| 151 |
+
),
|
| 152 |
+
gr.Slider(
|
| 153 |
+
label="Max new tokens",
|
| 154 |
+
value=256,
|
| 155 |
+
minimum=0,
|
| 156 |
+
maximum=4096,
|
| 157 |
+
step=64,
|
| 158 |
+
interactive=True,
|
| 159 |
+
info="The maximum numbers of new tokens",
|
| 160 |
+
),
|
| 161 |
+
gr.Slider(
|
| 162 |
+
label="Top-p (nucleus sampling)",
|
| 163 |
+
value=0.6,
|
| 164 |
+
minimum=0.0,
|
| 165 |
+
maximum=1,
|
| 166 |
+
step=0.05,
|
| 167 |
+
interactive=True,
|
| 168 |
+
info="Higher values sample more low-probability tokens",
|
| 169 |
+
),
|
| 170 |
+
gr.Slider(
|
| 171 |
+
label="Repetition penalty",
|
| 172 |
+
value=1.2,
|
| 173 |
+
minimum=1.0,
|
| 174 |
+
maximum=2.0,
|
| 175 |
+
step=0.05,
|
| 176 |
+
interactive=True,
|
| 177 |
+
info="Penalize repeated tokens",
|
| 178 |
+
)
|
| 179 |
+
]
|
| 180 |
+
|
| 181 |
+
|
| 182 |
# Gradio Demo
|
| 183 |
with gr.Blocks() as demo:
|
| 184 |
|
| 185 |
with gr.Tab("Streaming"):
|
| 186 |
+
gr.ChatInterface(predict, title=title, description=description, css=css, examples=examples, cache_examples=True, additional_inputs=additional_inputs,)
|
| 187 |
|
| 188 |
with gr.Tab("Batch"):
|
| 189 |
+
gr.ChatInterface(predict_batch, title=title, description=description, css=css, examples=examples, cache_examples=True, additional_inputs=additional_inputs,)
|
| 190 |
|
| 191 |
demo.queue(concurrency_count=75, max_size=100).launch(debug=True)
|