un-index commited on
Commit
02037ab
·
1 Parent(s): 725503b
Files changed (1) hide show
  1. app.py +6 -27
app.py CHANGED
@@ -195,38 +195,16 @@ def f(context, temperature, top_p, max_length, model_idx, SPACE_VERIFICATION_KEY
195
  # TODO use fallback gpt-2 inference api for this as well
196
  # TODO or just make it an option in the menu "GPT-2 inference"
197
  elif model_idx == 2:
198
- DISTIL_GPT2_API_URL = "https://api-inference.huggingface.co/models/distilgpt2"
199
- # generated_text=""
200
- # while (max_length > 0):
201
- # # NOTE see original implementation above for gpt-J-6B
202
- # payload = {"inputs": context, "parameters": {"max_new_tokens": 250, "temperature": temperature, "top_p": top_p}}
203
- # response = requests.request("POST", DISTIL_GPT2_API_URL, data=json.dumps(payload), headers=headers)
204
- # context = json.loads(response.content.decode("utf-8"))
205
-
206
- # context = get_generated_text(context)
207
-
208
- # generated_text += context
209
- # max_length -= 250
210
  generated_text = ""#context #""
211
- while len(generated_text) < max_length:#(max_length > 0): NOTE NOTE commented out this line and added new check
212
- # context becomes the previous generated context
213
- # NOTE I've set return_full_text to false, see how this plays out
214
- # change max_length from max_length>250 and 250 or max_length to 250
215
  payload = {"inputs": context, "parameters": {"return_full_text":False, "max_new_tokens": 250, "temperature": temperature, "top_p": top_p}}
216
- response = requests.request("POST", DISTIL_GPT2_API_URL, data=json.dumps(payload), headers=headers)
217
- context = json.loads(response.content.decode("utf-8"))#[0]['generated_text']
218
- # context = get_generated_text(generated_context)
219
-
220
- # handle inconsistent inference API
221
- # if 'generated_text' in context[0]:
222
- # context = context[0]['generated_text']
223
- # else:
224
- # context = context[0][0]['generated_text']
225
-
226
  context = get_generated_text(context).strip()
227
 
228
  generated_text += context
229
-
230
  # payload = {"inputs": context, "parameters":{
231
  # "max_new_tokens":max_length, "temperature":temperature, "top_p":top_p}}
232
  # data = json.dumps(payload)
@@ -248,6 +226,7 @@ def f(context, temperature, top_p, max_length, model_idx, SPACE_VERIFICATION_KEY
248
  else:
249
  url = "https://api-inference.huggingface.co/models/EleutherAI/gpt-neo-2.7B"
250
  generated_text = ""#context #""
 
251
  while len(generated_text) < max_length:
252
  payload = {"inputs": context, "parameters": {"return_full_text":False, "max_new_tokens": 250, "temperature": temperature, "top_p": top_p}}
253
  response = requests.request("POST", url, data=json.dumps(payload), headers=headers)
 
195
  # TODO use fallback gpt-2 inference api for this as well
196
  # TODO or just make it an option in the menu "GPT-2 inference"
197
  elif model_idx == 2:
198
+ url = "https://api-inference.huggingface.co/models/distilgpt2"
 
 
 
 
 
 
 
 
 
 
 
199
  generated_text = ""#context #""
200
+ while len(generated_text) < max_length:
 
 
 
201
  payload = {"inputs": context, "parameters": {"return_full_text":False, "max_new_tokens": 250, "temperature": temperature, "top_p": top_p}}
202
+ response = requests.request("POST", url, data=json.dumps(payload), headers=headers)
203
+ context = json.loads(response.content.decode("utf-8"))
 
 
 
 
 
 
 
 
204
  context = get_generated_text(context).strip()
205
 
206
  generated_text += context
207
+ return generated_text
208
  # payload = {"inputs": context, "parameters":{
209
  # "max_new_tokens":max_length, "temperature":temperature, "top_p":top_p}}
210
  # data = json.dumps(payload)
 
226
  else:
227
  url = "https://api-inference.huggingface.co/models/EleutherAI/gpt-neo-2.7B"
228
  generated_text = ""#context #""
229
+ # NOTE we're actually using max_new_tokens and min_new_tokens
230
  while len(generated_text) < max_length:
231
  payload = {"inputs": context, "parameters": {"return_full_text":False, "max_new_tokens": 250, "temperature": temperature, "top_p": top_p}}
232
  response = requests.request("POST", url, data=json.dumps(payload), headers=headers)