Spaces:
Running
on
T4
Running
on
T4
Update app.py
Browse files
app.py
CHANGED
|
@@ -46,6 +46,7 @@ STYLE = """
|
|
| 46 |
height: auto;
|
| 47 |
text-align: center;
|
| 48 |
display:inline-block;
|
|
|
|
| 49 |
}
|
| 50 |
#root {
|
| 51 |
display: inline-grid!important;
|
|
@@ -417,6 +418,7 @@ def generate_beams(n_beams, start_sentence, scores, length_penalty, decoded_sequ
|
|
| 417 |
|
| 418 |
return original_tree
|
| 419 |
|
|
|
|
| 420 |
@spaces.GPU
|
| 421 |
def get_beam_search_html(
|
| 422 |
input_text, number_steps, number_beams, length_penalty, num_return_sequences
|
|
@@ -441,9 +443,12 @@ def get_beam_search_html(
|
|
| 441 |
# Sequences are padded anyway so you can batch decode them
|
| 442 |
decoded_sequences = tokenizer.batch_decode(outputs.sequences)
|
| 443 |
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
|
|
|
|
|
|
|
|
|
| 447 |
|
| 448 |
if number_beams > 1:
|
| 449 |
original_tree = generate_beams(
|
|
@@ -455,9 +460,9 @@ def get_beam_search_html(
|
|
| 455 |
)
|
| 456 |
else:
|
| 457 |
original_tree = generate_beams(
|
| 458 |
-
|
| 459 |
input_text,
|
| 460 |
-
outputs.
|
| 461 |
0,
|
| 462 |
decoded_sequences,
|
| 463 |
)
|
|
|
|
| 46 |
height: auto;
|
| 47 |
text-align: center;
|
| 48 |
display:inline-block;
|
| 49 |
+
padding-bottom: 10px!important;
|
| 50 |
}
|
| 51 |
#root {
|
| 52 |
display: inline-grid!important;
|
|
|
|
| 418 |
|
| 419 |
return original_tree
|
| 420 |
|
| 421 |
+
|
| 422 |
@spaces.GPU
|
| 423 |
def get_beam_search_html(
|
| 424 |
input_text, number_steps, number_beams, length_penalty, num_return_sequences
|
|
|
|
| 443 |
# Sequences are padded anyway so you can batch decode them
|
| 444 |
decoded_sequences = tokenizer.batch_decode(outputs.sequences)
|
| 445 |
|
| 446 |
+
if number_beams > 1:
|
| 447 |
+
for i, sequence in enumerate(decoded_sequences):
|
| 448 |
+
markdown += f"\n- Score `{outputs.sequences_scores[i]:.2f}`: `{clean(sequence.replace('<s> ', ''))}`"
|
| 449 |
+
else:
|
| 450 |
+
markdown += f"\n- `{clean(decoded_sequences[0].replace('<s> ', ''))}`"
|
| 451 |
+
print(outputs.logits)
|
| 452 |
|
| 453 |
if number_beams > 1:
|
| 454 |
original_tree = generate_beams(
|
|
|
|
| 460 |
)
|
| 461 |
else:
|
| 462 |
original_tree = generate_beams(
|
| 463 |
+
number_beams,
|
| 464 |
input_text,
|
| 465 |
+
outputs.scores,
|
| 466 |
0,
|
| 467 |
decoded_sequences,
|
| 468 |
)
|