File size: 13,501 Bytes
75197d8
3dec2d1
 
1c8febf
0cc2838
a183160
 
5e3b1a8
145938e
 
 
 
 
 
75197d8
bba0271
1588190
5e3b1a8
 
 
1588190
7c1c1eb
 
 
 
1c8febf
7c1c1eb
 
 
 
 
1588190
1c8febf
68394f9
1c8febf
9ef03c3
145938e
4cd5b9e
5f283d9
68394f9
 
759a131
3346e5e
759a131
e74817f
 
571d832
 
 
68394f9
 
 
cbabcb5
6779ce8
cbabcb5
68394f9
e9b9296
cbabcb5
 
1c8febf
a183160
 
 
 
 
 
 
 
 
 
cbabcb5
a183160
 
 
 
 
 
 
 
 
 
 
62a9753
cbabcb5
3dec2d1
6779ce8
 
 
 
 
 
 
120d7a8
cbabcb5
117e325
 
 
 
 
cbabcb5
6779ce8
 
d69b53f
145938e
5e06490
759a131
7216e2e
 
a183160
145938e
 
 
 
 
6954f5d
145938e
6954f5d
 
 
 
145938e
759a131
 
eabeaac
117e325
 
 
 
ab2b147
cbabcb5
e7e8039
c249a9b
171b9d0
9ba6add
145938e
cbabcb5
 
117e325
cbabcb5
117e325
94b1007
 
 
 
cbabcb5
9ba6add
94b1007
117e325
d28ccc2
60b60b7
 
 
 
 
 
cbabcb5
a183160
552084d
117e325
 
a183160
 
cbabcb5
759a131
 
 
117e325
 
 
 
 
 
 
cbabcb5
117e325
 
 
eabeaac
759a131
e74817f
145938e
 
 
 
3dec2d1
cbabcb5
182c51e
 
103f0d1
f25f764
6ac99e7
a71047c
3a988df
4d0bba1
cbabcb5
45d20de
be94533
 
6779ce8
 
145938e
6ac99e7
 
9663b06
 
6779ce8
cbabcb5
 
 
 
 
 
62a9753
3dec2d1
bbef3ac
cbabcb5
6095690
9bf563d
145938e
02037ab
9bf563d
d69b53f
 
02037ab
145938e
02037ab
 
9bf563d
85f4499
 
02037ab
d5076c2
 
 
 
 
f5f0157
461eb5a
9bf563d
 
 
 
145938e
9bf563d
 
 
 
 
 
461eb5a
725503b
461eb5a
02037ab
461eb5a
145938e
461eb5a
 
 
9bf563d
461eb5a
 
3296004
5e06490
45d20de
dd62ca3
4c97ef8
1c8febf
 
 
759a131
145938e
759a131
3dec2d1
eabeaac
cbabcb5
7216e2e
cbabcb5
a71047c
1588190
0cc2838
95c5c80
1588190
3e6cabd
95c5c80
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287

from random import randint
from transformers import pipeline, set_seed
import requests
import gradio as gr
import json

# # from transformers import AutoModelForCausalLM, AutoTokenizer
def get():
    pass
def get():
    pass;



# stage, commit, push

# # prompt = "In a shocking finding, scientists discovered a herd of unicorns living in a remote, " \
# #          "previously unexplored valley, in the Andes Mountains. Even more surprising to the " \
# #          "researchers was the fact that the unicorns spoke perfect English."

# ex=None
# try:
#     from transformers import AutoModelForCausalLM, AutoTokenizer
#     tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")

#     # "EluttherAI" on this line and for the next occurence only
#     # tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
#     # model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")
# except Exception as e:
#     ex = e

temperature = gr.inputs.Slider(
    minimum=0, maximum=1.5, default=0.8, label="temperature")
top_p = gr.inputs.Slider(minimum=0, maximum=1.0,
                         default=0.9, label="top_p")
top_k = gr.inputs.Slider(minimum=0, maximum=100,
                         default=40, label="top_k")

# gradio checkbutton

generator = pipeline('text-generation', model='gpt2')


title = "GPT-J-6B"


title = "text generator based on GPT models"
# TODO TODO TODO TODO  support fine tuned models or models for text generation for different purposes

examples = [
    # another machine learning example
    [["For today's homework assignment, please describe the reasons for the US Civil War."], 0.8, 0.9, 50, "GPT2"],
    [["In a shocking discovery, scientists have found a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English."], 0.8, 0.9, 50, "GPT2"],
    [["The first step in the process of developing a new language is to invent a new word."], 0.8, 0.9, 50, "GPT2"],
]

            


# check if api.vicgalle.net:5000/generate is down with timeout of 10 seconds
def is_up(url):
    try:
        requests.head(url, timeout=10)
        return True
    except Exception:
        return False

# gpt_j_api_down = False

