Spaces:
Sleeping
Sleeping
un-index
commited on
Commit
·
3346e5e
1
Parent(s):
0c91274
app.py
CHANGED
|
@@ -33,6 +33,8 @@ top_p = gr.inputs.Slider(minimum=0, maximum=1.0,
|
|
| 33 |
|
| 34 |
generator = pipeline('text-generation', model='gpt2')
|
| 35 |
|
|
|
|
|
|
|
| 36 |
|
| 37 |
title = "GPT-J-6B"
|
| 38 |
|
|
@@ -99,24 +101,30 @@ def f(context, temperature, top_p, max_length, model_idx, SPACE_VERIFICATION_KEY
|
|
| 99 |
# maybe try "0" instead or 1, or "1"
|
| 100 |
# use GPT-J-6B
|
| 101 |
if model_idx == 0:
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
-
|
| 119 |
-
|
| 120 |
|
| 121 |
# payload = {"inputs": context, "parameters":{
|
| 122 |
# "max_new_tokens":max_length, "temperature":temperature, "top_p":top_p}}
|
|
|
|
| 33 |
|
| 34 |
generator = pipeline('text-generation', model='gpt2')
|
| 35 |
|
| 36 |
+
gpt_j_generator = pipeline('text-generation', model='GPT-J 6B')
|
| 37 |
+
|
| 38 |
|
| 39 |
title = "GPT-J-6B"
|
| 40 |
|
|
|
|
| 101 |
# maybe try "0" instead or 1, or "1"
|
| 102 |
# use GPT-J-6B
|
| 103 |
if model_idx == 0:
|
| 104 |
+
# just use regular pipeline models man leave APIs
|
| 105 |
+
|
| 106 |
+
set_seed(2**31)
|
| 107 |
+
generated_text = gpt_j_generator(context, max_length=896, max_new_tokens=max_length, top_p=top_p, temperature=temperature, num_return_sequences=1)
|
| 108 |
+
|
| 109 |
+
return get_generated_text(generated_text)
|
| 110 |
+
# if main_gpt_j_api_up:
|
| 111 |
+
# # for this api, a length of > 250 instantly errors, so use a while loop or something
|
| 112 |
+
# # that would fetch results in chunks of 250
|
| 113 |
+
# # NOTE change so it uses previous generated input every time
|
| 114 |
+
# generated_text = context #""
|
| 115 |
+
# while (max_length > 0):
|
| 116 |
+
# payload = {"inputs": generated_text, "parameters": {"max_new_tokens": max_length>250 and 250 or max_length, "temperature": temperature, "top_p": top_p}}
|
| 117 |
+
# response = requests.request("POST", API_URL, data=json.dumps(payload), headers=headers)
|
| 118 |
+
# context = json.loads(response.content.decode("utf-8"))#[0]['generated_text']
|
| 119 |
+
# # context = get_generated_text(generated_context)
|
| 120 |
+
# # handle inconsistent inference API
|
| 121 |
+
# if 'generated_text' in context[0]:
|
| 122 |
+
# context = context[0]['generated_text']
|
| 123 |
+
# else:
|
| 124 |
+
# context = context[0][0]['generated_text']
|
| 125 |
|
| 126 |
+
# generated_text += context
|
| 127 |
+
# max_length -= 250
|
| 128 |
|
| 129 |
# payload = {"inputs": context, "parameters":{
|
| 130 |
# "max_new_tokens":max_length, "temperature":temperature, "top_p":top_p}}
|