Spaces:
Runtime error
Runtime error
Younes Belkada
commited on
Commit
·
7e700b0
1
Parent(s):
45de332
Add no repeat n gram
Browse files
app.py
CHANGED
|
@@ -35,11 +35,11 @@ def query(payload):
|
|
| 35 |
print(response)
|
| 36 |
return json.loads(response.content.decode("utf-8"))
|
| 37 |
|
| 38 |
-
def inference(input_sentence, max_length, temperature,top_k, top_p, greedy_decoding, seed=42):
|
| 39 |
top_k = None if top_k == 0 else top_k
|
| 40 |
payload = {"inputs": input_sentence,
|
| 41 |
"parameters": {"max_new_tokens": max_length, "top_k": top_k, "top_p": top_p, "temperature": temperature,
|
| 42 |
-
"do_sample": not greedy_decoding, "seed": seed}}
|
| 43 |
data = query(
|
| 44 |
payload
|
| 45 |
)
|
|
@@ -52,6 +52,7 @@ gr.Interface(
|
|
| 52 |
[
|
| 53 |
gr.inputs.Textbox(label="Input"),
|
| 54 |
gr.inputs.Slider(1, 64, default=8, label="Tokens to generate"),
|
|
|
|
| 55 |
gr.inputs.Slider(0.0, 1.0, default=0.1, step=0.05, label="Temperature"),
|
| 56 |
gr.inputs.Slider(0, 64, default=0, step=1, label="Top K"),
|
| 57 |
gr.inputs.Slider(0.0, 10, default=0.9, step=0.05, label="Top P"),
|
|
|
|
| 35 |
print(response)
|
| 36 |
return json.loads(response.content.decode("utf-8"))
|
| 37 |
|
| 38 |
+
def inference(input_sentence, no_repeat_ngram_size, max_length, temperature,top_k, top_p, greedy_decoding, seed=42):
|
| 39 |
top_k = None if top_k == 0 else top_k
|
| 40 |
payload = {"inputs": input_sentence,
|
| 41 |
"parameters": {"max_new_tokens": max_length, "top_k": top_k, "top_p": top_p, "temperature": temperature,
|
| 42 |
+
"do_sample": not greedy_decoding, "seed": seed, "early_stopping":no_repeat_ngram_size > 1, "no_repeat_ngram_size":no_repeat_ngram_size}}
|
| 43 |
data = query(
|
| 44 |
payload
|
| 45 |
)
|
|
|
|
| 52 |
[
|
| 53 |
gr.inputs.Textbox(label="Input"),
|
| 54 |
gr.inputs.Slider(1, 64, default=8, label="Tokens to generate"),
|
| 55 |
+
gr.inputs.Slider(1, 10, default=2, label="No repeat N gram"),
|
| 56 |
gr.inputs.Slider(0.0, 1.0, default=0.1, step=0.05, label="Temperature"),
|
| 57 |
gr.inputs.Slider(0, 64, default=0, step=1, label="Top K"),
|
| 58 |
gr.inputs.Slider(0.0, 10, default=0.9, step=0.05, label="Top P"),
|