Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -21,6 +21,7 @@ def generate(
|
|
| 21 |
temperature,
|
| 22 |
top_p,
|
| 23 |
top_k,
|
|
|
|
| 24 |
seed,
|
| 25 |
model_path="roborovski/superprompt-v1",
|
| 26 |
dtype="fp16",
|
|
@@ -40,7 +41,9 @@ def generate(
|
|
| 40 |
input_text = f"{prompt}, {history}"
|
| 41 |
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
|
| 42 |
|
| 43 |
-
|
|
|
|
|
|
|
| 44 |
outputs = model.generate(
|
| 45 |
input_ids,
|
| 46 |
max_new_tokens=max_new_tokens,
|
|
@@ -100,6 +103,11 @@ additional_inputs = [
|
|
| 100 |
label="Top K",
|
| 101 |
info="Higher k means more diverse outputs by considering a range of tokens",
|
| 102 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
gr.Number(
|
| 104 |
value=42,
|
| 105 |
interactive=True,
|
|
@@ -123,6 +131,7 @@ examples = [
|
|
| 123 |
None,
|
| 124 |
None,
|
| 125 |
None,
|
|
|
|
| 126 |
None,
|
| 127 |
"roborovski/superprompt-v1",
|
| 128 |
"fp16",
|
|
|
|
| 21 |
temperature,
|
| 22 |
top_p,
|
| 23 |
top_k,
|
| 24 |
+
seed_checkbox,
|
| 25 |
seed,
|
| 26 |
model_path="roborovski/superprompt-v1",
|
| 27 |
dtype="fp16",
|
|
|
|
| 41 |
input_text = f"{prompt}, {history}"
|
| 42 |
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
|
| 43 |
|
| 44 |
+
if seed_checkbox:
|
| 45 |
+
torch.manual_seed(seed)
|
| 46 |
+
|
| 47 |
outputs = model.generate(
|
| 48 |
input_ids,
|
| 49 |
max_new_tokens=max_new_tokens,
|
|
|
|
| 103 |
label="Top K",
|
| 104 |
info="Higher k means more diverse outputs by considering a range of tokens",
|
| 105 |
),
|
| 106 |
+
gr.Checkbox(
|
| 107 |
+
value=False,
|
| 108 |
+
label="Use Random Seed",
|
| 109 |
+
info="Check to use a random seed for the generation process",
|
| 110 |
+
),
|
| 111 |
gr.Number(
|
| 112 |
value=42,
|
| 113 |
interactive=True,
|
|
|
|
| 131 |
None,
|
| 132 |
None,
|
| 133 |
None,
|
| 134 |
+
False,
|
| 135 |
None,
|
| 136 |
"roborovski/superprompt-v1",
|
| 137 |
"fp16",
|