Spaces:
Runtime error
Runtime error
Ubuntu
commited on
Commit
·
c5556d8
1
Parent(s):
fa25719
update
Browse files
app.py
CHANGED
|
@@ -17,6 +17,7 @@ def infer(
|
|
| 17 |
top_p=1.0,
|
| 18 |
top_k=40,
|
| 19 |
num_completions=1,
|
|
|
|
| 20 |
seed=42,
|
| 21 |
stop="\n"
|
| 22 |
):
|
|
@@ -28,6 +29,7 @@ def infer(
|
|
| 28 |
temperature = float(temperature)
|
| 29 |
top_p = float(top_p)
|
| 30 |
top_k = int(top_k)
|
|
|
|
| 31 |
stop = stop.split(";")
|
| 32 |
seed = seed
|
| 33 |
|
|
@@ -36,6 +38,7 @@ def infer(
|
|
| 36 |
assert 0.0 <= temperature <= 10.0
|
| 37 |
assert 0.0 <= top_p <= 1.0
|
| 38 |
assert 1 <= top_k <= 1000
|
|
|
|
| 39 |
|
| 40 |
if temperature == 0.0:
|
| 41 |
temperature = 0.01
|
|
@@ -48,6 +51,7 @@ def infer(
|
|
| 48 |
"top_k": top_k,
|
| 49 |
"temperature": temperature,
|
| 50 |
"max_tokens": max_new_tokens,
|
|
|
|
| 51 |
"stop": stop,
|
| 52 |
}
|
| 53 |
print(f"send: {datetime.now()}")
|
|
@@ -223,6 +227,7 @@ def main():
|
|
| 223 |
if 'preset' not in st.session_state:
|
| 224 |
st.session_state.preset = "Sentiment Analysis"
|
| 225 |
st.session_state.top_k = "40"
|
|
|
|
| 226 |
st.session_state.stop = r'\n'
|
| 227 |
set_preset()
|
| 228 |
|
|
@@ -252,6 +257,7 @@ def main():
|
|
| 252 |
top_p = st.text_input('top_p', st.session_state.top_p)
|
| 253 |
# num_completions = st.text_input('num_completions (only the best one will be returend)', "1")
|
| 254 |
num_completions = "1"
|
|
|
|
| 255 |
stop = st.text_input('stop, split by;', st.session_state.stop)
|
| 256 |
# seed = st.text_input('seed', "42")
|
| 257 |
seed = "42"
|
|
@@ -275,7 +281,8 @@ def main():
|
|
| 275 |
generated_area.markdown("<b>" + to_md(prompt) + "</b>", unsafe_allow_html=True)
|
| 276 |
report_text = infer(
|
| 277 |
prompt, model_name=model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k,
|
| 278 |
-
num_completions=num_completions,
|
|
|
|
| 279 |
)
|
| 280 |
generated_area.markdown("<b>" + to_md(prompt) + "</b><mark style='background-color: #cbeacd'>" + to_md(report_text)+"</mark>", unsafe_allow_html=True)
|
| 281 |
|
|
|
|
| 17 |
top_p=1.0,
|
| 18 |
top_k=40,
|
| 19 |
num_completions=1,
|
| 20 |
+
repetition_penalty=1.0,
|
| 21 |
seed=42,
|
| 22 |
stop="\n"
|
| 23 |
):
|
|
|
|
| 29 |
temperature = float(temperature)
|
| 30 |
top_p = float(top_p)
|
| 31 |
top_k = int(top_k)
|
| 32 |
+
repetition_penalty = float(repetition_penalty)
|
| 33 |
stop = stop.split(";")
|
| 34 |
seed = seed
|
| 35 |
|
|
|
|
| 38 |
assert 0.0 <= temperature <= 10.0
|
| 39 |
assert 0.0 <= top_p <= 1.0
|
| 40 |
assert 1 <= top_k <= 1000
|
| 41 |
+
assert 0.9 <= repetition_penalty <= 3.0
|
| 42 |
|
| 43 |
if temperature == 0.0:
|
| 44 |
temperature = 0.01
|
|
|
|
| 51 |
"top_k": top_k,
|
| 52 |
"temperature": temperature,
|
| 53 |
"max_tokens": max_new_tokens,
|
| 54 |
+
"repetition_penalty": repetition_penalty,
|
| 55 |
"stop": stop,
|
| 56 |
}
|
| 57 |
print(f"send: {datetime.now()}")
|
|
|
|
| 227 |
if 'preset' not in st.session_state:
|
| 228 |
st.session_state.preset = "Sentiment Analysis"
|
| 229 |
st.session_state.top_k = "40"
|
| 230 |
+
st.session_state.repetition_penalty = "1.0"
|
| 231 |
st.session_state.stop = r'\n'
|
| 232 |
set_preset()
|
| 233 |
|
|
|
|
| 257 |
top_p = st.text_input('top_p', st.session_state.top_p)
|
| 258 |
# num_completions = st.text_input('num_completions (only the best one will be returend)', "1")
|
| 259 |
num_completions = "1"
|
| 260 |
+
repetition_penalty = st.text_input('repetition_penalty', st.session_state.repetition_penalty)
|
| 261 |
stop = st.text_input('stop, split by;', st.session_state.stop)
|
| 262 |
# seed = st.text_input('seed', "42")
|
| 263 |
seed = "42"
|
|
|
|
| 281 |
generated_area.markdown("<b>" + to_md(prompt) + "</b>", unsafe_allow_html=True)
|
| 282 |
report_text = infer(
|
| 283 |
prompt, model_name=model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k,
|
| 284 |
+
num_completions=num_completions, repetition_penalty=repetition_penalty,
|
| 285 |
+
seed=seed, stop=literal_eval("'''"+stop+"'''"),
|
| 286 |
)
|
| 287 |
generated_area.markdown("<b>" + to_md(prompt) + "</b><mark style='background-color: #cbeacd'>" + to_md(report_text)+"</mark>", unsafe_allow_html=True)
|
| 288 |
|