import os

API_URL = "https://api-inference.huggingface.co/models/EleutherAI/gpt-j-6B"
main_gpt_j_api_up = is_up(API_URL)
secondary_gpt_j_api_up = False
if not main_gpt_j_api_up:
    # check whether secondary api is available
    API_URL = "https://api.vicgalle.net:5000/generate"
    secondary_gpt_j_api_up = is_up(API_URL)

headers = {"Authorization": f"Bearer {os.environ['API_TOKEN']}"}

# NOTE see build logs here: https://huggingface.co/spaces/un-index/textgen6b/logs/build
    

def get_generated_text(generated_text):
    try:
        if 'generated_text' in generated_text[0]:
            return generated_text[0]['generated_text']
        else:
            return generated_text[0][0]['generated_text']
    except:
        # recursively loop through generated_text till we get the text
        # don't know if this will work 
        for gt in generated_text:
            if 'generated_text' in gt:
                return gt['generated_text']
            else:
                return get_generated_text(gt)
        # return generated_text 



def f(context, temperature, top_p, top_k, max_length, model_idx, SPACE_VERIFICATION_KEY):
    try:

        if os.environ['SPACE_VERIFICATION_KEY'] != SPACE_VERIFICATION_KEY:
            return "invalid SPACE_VERIFICATION_KEY; see project secrets to view key"

        try:
            set_seed(randint(1, 256))
        except Exception as e:
            return "Exception while setting seed: " + str(e)

        top_p = (top_p==0 and None) or top_p
        top_k = (top_k==0 and None) or top_k

        # if neither one of top_p or top_k is truthy, or both are truthy, use top_p
        top_p = (not (top_p or top_k) or (top_p and top_k)) and 0.8

        # TODO write a function to generate the payload, it's becoming repetitive
        # maybe try "0" instead or 1, or "1"
        # use GPT-J-6B
        if model_idx == 0:
            if main_gpt_j_api_up:
                # for this api, a length of > 250 instantly errors, so use a while loop or something
                # that would fetch results in chunks of 250
                # NOTE change so it uses previous generated input every time
                # _context = context
                generated_text = ""#context #""
                while len(generated_text) < max_length:#(max_length > 0): NOTE NOTE commented out this line and added new check
                    # context becomes the previous generated context
                    # NOTE I've set return_full_text to false, see how this plays out
                    # change max_length from max_length>250 and 250 or max_length to 250
                    payload = {"inputs": context, "parameters": {"return_full_text":False, "max_new_tokens": 250, "temperature": temperature, "top_p": top_p, "top_k": top_k}}
                    response = requests.request("POST", API_URL, data=json.dumps(payload), headers=headers)
                    context = json.loads(response.content.decode("utf-8"))#[0]['generated_text']
                    # context = get_generated_text(generated_context)
                    
                    # handle inconsistent inference API
                    # if 'generated_text' in context[0]:
                    #     context = context[0]['generated_text']
                    # else:
                    #     context = context[0][0]['generated_text']
                        
                    context = get_generated_text(context).strip()

                    generated_text += context
                    # max_length -= 250

                # payload = {"inputs": context, "parameters":{
                #     "max_new_tokens":max_length, "temperature":temperature, "top_p":top_p}}
                # data = json.dumps(payload)
                # response = requests.request("POST", API_URL, data=data, headers=headers)
                # generated_text = json.loads(response.content.decode("utf-8"))[0]['generated_text']
                return generated_text#context #_context+generated_text

            # use secondary gpt-j-6B api, as the main one is down
            if not secondary_gpt_j_api_up:
                return "ERR: both GPT-J-6B APIs are down, please try again later (will use a third fallback in the future)"

            # use fallback API
            # 
            # http://api.vicgalle.net:5000/docs#/default/generate_generate_post
            # https://pythonrepo.com/repo/vicgalle-gpt-j-api-python-natural-language-processing

            payload = {
                "context": context,
                "token_max_length": max_length,  # 512,
                "temperature": temperature,
                "top_p": top_p,
                "max_time": 120.0
            }
            
            response = requests.post(
                "http://api.vicgalle.net:5000/generate", params=payload).json()
            return response['text']
        elif model_idx == 1:
            # use GPT-2
            #
            # try:
            #     set_seed(randint(1, 2**31))
            # except Exception as e:
            #     return "Exception while setting seed: " + str(e)
            # return sequences specifies how many to return
            
            # for some reson indexing with 'generated-text' doesn't work
            # edit: maybe because I was using generated-text, not generated_text (note the underscore in the second)
            # try:
            # NOTE sometimes it seems to contain another array, weird
            try:
                # NOTE after exactly 60 seconds the fn function seems to error: https://discuss.huggingface.co/t/gradio-fn-function-errors-whenever-60-seconds-passed/13048
                # todo fix max_length below, maybe there is a max_new_tokens parameter
                # try max_length=len(context)+max_length or =len(context)+max_length or make max_length inf or unspecified
                # note: added max_new_tokens parameter to see whether it actually works, if not remove, 
                # TODO if yes, then make max_length infinite because it seems to be counted as max input length, not output
                # NOTE max_new_tokens does not seem to generate that many tokens
                # however in the source that's what's used
                # NOTE I think max_new_tokens is working now and punctuation characters count too
                # NOTE set max_length to max_length to allow input text of any size
                generated_text = generator(context, max_length=896, max_new_tokens=max_length, top_p=top_p, top_k=top_k, temperature=temperature, num_return_sequences=1)
            except Exception as e:
                return "Exception while generating text: " + str(e)
            # [0][0]['generated_text']

            return get_generated_text(generated_text)
                
                # was error due to timeout because of not enabling queue in gradio interface?
                # if it works right now, then that was the reason for the JSON parsing error
            # except: 
                # generated_text = generator(context, max_length=max_length, top_p=top_p, temperature=temperature, num_return_sequences=1)[0]
                
            # return generated_text
            # args found in the source: https://github.com/huggingface/transformers/blob/27b3031de2fb8195dec9bc2093e3e70bdb1c4bff/src/transformers/generation_tf_utils.py#L348-L376

            # TODO use fallback gpt-2 inference api for this as well 
            # TODO or just make it an option in the menu "GPT-2 inference"
        elif model_idx == 2: 
            
            url = "https://api-inference.huggingface.co/models/distilgpt2"
            generated_text = ""#context #""
            # NOTE adding repetition penalty parameter
            # NOTE maybe leave tha parameter and just write a function to remove repetitions
            while len(generated_text) < max_length:
                payload = {"inputs": context, "parameters": {"repetition_penalty":20.0,"return_full_text":False, "max_new_tokens": 250, "temperature": temperature, "top_p": top_p, "top_k": top_k}}
                response = requests.request("POST", url, data=json.dumps(payload), headers=headers)
                context = json.loads(response.content.decode("utf-8"))
                context = get_generated_text(context).strip()

                generated_text += context
            return generated_text
                # payload = {"inputs": context, "parameters":{
                #     "max_new_tokens":max_length, "temperature":temperature, "top_p":top_p}}
                # data = json.dumps(payload)
                # response = requests.request("POST", API_URL, data=data, headers=headers)
                # generated_text = json.loads(response.content.decode("utf-8"))[0]['generated_text']
            # return generated_text#context #_context+generated_text
        elif model_idx == 3:
            url = "https://api-inference.huggingface.co/models/gpt2-large"

            generated_text = ""#context #""
            while len(generated_text) < max_length:
                payload = {"inputs": context, "parameters": {"return_full_text":False, "max_new_tokens": 250, "temperature": temperature, "top_p": top_p, "top_k": top_k}}
                response = requests.request("POST", url, data=json.dumps(payload), headers=headers)
                context = json.loads(response.content.decode("utf-8"))
                context = get_generated_text(context).strip()

                generated_text += context
            return generated_text
        else:
            url = "https://api-inference.huggingface.co/models/EleutherAI/gpt-neo-2.7B"
            generated_text = ""#context #""
            # NOTE we're actually using max_new_tokens and min_new_tokens
            while len(generated_text) < max_length:
                payload = {"inputs": context, "parameters": {"return_full_text":False, "max_new_tokens": 250, "temperature": temperature, "top_p": top_p, "top_k": top_k}}
                response = requests.request("POST", url, data=json.dumps(payload), headers=headers)
                context = json.loads(response.content.decode("utf-8"))
                context = get_generated_text(context).strip()

                generated_text += context
            return generated_text

    except Exception as e:
        return f"error with idx{model_idx}: "+str(e)


iface = gr.Interface(f, [
    "text",
    temperature,
    top_p,
    top_k, 
    gr.inputs.Slider(
        minimum=20, maximum=512, default=30, label="max length"),
    gr.inputs.Dropdown(["GPT-J-6B", "GPT2", "DistilGPT2", "GPT-Large", "GPT-Neo-2.7B"], type="index", label="model", default="GPT2"),
    gr.inputs.Textbox(lines=1, placeholder="xxxxxxxx", label="space verification key")

], outputs="text", title=title, examples=examples, enable_queue = True) # deprecated iwthin iface.launch: https://discuss.huggingface.co/t/is-there-a-timeout-max-runtime-for-spaces/12979/3?u=un-index
iface.launch()  # enable_queue=True

# all below works but testing
# import gradio as gr


# gr.Interface.load("huggingface/EleutherAI/gpt-j-6B",
#     inputs=gr.inputs.Textbox(lines=10, label="Input Text"),
#     title=title, examples=examples).launch();