Spaces:
Build error
Build error
add repetition penalty
Browse files- app/app.py +22 -4
app/app.py
CHANGED
|
@@ -110,12 +110,16 @@ def get_generator(model_name: str):
|
|
| 110 |
# Disable the st.cache for this function due to issue on newer version of streamlit
|
| 111 |
# @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: id})
|
| 112 |
def process(text_generator, text: str, max_length: int = 100, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
|
| 113 |
-
temperature: float = 1.0, max_time: float = 120.0, seed=42):
|
| 114 |
# st.write("Cache miss: process")
|
| 115 |
set_seed(seed)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
result = text_generator(text, max_length=max_length, do_sample=do_sample,
|
| 117 |
top_k=top_k, top_p=top_p, temperature=temperature,
|
| 118 |
-
max_time=max_time)
|
| 119 |
return result
|
| 120 |
|
| 121 |
|
|
@@ -164,7 +168,7 @@ if prompt_group_name in ["Indonesian GPT-2", "Indonesian Literature", "Indonesia
|
|
| 164 |
"Temperature",
|
| 165 |
value=0.9,
|
| 166 |
min_value=0.0,
|
| 167 |
-
max_value=
|
| 168 |
)
|
| 169 |
|
| 170 |
do_sample = st.sidebar.checkbox(
|
|
@@ -194,6 +198,20 @@ if prompt_group_name in ["Indonesian GPT-2", "Indonesian Literature", "Indonesia
|
|
| 194 |
help="The number used to initialize a pseudorandom number generator"
|
| 195 |
)
|
| 196 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
for group_name in MODELS:
|
| 198 |
if MODELS[group_name]["group"] in ["Indonesian GPT-2", "Indonesian Literature", "Indonesian Journal"]:
|
| 199 |
MODELS[group_name]["text_generator"] = get_generator(MODELS[group_name]["name"])
|
|
@@ -206,7 +224,7 @@ if prompt_group_name in ["Indonesian GPT-2", "Indonesian Literature", "Indonesia
|
|
| 206 |
# text_generator = MODELS[model]["text_generator"]
|
| 207 |
result = process(MODELS[model]["text_generator"], text=session_state.text, max_length=int(max_length),
|
| 208 |
temperature=temperature, do_sample=do_sample,
|
| 209 |
-
top_k=int(top_k), top_p=float(top_p), seed=seed)
|
| 210 |
time_end = time.time()
|
| 211 |
time_diff = time_end-time_start
|
| 212 |
result = result[0]["generated_text"]
|
|
|
|
| 110 |
# Disable the st.cache for this function due to issue on newer version of streamlit
|
| 111 |
# @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: id})
|
| 112 |
def process(text_generator, text: str, max_length: int = 100, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
|
| 113 |
+
temperature: float = 1.0, max_time: float = 120.0, seed=42, repetition_penalty=1.0):
|
| 114 |
# st.write("Cache miss: process")
|
| 115 |
set_seed(seed)
|
| 116 |
+
if repetition_penalty == 0.0:
|
| 117 |
+
min_penalty = 1.05
|
| 118 |
+
max_penalty = 1.5
|
| 119 |
+
repetition_penalty = max(min_penalty + (1.0-temperature) * (max_penalty-min_penalty), 0.8)
|
| 120 |
result = text_generator(text, max_length=max_length, do_sample=do_sample,
|
| 121 |
top_k=top_k, top_p=top_p, temperature=temperature,
|
| 122 |
+
max_time=max_time, repetition_penalty=repetition_penalty)
|
| 123 |
return result
|
| 124 |
|
| 125 |
|
|
|
|
| 168 |
"Temperature",
|
| 169 |
value=0.9,
|
| 170 |
min_value=0.0,
|
| 171 |
+
max_value=2.0
|
| 172 |
)
|
| 173 |
|
| 174 |
do_sample = st.sidebar.checkbox(
|
|
|
|
| 198 |
help="The number used to initialize a pseudorandom number generator"
|
| 199 |
)
|
| 200 |
|
| 201 |
+
repetition_penalty = 0.0
|
| 202 |
+
automatic_repetition_penalty = st.sidebar.checkbox(
|
| 203 |
+
"Automatic Repetition Penalty",
|
| 204 |
+
value=True
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
if not automatic_repetition_penalty:
|
| 208 |
+
repetition_penalty = st.sidebar.slider(
|
| 209 |
+
"Repetition Penalty",
|
| 210 |
+
value=1.0,
|
| 211 |
+
min_value=1.0,
|
| 212 |
+
max_value=2.0
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
for group_name in MODELS:
|
| 216 |
if MODELS[group_name]["group"] in ["Indonesian GPT-2", "Indonesian Literature", "Indonesian Journal"]:
|
| 217 |
MODELS[group_name]["text_generator"] = get_generator(MODELS[group_name]["name"])
|
|
|
|
| 224 |
# text_generator = MODELS[model]["text_generator"]
|
| 225 |
result = process(MODELS[model]["text_generator"], text=session_state.text, max_length=int(max_length),
|
| 226 |
temperature=temperature, do_sample=do_sample,
|
| 227 |
+
top_k=int(top_k), top_p=float(top_p), seed=seed, repetition_penalty=repetition_penalty)
|
| 228 |
time_end = time.time()
|
| 229 |
time_diff = time_end-time_start
|
| 230 |
result = result[0]["generated_text"]
|