un-index commited on
Commit
9bf563d
·
1 Parent(s): 9ba6add
Files changed (1) hide show
  1. app.py +40 -8
app.py CHANGED
@@ -194,19 +194,38 @@ def f(context, temperature, top_p, max_length, model_idx, SPACE_VERIFICATION_KEY
194
 
195
  # TODO use fallback gpt-2 inference api for this as well
196
  # TODO or just make it an option in the menu "GPT-2 inference"
197
- else:
198
  DISTIL_GPT2_API_URL = "https://api-inference.huggingface.co/models/distilgpt2"
199
- generated_text=""
200
- while (max_length > 0):
201
- # NOTE see original implementation above for gpt-J-6B
202
- payload = {"inputs": context, "parameters": {"max_new_tokens": 250, "temperature": temperature, "top_p": top_p}}
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  response = requests.request("POST", DISTIL_GPT2_API_URL, data=json.dumps(payload), headers=headers)
204
- context = json.loads(response.content.decode("utf-8"))
 
 
 
 
 
 
 
205
 
206
- context = get_generated_text(context)
207
 
208
  generated_text += context
209
- max_length -= 250
210
 
211
  # payload = {"inputs": context, "parameters":{
212
  # "max_new_tokens":max_length, "temperature":temperature, "top_p":top_p}}
@@ -214,6 +233,19 @@ def f(context, temperature, top_p, max_length, model_idx, SPACE_VERIFICATION_KEY
214
  # response = requests.request("POST", API_URL, data=data, headers=headers)
215
  # generated_text = json.loads(response.content.decode("utf-8"))[0]['generated_text']
216
  return generated_text#context #_context+generated_text
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
  except Exception as e:
219
  return f"error with idx{model_idx}: "+str(e)
 
194
 
195
  # TODO use fallback gpt-2 inference api for this as well
196
  # TODO or just make it an option in the menu "GPT-2 inference"
197
+ elif model_idx == 2:
198
  DISTIL_GPT2_API_URL = "https://api-inference.huggingface.co/models/distilgpt2"
199
+ # generated_text=""
200
+ # while (max_length > 0):
201
+ # # NOTE see original implementation above for gpt-J-6B
202
+ # payload = {"inputs": context, "parameters": {"max_new_tokens": 250, "temperature": temperature, "top_p": top_p}}
203
+ # response = requests.request("POST", DISTIL_GPT2_API_URL, data=json.dumps(payload), headers=headers)
204
+ # context = json.loads(response.content.decode("utf-8"))
205
+
206
+ # context = get_generated_text(context)
207
+
208
+ # generated_text += context
209
+ # max_length -= 250
210
+ generated_text = ""#context #""
211
+ while len(generated_text) < max_length:#(max_length > 0): NOTE NOTE commented out this line and added new check
212
+ # context becomes the previous generated context
213
+ # NOTE I've set return_full_text to false, see how this plays out
214
+ # change max_length from max_length>250 and 250 or max_length to 250
215
+ payload = {"inputs": context, "parameters": {"return_full_text":False, "max_new_tokens": 250, "temperature": temperature, "top_p": top_p}}
216
  response = requests.request("POST", DISTIL_GPT2_API_URL, data=json.dumps(payload), headers=headers)
217
+ context = json.loads(response.content.decode("utf-8"))#[0]['generated_text']
218
+ # context = get_generated_text(generated_context)
219
+
220
+ # handle inconsistent inference API
221
+ # if 'generated_text' in context[0]:
222
+ # context = context[0]['generated_text']
223
+ # else:
224
+ # context = context[0][0]['generated_text']
225
 
226
+ context = get_generated_text(context).strip()
227
 
228
  generated_text += context
 
229
 
230
  # payload = {"inputs": context, "parameters":{
231
  # "max_new_tokens":max_length, "temperature":temperature, "top_p":top_p}}
 
233
  # response = requests.request("POST", API_URL, data=data, headers=headers)
234
  # generated_text = json.loads(response.content.decode("utf-8"))[0]['generated_text']
235
  return generated_text#context #_context+generated_text
236
+ else:
237
+ url = "https://api-inference.huggingface.co/models/gpt2-large"
238
+
239
+ generated_text = ""#context #""
240
+ while len(generated_text) < max_length:
241
+ payload = {"inputs": context, "parameters": {"return_full_text":False, "max_new_tokens": 250, "temperature": temperature, "top_p": top_p}}
242
+ response = requests.request("POST", url, data=json.dumps(payload), headers=headers)
243
+ context = json.loads(response.content.decode("utf-8"))
244
+ context = get_generated_text(context).strip()
245
+
246
+ generated_text += context
247
+ return generated_text
248
+
249
 
250
  except Exception as e:
251
  return f"error with idx{model_idx}: "+str(e)