Update app.py
Browse files
app.py
CHANGED
|
@@ -202,24 +202,26 @@ elif selected_model == "Генерация текста GPT-моделью по
|
|
| 202 |
model = GPT2LMHeadModel.from_pretrained(path).to(device)
|
| 203 |
tokenizer = GPT2Tokenizer.from_pretrained(path)
|
| 204 |
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
input_ids = tokenizer.encode(user_text_input, return_tensors="pt").to(device)
|
| 209 |
-
model.eval()
|
| 210 |
-
with torch.no_grad():
|
| 211 |
-
out = model.generate(
|
| 212 |
-
input_ids,
|
| 213 |
-
do_sample=True,
|
| 214 |
-
num_beams=2,
|
| 215 |
-
temperature=1.1,
|
| 216 |
-
top_p=0.9,
|
| 217 |
-
max_length=50,
|
| 218 |
-
)
|
| 219 |
-
|
| 220 |
-
generated_text = tokenizer.decode(out[0], skip_special_tokens=True)
|
| 221 |
-
end_time = time.time()
|
| 222 |
-
prediction_time = end_time - start_time
|
| 223 |
|
| 224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
st.write(f'{generated_text}')
|
|
|
|
| 202 |
model = GPT2LMHeadModel.from_pretrained(path).to(device)
|
| 203 |
tokenizer = GPT2Tokenizer.from_pretrained(path)
|
| 204 |
|
| 205 |
+
temperature = st.slider('Temperature', 0.1, 2.0, 1.1, step=0.1)
|
| 206 |
+
max_gen_length = st.slider('Максимальная длина генерации', 10, 500, 100, step=10)
|
| 207 |
+
num_generations = st.slider('Количество генераций', 1, 10, 2, step=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
+
if st.button('Сделать гороскоп'):
|
| 210 |
+
with st.spinner('Генерация текста...'):
|
| 211 |
+
start_time = time.time()
|
| 212 |
+
input_ids = tokenizer.encode(user_text_input, return_tensors="pt").to(device)
|
| 213 |
+
model.eval()
|
| 214 |
+
with torch.no_grad():
|
| 215 |
+
out = model.generate(
|
| 216 |
+
input_ids,
|
| 217 |
+
do_sample=True,
|
| 218 |
+
num_beams=num_generations,
|
| 219 |
+
temperature=temperature,
|
| 220 |
+
top_p=0.9,
|
| 221 |
+
max_length=max_gen_length,
|
| 222 |
+
)
|
| 223 |
+
generated_text = tokenizer.decode(out[0], skip_special_tokens=True)
|
| 224 |
+
end_time = time.time()
|
| 225 |
+
prediction_time = end_time - start_time
|
| 226 |
+
|
| 227 |
st.write(f'{generated_text}')
|