un-index commited on
Commit
117e325
·
1 Parent(s): 037c712
Files changed (1) hide show
  1. app.py +36 -58
app.py CHANGED
@@ -32,20 +32,6 @@ top_p = gr.inputs.Slider(minimum=0, maximum=1.0,
32
  # gradio checkbutton
33
 
34
  generator = pipeline('text-generation', model='gpt2')
35
- from transformers import AutoModelForCausalLM, AutoTokenizer
36
- model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")
37
- tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
38
-
39
- # prompt = "In a shocking finding, scientists discovered a herd of unicorns living in a remote, " \
40
- # "previously unexplored valley, in the Andes Mountains. Even more surprising to the " \
41
- # "researchers was the fact that the unicorns spoke perfect English."
42
-
43
- # input_ids = tokenizer(prompt, return_tensors="pt").input_ids
44
-
45
- # gen_tokens = model.generate(input_ids, do_sample=True, temperature=0.9, max_length=100,)
46
- # gen_text = tokenizer.batch_decode(gen_tokens)[0]
47
-
48
- # gpt_j_generator = pipeline(model='EleutherAI/gpt-j-6B')
49
 
50
 
51
  title = "GPT-J-6B"
@@ -98,10 +84,12 @@ def get_generated_text(generated_text):
98
  except:
99
  # recursively loop through generated_text till we get the text
100
  # don't know if this will work
101
-
102
- # for i in
103
-
104
- return generated_text
 
 
105
 
106
 
107
  def f(context, temperature, top_p, max_length, model_idx, SPACE_VERIFICATION_KEY):
@@ -113,62 +101,52 @@ def f(context, temperature, top_p, max_length, model_idx, SPACE_VERIFICATION_KEY
113
  # maybe try "0" instead or 1, or "1"
114
  # use GPT-J-6B
115
  if model_idx == 0:
116
- # just use regular pipeline models man leave APIs
117
- input_ids = tokenizer(context, return_tensors="pt").input_ids
118
-
119
- gen_tokens = model.generate(input_ids, temperature=temperature, max_length=2**11, top_p=top_p, max_new_tokens=max_length, num_return_sequences=1)
120
- gen_text = tokenizer.batch_decode(gen_tokens)[0]
121
- return gen_text
122
- # set_seed(2**31)
123
- # generated_text = gpt_j_generator(context, max_length=896, max_new_tokens=max_length, top_p=top_p, temperature=temperature, num_return_sequences=1)
124
-
125
- # return get_generated_text(generated_text)
126
- # if main_gpt_j_api_up:
127
- # # for this api, a length of > 250 instantly errors, so use a while loop or something
128
- # # that would fetch results in chunks of 250
129
- # # NOTE change so it uses previous generated input every time
130
- # generated_text = context #""
131
- # while (max_length > 0):
132
- # payload = {"inputs": generated_text, "parameters": {"max_new_tokens": max_length>250 and 250 or max_length, "temperature": temperature, "top_p": top_p}}
133
- # response = requests.request("POST", API_URL, data=json.dumps(payload), headers=headers)
134
- # context = json.loads(response.content.decode("utf-8"))#[0]['generated_text']
135
- # # context = get_generated_text(generated_context)
136
- # # handle inconsistent inference API
137
- # if 'generated_text' in context[0]:
138
- # context = context[0]['generated_text']
139
- # else:
140
- # context = context[0][0]['generated_text']
141
 
142
- # generated_text += context
143
- # max_length -= 250
144
 
145
  # payload = {"inputs": context, "parameters":{
146
  # "max_new_tokens":max_length, "temperature":temperature, "top_p":top_p}}
147
  # data = json.dumps(payload)
148
  # response = requests.request("POST", API_URL, data=data, headers=headers)
149
  # generated_text = json.loads(response.content.decode("utf-8"))[0]['generated_text']
150
- # return generated_text
151
 
152
  # use secondary gpt-j-6B api, as the main one is down
153
- # if not secondary_gpt_j_api_up:
154
- # return "ERR: both GPT-J-6B APIs are down, please try again later (will use a third fallback in the future)"
155
 
156
  # use fallback API
157
  #
158
  # http://api.vicgalle.net:5000/docs#/default/generate_generate_post
159
  # https://pythonrepo.com/repo/vicgalle-gpt-j-api-python-natural-language-processing
160
 
161
- # payload = {
162
- # "context": context,
163
- # "token_max_length": max_length, # 512,
164
- # "temperature": temperature,
165
- # "top_p": top_p,
166
- # "max_time": 120.0
167
- # }
168
 
169
- # response = requests.post(
170
- # "http://api.vicgalle.net:5000/generate", params=payload).json()
171
- # return response['text']
172
  elif model_idx == 1:
173
  # use GPT-2
174
  #
 
32
  # gradio checkbutton
33
 
34
  generator = pipeline('text-generation', model='gpt2')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
 
37
  title = "GPT-J-6B"
 
84
  except:
85
  # recursively loop through generated_text till we get the text
86
  # don't know if this will work
87
+ for gt in generated_text:
88
+ if 'generated_text' in gt:
89
+ return gt['generated_text']
90
+ else:
91
+ return get_generated_text(gt)
92
+ # return generated_text
93
 
94
 
95
  def f(context, temperature, top_p, max_length, model_idx, SPACE_VERIFICATION_KEY):
 
101
  # maybe try "0" instead or 1, or "1"
102
  # use GPT-J-6B
103
  if model_idx == 0:
104
+ if main_gpt_j_api_up:
105
+ # for this api, a length of > 250 instantly errors, so use a while loop or something
106
+ # that would fetch results in chunks of 250
107
+ # NOTE change so it uses previous generated input every time
108
+ generated_text = context #""
109
+ while (max_length > 0):
110
+ payload = {"inputs": generated_text, "parameters": {"max_new_tokens": max_length>250 and 250 or max_length, "temperature": temperature, "top_p": top_p}}
111
+ response = requests.request("POST", API_URL, data=json.dumps(payload), headers=headers)
112
+ context = json.loads(response.content.decode("utf-8"))#[0]['generated_text']
113
+ # context = get_generated_text(generated_context)
114
+ # handle inconsistent inference API
115
+ if 'generated_text' in context[0]:
116
+ context = context[0]['generated_text']
117
+ else:
118
+ context = context[0][0]['generated_text']
 
 
 
 
 
 
 
 
 
 
119
 
120
+ generated_text += context
121
+ max_length -= 250
122
 
123
  # payload = {"inputs": context, "parameters":{
124
  # "max_new_tokens":max_length, "temperature":temperature, "top_p":top_p}}
125
  # data = json.dumps(payload)
126
  # response = requests.request("POST", API_URL, data=data, headers=headers)
127
  # generated_text = json.loads(response.content.decode("utf-8"))[0]['generated_text']
128
+ return generated_text
129
 
130
  # use secondary gpt-j-6B api, as the main one is down
131
+ if not secondary_gpt_j_api_up:
132
+ return "ERR: both GPT-J-6B APIs are down, please try again later (will use a third fallback in the future)"
133
 
134
  # use fallback API
135
  #
136
  # http://api.vicgalle.net:5000/docs#/default/generate_generate_post
137
  # https://pythonrepo.com/repo/vicgalle-gpt-j-api-python-natural-language-processing
138
 
139
+ payload = {
140
+ "context": context,
141
+ "token_max_length": max_length, # 512,
142
+ "temperature": temperature,
143
+ "top_p": top_p,
144
+ "max_time": 120.0
145
+ }
146
 
147
+ response = requests.post(
148
+ "http://api.vicgalle.net:5000/generate", params=payload).json()
149
+ return response['text']
150
  elif model_idx == 1:
151
  # use GPT-2
152
  #