un-index commited on
Commit
3346e5e
·
1 Parent(s): 0c91274
Files changed (1) hide show
  1. app.py +25 -17
app.py CHANGED
@@ -33,6 +33,8 @@ top_p = gr.inputs.Slider(minimum=0, maximum=1.0,
33
 
34
  generator = pipeline('text-generation', model='gpt2')
35
 
 
 
36
 
37
  title = "GPT-J-6B"
38
 
@@ -99,24 +101,30 @@ def f(context, temperature, top_p, max_length, model_idx, SPACE_VERIFICATION_KEY
99
  # maybe try "0" instead or 1, or "1"
100
  # use GPT-J-6B
101
  if model_idx == 0:
102
- if main_gpt_j_api_up:
103
- # for this api, a length of > 250 instantly errors, so use a while loop or something
104
- # that would fetch results in chunks of 250
105
- # NOTE change so it uses previous generated input every time
106
- generated_text = context #""
107
- while (max_length > 0):
108
- payload = {"inputs": generated_text, "parameters": {"max_new_tokens": max_length>250 and 250 or max_length, "temperature": temperature, "top_p": top_p}}
109
- response = requests.request("POST", API_URL, data=json.dumps(payload), headers=headers)
110
- context = json.loads(response.content.decode("utf-8"))#[0]['generated_text']
111
- # context = get_generated_text(generated_context)
112
- # handle inconsistent inference API
113
- if 'generated_text' in context[0]:
114
- context = context[0]['generated_text']
115
- else:
116
- context = context[0][0]['generated_text']
 
 
 
 
 
 
117
 
118
- generated_text += context
119
- max_length -= 250
120
 
121
  # payload = {"inputs": context, "parameters":{
122
  # "max_new_tokens":max_length, "temperature":temperature, "top_p":top_p}}
 
33
 
34
  generator = pipeline('text-generation', model='gpt2')
35
 
36
+ gpt_j_generator = pipeline('text-generation', model='GPT-J 6B')
37
+
38
 
39
  title = "GPT-J-6B"
40
 
 
101
  # maybe try "0" instead or 1, or "1"
102
  # use GPT-J-6B
103
  if model_idx == 0:
104
+ # just use regular pipeline models man leave APIs
105
+
106
+ set_seed(2**31)
107
+ generated_text = gpt_j_generator(context, max_length=896, max_new_tokens=max_length, top_p=top_p, temperature=temperature, num_return_sequences=1)
108
+
109
+ return get_generated_text(generated_text)
110
+ # if main_gpt_j_api_up:
111
+ # # for this api, a length of > 250 instantly errors, so use a while loop or something
112
+ # # that would fetch results in chunks of 250
113
+ # # NOTE change so it uses previous generated input every time
114
+ # generated_text = context #""
115
+ # while (max_length > 0):
116
+ # payload = {"inputs": generated_text, "parameters": {"max_new_tokens": max_length>250 and 250 or max_length, "temperature": temperature, "top_p": top_p}}
117
+ # response = requests.request("POST", API_URL, data=json.dumps(payload), headers=headers)
118
+ # context = json.loads(response.content.decode("utf-8"))#[0]['generated_text']
119
+ # # context = get_generated_text(generated_context)
120
+ # # handle inconsistent inference API
121
+ # if 'generated_text' in context[0]:
122
+ # context = context[0]['generated_text']
123
+ # else:
124
+ # context = context[0][0]['generated_text']
125
 
126
+ # generated_text += context
127
+ # max_length -= 250
128
 
129
  # payload = {"inputs": context, "parameters":{
130
  # "max_new_tokens":max_length, "temperature":temperature, "top_p":top_p}}