Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -272,7 +272,7 @@ class BeamNode:
|
|
| 272 |
is_selected_sequence: bool
|
| 273 |
|
| 274 |
|
| 275 |
-
def generate_beams(start_sentence, scores, length_penalty, decoded_sequences, beam_indexes_source):
|
| 276 |
original_tree = BeamNode(
|
| 277 |
cumulative_score=0,
|
| 278 |
current_token_ix=None,
|
|
@@ -284,7 +284,6 @@ def generate_beams(start_sentence, scores, length_penalty, decoded_sequences, be
|
|
| 284 |
is_final=False,
|
| 285 |
is_selected_sequence=False,
|
| 286 |
)
|
| 287 |
-
n_beams = len(scores[0])
|
| 288 |
beam_trees = [original_tree] * n_beams
|
| 289 |
generation_length = len(scores)
|
| 290 |
|
|
@@ -429,7 +428,7 @@ def get_beam_search_html(
|
|
| 429 |
outputs = model.generate(
|
| 430 |
**inputs,
|
| 431 |
max_new_tokens=number_steps,
|
| 432 |
-
num_beams=number_beams,
|
| 433 |
num_return_sequences=num_return_sequences,
|
| 434 |
return_dict_in_generate=True,
|
| 435 |
length_penalty=length_penalty,
|
|
@@ -447,6 +446,7 @@ def get_beam_search_html(
|
|
| 447 |
markdown += f"\n- Score `{outputs.sequences_scores[i]:.2f}`: `{clean(sequence.replace('<s> ', ''))}`"
|
| 448 |
|
| 449 |
original_tree = generate_beams(
|
|
|
|
| 450 |
input_text,
|
| 451 |
outputs.scores[:],
|
| 452 |
length_penalty,
|
|
@@ -493,7 +493,7 @@ This parameter will not impact the beam search paths, but only influence the cho
|
|
| 493 |
label="Number of steps", minimum=1, maximum=12, step=1, value=5
|
| 494 |
)
|
| 495 |
n_beams = gr.Slider(
|
| 496 |
-
label="Number of beams", minimum=
|
| 497 |
)
|
| 498 |
length_penalty = gr.Slider(
|
| 499 |
label="Length penalty", minimum=-3, maximum=3, step=0.5, value=1
|
|
|
|
| 272 |
is_selected_sequence: bool
|
| 273 |
|
| 274 |
|
| 275 |
+
def generate_beams(n_beams, start_sentence, scores, length_penalty, decoded_sequences, beam_indexes_source):
|
| 276 |
original_tree = BeamNode(
|
| 277 |
cumulative_score=0,
|
| 278 |
current_token_ix=None,
|
|
|
|
| 284 |
is_final=False,
|
| 285 |
is_selected_sequence=False,
|
| 286 |
)
|
|
|
|
| 287 |
beam_trees = [original_tree] * n_beams
|
| 288 |
generation_length = len(scores)
|
| 289 |
|
|
|
|
| 428 |
outputs = model.generate(
|
| 429 |
**inputs,
|
| 430 |
max_new_tokens=number_steps,
|
| 431 |
+
num_beams=max(number_beams, 2),
|
| 432 |
num_return_sequences=num_return_sequences,
|
| 433 |
return_dict_in_generate=True,
|
| 434 |
length_penalty=length_penalty,
|
|
|
|
| 446 |
markdown += f"\n- Score `{outputs.sequences_scores[i]:.2f}`: `{clean(sequence.replace('<s> ', ''))}`"
|
| 447 |
|
| 448 |
original_tree = generate_beams(
|
| 449 |
+
number_beams,
|
| 450 |
input_text,
|
| 451 |
outputs.scores[:],
|
| 452 |
length_penalty,
|
|
|
|
| 493 |
label="Number of steps", minimum=1, maximum=12, step=1, value=5
|
| 494 |
)
|
| 495 |
n_beams = gr.Slider(
|
| 496 |
+
label="Number of beams", minimum=1, maximum=4, step=1, value=4
|
| 497 |
)
|
| 498 |
length_penalty = gr.Slider(
|
| 499 |
label="Length penalty", minimum=-3, maximum=3, step=0.5, value=1
|