Spaces:
Sleeping
Sleeping
| from random import randint | |
| from transformers import pipeline, set_seed | |
| import requests | |
| import gradio as gr | |
| import json | |
| # # from transformers import AutoModelForCausalLM, AutoTokenizer | |
| def get(): | |
| pass | |
| def get(): | |
| pass; | |
| # stage, commit, push | |
| # # prompt = "In a shocking finding, scientists discovered a herd of unicorns living in a remote, " \ | |
| # # "previously unexplored valley, in the Andes Mountains. Even more surprising to the " \ | |
| # # "researchers was the fact that the unicorns spoke perfect English." | |
| # ex=None | |
| # try: | |
| # from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") | |
| # # "EluttherAI" on this line and for the next occurence only | |
| # # tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") | |
| # # model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B") | |
| # except Exception as e: | |
| # ex = e | |
| temperature = gr.inputs.Slider( | |
| minimum=0, maximum=1.5, default=0.8, label="temperature") | |
| top_p = gr.inputs.Slider(minimum=0, maximum=1.0, | |
| default=0.9, label="top_p") | |
| top_k = gr.inputs.Slider(minimum=0, maximum=100, | |
| default=40, label="top_k") | |
| # gradio checkbutton | |
| generator = pipeline('text-generation', model='gpt2') | |
| title = "GPT-J-6B" | |
| title = "text generator based on GPT models" | |
| # TODO TODO TODO TODO support fine tuned models or models for text generation for different purposes | |
| examples = [ | |
| # another machine learning example | |
| [["For today's homework assignment, please describe the reasons for the US Civil War."], 0.8, 0.9, 50, "GPT2"], | |
| [["In a shocking discovery, scientists have found a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English."], 0.8, 0.9, 50, "GPT2"], | |
| [["The first step in the process of developing a new language is to invent a new word."], 0.8, 0.9, 50, "GPT2"], | |
| ] | |
| # check if api.vicgalle.net:5000/generate is down with timeout of 10 seconds | |
| def is_up(url): | |
| try: | |
| requests.head(url, timeout=10) | |
| return True | |
| except Exception: | |
| return False | |
| # gpt_j_api_down = False | |
| import os | |
| API_URL = "https://api-inference.huggingface.co/models/EleutherAI/gpt-j-6B" | |
| main_gpt_j_api_up = is_up(API_URL) | |
| secondary_gpt_j_api_up = False | |
| if not main_gpt_j_api_up: | |
| # check whether secondary api is available | |
| API_URL = "https://api.vicgalle.net:5000/generate" | |
| secondary_gpt_j_api_up = is_up(API_URL) | |
| headers = {"Authorization": f"Bearer {os.environ['API_TOKEN']}"} | |
| # NOTE see build logs here: https://huggingface.co/spaces/un-index/textgen6b/logs/build | |
| def get_generated_text(generated_text): | |
| try: | |
| if 'generated_text' in generated_text[0]: | |
| return generated_text[0]['generated_text'] | |
| else: | |
| return generated_text[0][0]['generated_text'] | |
| except: | |
| # recursively loop through generated_text till we get the text | |
| # don't know if this will work | |
| for gt in generated_text: | |
| if 'generated_text' in gt: | |
| return gt['generated_text'] | |
| else: | |
| return get_generated_text(gt) | |
| # return generated_text | |
| def f(context, temperature, top_p, top_k, max_length, model_idx, SPACE_VERIFICATION_KEY): | |
| try: | |
| if os.environ['SPACE_VERIFICATION_KEY'] != SPACE_VERIFICATION_KEY: | |
| return "invalid SPACE_VERIFICATION_KEY; see project secrets to view key" | |
| try: | |
| set_seed(randint(1, 256)) | |
| except Exception as e: | |
| return "Exception while setting seed: " + str(e) | |
| top_p = (top_p==0 and None) or top_p | |
| top_k = (top_k==0 and None) or top_k | |
| # if neither one of top_p or top_k is truthy, or both are truthy, use top_p | |
| top_p = (not (top_p or top_k) or (top_p and top_k)) and 0.8 | |
| # TODO write a function to generate the payload, it's becoming repetitive | |
| # maybe try "0" instead or 1, or "1" | |
| # use GPT-J-6B | |
| if model_idx == 0: | |
| if main_gpt_j_api_up: | |
| # for this api, a length of > 250 instantly errors, so use a while loop or something | |
| # that would fetch results in chunks of 250 | |
| # NOTE change so it uses previous generated input every time | |
| # _context = context | |
| generated_text = ""#context #"" | |
| while len(generated_text) < max_length:#(max_length > 0): NOTE NOTE commented out this line and added new check | |
| # context becomes the previous generated context | |
| # NOTE I've set return_full_text to false, see how this plays out | |
| # change max_length from max_length>250 and 250 or max_length to 250 | |
| payload = {"inputs": context, "parameters": {"return_full_text":False, "max_new_tokens": 250, "temperature": temperature, "top_p": top_p, "top_k": top_k}} | |
| response = requests.request("POST", API_URL, data=json.dumps(payload), headers=headers) | |
| context = json.loads(response.content.decode("utf-8"))#[0]['generated_text'] | |
| # context = get_generated_text(generated_context) | |
| # handle inconsistent inference API | |
| # if 'generated_text' in context[0]: | |
| # context = context[0]['generated_text'] | |
| # else: | |
| # context = context[0][0]['generated_text'] | |
| context = get_generated_text(context).strip() | |
| generated_text += context | |
| # max_length -= 250 | |
| # payload = {"inputs": context, "parameters":{ | |
| # "max_new_tokens":max_length, "temperature":temperature, "top_p":top_p}} | |
| # data = json.dumps(payload) | |
| # response = requests.request("POST", API_URL, data=data, headers=headers) | |
| # generated_text = json.loads(response.content.decode("utf-8"))[0]['generated_text'] | |
| return generated_text#context #_context+generated_text | |
| # use secondary gpt-j-6B api, as the main one is down | |
| if not secondary_gpt_j_api_up: | |
| return "ERR: both GPT-J-6B APIs are down, please try again later (will use a third fallback in the future)" | |
| # use fallback API | |
| # | |
| # http://api.vicgalle.net:5000/docs#/default/generate_generate_post | |
| # https://pythonrepo.com/repo/vicgalle-gpt-j-api-python-natural-language-processing | |
| payload = { | |
| "context": context, | |
| "token_max_length": max_length, # 512, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "max_time": 120.0 | |
| } | |
| response = requests.post( | |
| "http://api.vicgalle.net:5000/generate", params=payload).json() | |
| return response['text'] | |
| elif model_idx == 1: | |
| # use GPT-2 | |
| # | |
| # try: | |
| # set_seed(randint(1, 2**31)) | |
| # except Exception as e: | |
| # return "Exception while setting seed: " + str(e) | |
| # return sequences specifies how many to return | |
| # for some reson indexing with 'generated-text' doesn't work | |
| # edit: maybe because I was using generated-text, not generated_text (note the underscore in the second) | |
| # try: | |
| # NOTE sometimes it seems to contain another array, weird | |
| try: | |
| # NOTE after exactly 60 seconds the fn function seems to error: https://discuss.huggingface.co/t/gradio-fn-function-errors-whenever-60-seconds-passed/13048 | |
| # todo fix max_length below, maybe there is a max_new_tokens parameter | |
| # try max_length=len(context)+max_length or =len(context)+max_length or make max_length inf or unspecified | |
| # note: added max_new_tokens parameter to see whether it actually works, if not remove, | |
| # TODO if yes, then make max_length infinite because it seems to be counted as max input length, not output | |
| # NOTE max_new_tokens does not seem to generate that many tokens | |
| # however in the source that's what's used | |
| # NOTE I think max_new_tokens is working now and punctuation characters count too | |
| # NOTE set max_length to max_length to allow input text of any size | |
| generated_text = generator(context, max_length=896, max_new_tokens=max_length, top_p=top_p, top_k=top_k, temperature=temperature, num_return_sequences=1) | |
| except Exception as e: | |
| return "Exception while generating text: " + str(e) | |
| # [0][0]['generated_text'] | |
| return get_generated_text(generated_text) | |
| # was error due to timeout because of not enabling queue in gradio interface? | |
| # if it works right now, then that was the reason for the JSON parsing error | |
| # except: | |
| # generated_text = generator(context, max_length=max_length, top_p=top_p, temperature=temperature, num_return_sequences=1)[0] | |
| # return generated_text | |
| # args found in the source: https://github.com/huggingface/transformers/blob/27b3031de2fb8195dec9bc2093e3e70bdb1c4bff/src/transformers/generation_tf_utils.py#L348-L376 | |
| # TODO use fallback gpt-2 inference api for this as well | |
| # TODO or just make it an option in the menu "GPT-2 inference" | |
| elif model_idx == 2: | |
| url = "https://api-inference.huggingface.co/models/distilgpt2" | |
| generated_text = ""#context #"" | |
| # NOTE adding repetition penalty parameter | |
| # NOTE maybe leave tha parameter and just write a function to remove repetitions | |
| while len(generated_text) < max_length: | |
| payload = {"inputs": context, "parameters": {"repetition_penalty":20.0,"return_full_text":False, "max_new_tokens": 250, "temperature": temperature, "top_p": top_p, "top_k": top_k}} | |
| response = requests.request("POST", url, data=json.dumps(payload), headers=headers) | |
| context = json.loads(response.content.decode("utf-8")) | |
| context = get_generated_text(context).strip() | |
| generated_text += context | |
| return generated_text | |
| # payload = {"inputs": context, "parameters":{ | |
| # "max_new_tokens":max_length, "temperature":temperature, "top_p":top_p}} | |
| # data = json.dumps(payload) | |
| # response = requests.request("POST", API_URL, data=data, headers=headers) | |
| # generated_text = json.loads(response.content.decode("utf-8"))[0]['generated_text'] | |
| # return generated_text#context #_context+generated_text | |
| elif model_idx == 3: | |
| url = "https://api-inference.huggingface.co/models/gpt2-large" | |
| generated_text = ""#context #"" | |
| while len(generated_text) < max_length: | |
| payload = {"inputs": context, "parameters": {"return_full_text":False, "max_new_tokens": 250, "temperature": temperature, "top_p": top_p, "top_k": top_k}} | |
| response = requests.request("POST", url, data=json.dumps(payload), headers=headers) | |
| context = json.loads(response.content.decode("utf-8")) | |
| context = get_generated_text(context).strip() | |
| generated_text += context | |
| return generated_text | |
| else: | |
| url = "https://api-inference.huggingface.co/models/EleutherAI/gpt-neo-2.7B" | |
| generated_text = ""#context #"" | |
| # NOTE we're actually using max_new_tokens and min_new_tokens | |
| while len(generated_text) < max_length: | |
| payload = {"inputs": context, "parameters": {"return_full_text":False, "max_new_tokens": 250, "temperature": temperature, "top_p": top_p, "top_k": top_k}} | |
| response = requests.request("POST", url, data=json.dumps(payload), headers=headers) | |
| context = json.loads(response.content.decode("utf-8")) | |
| context = get_generated_text(context).strip() | |
| generated_text += context | |
| return generated_text | |
| except Exception as e: | |
| return f"error with idx{model_idx}: "+str(e) | |
| iface = gr.Interface(f, [ | |
| "text", | |
| temperature, | |
| top_p, | |
| top_k, | |
| gr.inputs.Slider( | |
| minimum=20, maximum=512, default=30, label="max length"), | |
| gr.inputs.Dropdown(["GPT-J-6B", "GPT2", "DistilGPT2", "GPT-Large", "GPT-Neo-2.7B"], type="index", label="model", default="GPT2"), | |
| gr.inputs.Textbox(lines=1, placeholder="xxxxxxxx", label="space verification key") | |
| ], outputs="text", title=title, examples=examples, enable_queue = True) # deprecated iwthin iface.launch: https://discuss.huggingface.co/t/is-there-a-timeout-max-runtime-for-spaces/12979/3?u=un-index | |
| iface.launch() # enable_queue=True | |
| # all below works but testing | |
| # import gradio as gr | |
| # gr.Interface.load("huggingface/EleutherAI/gpt-j-6B", | |
| # inputs=gr.inputs.Textbox(lines=10, label="Input Text"), | |
| # title=title, examples=examples).launch(); | |