un-index commited on
Commit
145938e
·
1 Parent(s): d69b53f
Files changed (1) hide show
  1. app.py +27 -10
app.py CHANGED
@@ -6,6 +6,12 @@ import gradio as gr
6
  import json
7
 
8
  # # from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
9
 
10
  # stage, commit, push
11
 
@@ -28,6 +34,8 @@ temperature = gr.inputs.Slider(
28
  minimum=0, maximum=1.5, default=0.8, label="temperature")
29
  top_p = gr.inputs.Slider(minimum=0, maximum=1.0,
30
  default=0.9, label="top_p")
 
 
31
 
32
  # gradio checkbutton
33
 
@@ -93,12 +101,19 @@ def get_generated_text(generated_text):
93
 
94
 
95
 
96
- def f(context, temperature, top_p, max_length, model_idx, SPACE_VERIFICATION_KEY):
97
  try:
98
 
99
  if os.environ['SPACE_VERIFICATION_KEY'] != SPACE_VERIFICATION_KEY:
100
  return "invalid SPACE_VERIFICATION_KEY; see project secrets to view key"
101
 
 
 
 
 
 
 
 
102
  # maybe try "0" instead or 1, or "1"
103
  # use GPT-J-6B
104
  if model_idx == 0:
@@ -112,7 +127,7 @@ def f(context, temperature, top_p, max_length, model_idx, SPACE_VERIFICATION_KEY
112
  # context becomes the previous generated context
113
  # NOTE I've set return_full_text to false, see how this plays out
114
  # change max_length from max_length>250 and 250 or max_length to 250
115
- payload = {"inputs": context, "parameters": {"return_full_text":False, "max_new_tokens": 250, "temperature": temperature, "top_p": top_p}}
116
  response = requests.request("POST", API_URL, data=json.dumps(payload), headers=headers)
117
  context = json.loads(response.content.decode("utf-8"))#[0]['generated_text']
118
  # context = get_generated_text(generated_context)
@@ -158,10 +173,10 @@ def f(context, temperature, top_p, max_length, model_idx, SPACE_VERIFICATION_KEY
158
  elif model_idx == 1:
159
  # use GPT-2
160
  #
161
- try:
162
- set_seed(randint(1, 2**31))
163
- except Exception as e:
164
- return "Exception while setting seed: " + str(e)
165
  # return sequences specifies how many to return
166
 
167
  # for some reson indexing with 'generated-text' doesn't work
@@ -178,7 +193,7 @@ def f(context, temperature, top_p, max_length, model_idx, SPACE_VERIFICATION_KEY
178
  # however in the source that's what's used
179
  # NOTE I think max_new_tokens is working now and punctuation characters count too
180
  # NOTE set max_length to max_length to allow input text of any size
181
- generated_text = generator(context, max_length=896, max_new_tokens=max_length, top_p=top_p, temperature=temperature, num_return_sequences=1)
182
  except Exception as e:
183
  return "Exception while generating text: " + str(e)
184
  # [0][0]['generated_text']
@@ -196,12 +211,13 @@ def f(context, temperature, top_p, max_length, model_idx, SPACE_VERIFICATION_KEY
196
  # TODO use fallback gpt-2 inference api for this as well
197
  # TODO or just make it an option in the menu "GPT-2 inference"
198
  elif model_idx == 2:
 
199
  url = "https://api-inference.huggingface.co/models/distilgpt2"
200
  generated_text = ""#context #""
201
  # NOTE adding repetition penalty parameter
202
  # NOTE maybe leave tha parameter and just write a function to remove repetitions
203
  while len(generated_text) < max_length:
204
- payload = {"inputs": context, "parameters": {"repetition_penalty":20.0,"return_full_text":False, "max_new_tokens": 250, "temperature": temperature, "top_p": top_p}}
205
  response = requests.request("POST", url, data=json.dumps(payload), headers=headers)
206
  context = json.loads(response.content.decode("utf-8"))
207
  context = get_generated_text(context).strip()
@@ -219,7 +235,7 @@ def f(context, temperature, top_p, max_length, model_idx, SPACE_VERIFICATION_KEY
219
 
220
  generated_text = ""#context #""
221
  while len(generated_text) < max_length:
222
- payload = {"inputs": context, "parameters": {"return_full_text":False, "max_new_tokens": 250, "temperature": temperature, "top_p": top_p}}
223
  response = requests.request("POST", url, data=json.dumps(payload), headers=headers)
224
  context = json.loads(response.content.decode("utf-8"))
225
  context = get_generated_text(context).strip()
@@ -231,7 +247,7 @@ def f(context, temperature, top_p, max_length, model_idx, SPACE_VERIFICATION_KEY
231
  generated_text = ""#context #""
232
  # NOTE we're actually using max_new_tokens and min_new_tokens
233
  while len(generated_text) < max_length:
234
- payload = {"inputs": context, "parameters": {"return_full_text":False, "max_new_tokens": 250, "temperature": temperature, "top_p": top_p}}
235
  response = requests.request("POST", url, data=json.dumps(payload), headers=headers)
236
  context = json.loads(response.content.decode("utf-8"))
237
  context = get_generated_text(context).strip()
@@ -247,6 +263,7 @@ iface = gr.Interface(f, [
247
  "text",
248
  temperature,
249
  top_p,
 
250
  gr.inputs.Slider(
251
  minimum=20, maximum=512, default=30, label="max length"),
252
  gr.inputs.Dropdown(["GPT-J-6B", "GPT2", "DistilGPT2", "GPT-Large", "GPT-Neo-2.7B"], type="index", label="model", default="GPT2"),
 
6
  import json
7
 
8
  # # from transformers import AutoModelForCausalLM, AutoTokenizer
9
+ def get():
10
+ pass
11
+ def get():
12
+ pass;
13
+
14
+
15
 
16
  # stage, commit, push
17
 
 
34
  minimum=0, maximum=1.5, default=0.8, label="temperature")
35
  top_p = gr.inputs.Slider(minimum=0, maximum=1.0,
36
  default=0.9, label="top_p")
37
+ top_k = gr.inputs.Slider(minimum=0, maximum=100,
38
+ default=40, label="top_p")
39
 
40
  # gradio checkbutton
41
 
 
101
 
102
 
103
 
104
+ def f(context, temperature, top_p, top_k, max_length, model_idx, SPACE_VERIFICATION_KEY):
105
  try:
106
 
107
  if os.environ['SPACE_VERIFICATION_KEY'] != SPACE_VERIFICATION_KEY:
108
  return "invalid SPACE_VERIFICATION_KEY; see project secrets to view key"
109
 
110
+ try:
111
+ set_seed(randint(1, 256))
112
+ except Exception as e:
113
+ return "Exception while setting seed: " + str(e)
114
+
115
+ top_k = (top_k==0 and None) or top_k
116
+ # TODO write a function to generate the payload, it's becoming repetitive
117
  # maybe try "0" instead or 1, or "1"
118
  # use GPT-J-6B
119
  if model_idx == 0:
 
127
  # context becomes the previous generated context
128
  # NOTE I've set return_full_text to false, see how this plays out
129
  # change max_length from max_length>250 and 250 or max_length to 250
130
+ payload = {"inputs": context, "parameters": {"return_full_text":False, "max_new_tokens": 250, "temperature": temperature, "top_p": top_p, "top_k": top_k}}
131
  response = requests.request("POST", API_URL, data=json.dumps(payload), headers=headers)
132
  context = json.loads(response.content.decode("utf-8"))#[0]['generated_text']
133
  # context = get_generated_text(generated_context)
 
173
  elif model_idx == 1:
174
  # use GPT-2
175
  #
176
+ # try:
177
+ # set_seed(randint(1, 2**31))
178
+ # except Exception as e:
179
+ # return "Exception while setting seed: " + str(e)
180
  # return sequences specifies how many to return
181
 
182
  # for some reson indexing with 'generated-text' doesn't work
 
193
  # however in the source that's what's used
194
  # NOTE I think max_new_tokens is working now and punctuation characters count too
195
  # NOTE set max_length to max_length to allow input text of any size
196
+ generated_text = generator(context, max_length=896, max_new_tokens=max_length, top_p=top_p, top_k=top_k, temperature=temperature, num_return_sequences=1)
197
  except Exception as e:
198
  return "Exception while generating text: " + str(e)
199
  # [0][0]['generated_text']
 
211
  # TODO use fallback gpt-2 inference api for this as well
212
  # TODO or just make it an option in the menu "GPT-2 inference"
213
  elif model_idx == 2:
214
+
215
  url = "https://api-inference.huggingface.co/models/distilgpt2"
216
  generated_text = ""#context #""
217
  # NOTE adding repetition penalty parameter
218
  # NOTE maybe leave tha parameter and just write a function to remove repetitions
219
  while len(generated_text) < max_length:
220
+ payload = {"inputs": context, "parameters": {"repetition_penalty":20.0,"return_full_text":False, "max_new_tokens": 250, "temperature": temperature, "top_p": top_p, "top_k": top_k}}
221
  response = requests.request("POST", url, data=json.dumps(payload), headers=headers)
222
  context = json.loads(response.content.decode("utf-8"))
223
  context = get_generated_text(context).strip()
 
235
 
236
  generated_text = ""#context #""
237
  while len(generated_text) < max_length:
238
+ payload = {"inputs": context, "parameters": {"return_full_text":False, "max_new_tokens": 250, "temperature": temperature, "top_p": top_p, "top_k": top_k}}
239
  response = requests.request("POST", url, data=json.dumps(payload), headers=headers)
240
  context = json.loads(response.content.decode("utf-8"))
241
  context = get_generated_text(context).strip()
 
247
  generated_text = ""#context #""
248
  # NOTE we're actually using max_new_tokens and min_new_tokens
249
  while len(generated_text) < max_length:
250
+ payload = {"inputs": context, "parameters": {"return_full_text":False, "max_new_tokens": 250, "temperature": temperature, "top_p": top_p, "top_k": top_k}}
251
  response = requests.request("POST", url, data=json.dumps(payload), headers=headers)
252
  context = json.loads(response.content.decode("utf-8"))
253
  context = get_generated_text(context).strip()
 
263
  "text",
264
  temperature,
265
  top_p,
266
+ top_k,
267
  gr.inputs.Slider(
268
  minimum=20, maximum=512, default=30, label="max length"),
269
  gr.inputs.Dropdown(["GPT-J-6B", "GPT2", "DistilGPT2", "GPT-Large", "GPT-Neo-2.7B"], type="index", label="model", default="GPT2"),