un-index commited on
Commit
6779ce8
·
1 Parent(s): be94533
Files changed (1) hide show
  1. app.py +36 -22
app.py CHANGED
@@ -42,9 +42,9 @@ title = "text generator based on GPT models"
42
 
43
  examples = [
44
  # another machine learning example
45
- [["For today's homework assignment, please describe the reasons for the US Civil War."], 0.8, 0.9, 50, "GPT-2"],
46
- [["In a shocking discovery, scientists have found a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English."], 0.8, 0.9, 50, "GPT-2"],
47
- [["The first step in the process of developing a new language is to invent a new word."], 0.8, 0.9, 50, "GPT-2"],
48
  ]
49
 
50
 
@@ -75,6 +75,16 @@ headers = {"Authorization": f"Bearer {os.environ['API_TOKEN']}"}
75
  # NOTE see build logs here: https://huggingface.co/spaces/un-index/textgen6b/logs/build
76
 
77
 
 
 
 
 
 
 
 
 
 
 
78
  def f(context, temperature, top_p, max_length, model_idx, SPACE_VERIFICATION_KEY):
79
  try:
80
 
@@ -87,17 +97,20 @@ def f(context, temperature, top_p, max_length, model_idx, SPACE_VERIFICATION_KEY
87
  if main_gpt_j_api_up:
88
  # for this api, a length of > 250 instantly errors, so use a while loop or something
89
  # that would fetch results in chunks of 250
90
- generated_text = ""
91
  while (max_length > 0):
92
- payload = {"inputs": context, "parameters": {"max_new_tokens": 250, "temperature": temperature, "top_p": top_p}}
93
  response = requests.request("POST", API_URL, data=json.dumps(payload), headers=headers)
94
- context = json.loads(response.content.decode("utf-8"))#[0]['generated_text']
 
 
95
  # handle inconsistent inference API
96
- if 'generated_text' in context[0]:
97
- context = context[0]['generated_text']
98
- else:
99
- context = context[0][0]['generated_text']
100
- generated_text += context
 
101
  max_length -= 250
102
 
103
  # payload = {"inputs": context, "parameters":{
@@ -105,7 +118,9 @@ def f(context, temperature, top_p, max_length, model_idx, SPACE_VERIFICATION_KEY
105
  # data = json.dumps(payload)
106
  # response = requests.request("POST", API_URL, data=data, headers=headers)
107
  # generated_text = json.loads(response.content.decode("utf-8"))[0]['generated_text']
108
- return generated_text
 
 
109
 
110
  # use secondary gpt-j-6B api, as the main one is down
111
  if not secondary_gpt_j_api_up:
@@ -127,7 +142,7 @@ def f(context, temperature, top_p, max_length, model_idx, SPACE_VERIFICATION_KEY
127
  response = requests.post(
128
  "http://api.vicgalle.net:5000/generate", params=payload).json()
129
  return response['text']
130
- else:
131
  # use GPT-2
132
  #
133
  try:
@@ -148,18 +163,16 @@ def f(context, temperature, top_p, max_length, model_idx, SPACE_VERIFICATION_KEY
148
  # TODO if yes, then make max_length infinite because it seems to be counted as max input length, not output
149
  # NOTE max_new_tokens does not seem to generate that many tokens
150
  # however in the source that's what's used
 
 
151
  generated_text = generator(context, max_length=896, max_new_tokens=max_length, top_p=top_p, temperature=temperature, num_return_sequences=1)
152
  except Exception as e:
153
  return "Exception while generating text: " + str(e)
154
  # [0][0]['generated_text']
155
 
156
- try:
157
- if 'generated_text' in generated_text[0]:
158
- return generated_text[0]['generated_text']
159
- else:
160
- return generated_text[0][0]['generated_text']
161
- except:
162
- return generated_text # was error due to timeout because of not enabling queue in gradio interface?
163
  # if it works right now, then that was the reason for the JSON parsing error
164
  # except:
165
  # generated_text = generator(context, max_length=max_length, top_p=top_p, temperature=temperature, num_return_sequences=1)[0]
@@ -169,7 +182,8 @@ def f(context, temperature, top_p, max_length, model_idx, SPACE_VERIFICATION_KEY
169
 
170
  # TODO use fallback gpt-2 inference api for this as well
171
  # TODO or just make it an option in the menu "GPT-2 inference"
172
-
 
173
 
174
  except Exception as e:
175
  return f"error with idx{model_idx}: "+str(e)
@@ -181,7 +195,7 @@ iface = gr.Interface(f, [
181
  top_p,
182
  gr.inputs.Slider(
183
  minimum=20, maximum=512, default=30, label="max length"),
184
- gr.inputs.Dropdown(["GPT-J-6B", "GPT-2"], type="index", label="model"),
185
  gr.inputs.Textbox(lines=1, placeholder="xxxxxxxx", label="space verification key")
186
 
187
  ], outputs="text", title=title, examples=examples, enable_queue = True) # deprecated iwthin iface.launch: https://discuss.huggingface.co/t/is-there-a-timeout-max-runtime-for-spaces/12979/3?u=un-index
 
42
 
43
  examples = [
44
  # another machine learning example
45
+ [["For today's homework assignment, please describe the reasons for the US Civil War."], 0.8, 0.9, 50, "GPT2"],
46
+ [["In a shocking discovery, scientists have found a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English."], 0.8, 0.9, 50, "GPT2"],
47
+ [["The first step in the process of developing a new language is to invent a new word."], 0.8, 0.9, 50, "GPT2"],
48
  ]
49
 
50
 
 
75
  # NOTE see build logs here: https://huggingface.co/spaces/un-index/textgen6b/logs/build
76
 
77
 
78
+ def get_generated_text(generated_text):
79
+ try:
80
+ if 'generated_text' in generated_text[0]:
81
+ return generated_text[0]['generated_text']
82
+ else:
83
+ return generated_text[0][0]['generated_text']
84
+ except:
85
+ return generated_text
86
+
87
+
88
  def f(context, temperature, top_p, max_length, model_idx, SPACE_VERIFICATION_KEY):
89
  try:
90
 
 
97
  if main_gpt_j_api_up:
98
  # for this api, a length of > 250 instantly errors, so use a while loop or something
99
  # that would fetch results in chunks of 250
100
+ generated_total = context
101
  while (max_length > 0):
102
+ payload = {"inputs": generated_total, "parameters": {"max_new_tokens": 250, "temperature": temperature, "top_p": top_p}}
103
  response = requests.request("POST", API_URL, data=json.dumps(payload), headers=headers)
104
+ generated_text = json.loads(response.content.decode("utf-8"))#[0]['generated_text']
105
+
106
+ generated_text = get_generated_text(generated_text)
107
  # handle inconsistent inference API
108
+ # if 'generated_text' in context[0]:
109
+ # context = context[0]['generated_text']
110
+ # else:
111
+ # context = context[0][0]['generated_text']
112
+
113
+ generated_total += context
114
  max_length -= 250
115
 
116
  # payload = {"inputs": context, "parameters":{
 
118
  # data = json.dumps(payload)
119
  # response = requests.request("POST", API_URL, data=data, headers=headers)
120
  # generated_text = json.loads(response.content.decode("utf-8"))[0]['generated_text']
121
+ # remove first n characters of generated total where n = len(context)
122
+ generated_total = generated_total[len(context):]
123
+ return generated_total
124
 
125
  # use secondary gpt-j-6B api, as the main one is down
126
  if not secondary_gpt_j_api_up:
 
142
  response = requests.post(
143
  "http://api.vicgalle.net:5000/generate", params=payload).json()
144
  return response['text']
145
+ elif model_idx == 1:
146
  # use GPT-2
147
  #
148
  try:
 
163
  # TODO if yes, then make max_length infinite because it seems to be counted as max input length, not output
164
  # NOTE max_new_tokens does not seem to generate that many tokens
165
  # however in the source that's what's used
166
+ # NOTE I think max_new_tokens is working now and punctuation characters count too
167
+ # NOTE set max_length to max_length to allow input text of any size
168
  generated_text = generator(context, max_length=896, max_new_tokens=max_length, top_p=top_p, temperature=temperature, num_return_sequences=1)
169
  except Exception as e:
170
  return "Exception while generating text: " + str(e)
171
  # [0][0]['generated_text']
172
 
173
+ return get_generated_text(generated_text)
174
+
175
+ # was error due to timeout because of not enabling queue in gradio interface?
 
 
 
 
176
  # if it works right now, then that was the reason for the JSON parsing error
177
  # except:
178
  # generated_text = generator(context, max_length=max_length, top_p=top_p, temperature=temperature, num_return_sequences=1)[0]
 
182
 
183
  # TODO use fallback gpt-2 inference api for this as well
184
  # TODO or just make it an option in the menu "GPT-2 inference"
185
+ else:
186
+ url = "https://api-inference.huggingface.co/models/distilgpt2"
187
 
188
  except Exception as e:
189
  return f"error with idx{model_idx}: "+str(e)
 
195
  top_p,
196
  gr.inputs.Slider(
197
  minimum=20, maximum=512, default=30, label="max length"),
198
+ gr.inputs.Dropdown(["GPT-J-6B", "GPT2", "DistilGPT2"], type="index", label="model"),
199
  gr.inputs.Textbox(lines=1, placeholder="xxxxxxxx", label="space verification key")
200
 
201
  ], outputs="text", title=title, examples=examples, enable_queue = True) # deprecated iwthin iface.launch: https://discuss.huggingface.co/t/is-there-a-timeout-max-runtime-for-spaces/12979/3?u=un-index