Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import json | |
| import requests | |
| from mtranslate import translate | |
| from prompts import PROMPT_LIST | |
| import streamlit as st | |
| import random | |
| import fasttext | |
| import SessionState | |
| headers = {} | |
| LOGO = "huggingwayang.png" | |
| MODELS = { | |
| "GPT-2 Small": { | |
| "url": "https://api-inference.huggingface.co/models/flax-community/gpt2-small-indonesian" | |
| }, | |
| "GPT-2 Medium": { | |
| "url": "https://api-inference.huggingface.co/models/flax-community/gpt2-medium-indonesian" | |
| }, | |
| "GPT-2 Small finetuned on Indonesian academic journals": { | |
| "url": "https://api-inference.huggingface.co/models/Galuh/id-journal-gpt2" | |
| }, | |
| "GPT-2 Medium finetuned on Indonesian stories": { | |
| "url": "https://api-inference.huggingface.co/models/cahya/gpt2-medium-indonesian-story" | |
| }, | |
| } | |
| def get_image(text: str): | |
| url = "https://wikisearch.uncool.ai/get_image/" | |
| try: | |
| payload = { | |
| "text": text, | |
| "image_width": 400 | |
| } | |
| data = json.dumps(payload) | |
| response = requests.request("POST", url, headers=headers, data=data) | |
| print(response.content) | |
| image = json.loads(response.content.decode("utf-8"))["url"] | |
| except: | |
| image = "" | |
| return image | |
| def query(payload, model_name): | |
| data = json.dumps(payload) | |
| # print("model url:", MODELS[model_name]["url"]) | |
| response = requests.request("POST", MODELS[model_name]["url"], headers=headers, data=data) | |
| return json.loads(response.content.decode("utf-8")) | |
| def process(text: str, | |
| model_name: str, | |
| max_len: int, | |
| temp: float, | |
| top_k: int, | |
| top_p: float): | |
| payload = { | |
| "inputs": text, | |
| "parameters": { | |
| "max_new_tokens": max_len, | |
| "top_k": top_k, | |
| "top_p": top_p, | |
| "temperature": temp, | |
| "repetition_penalty": 2.0, | |
| }, | |
| "options": { | |
| "use_cache": True, | |
| } | |
| } | |
| return query(payload, model_name) | |
| st.set_page_config(page_title="Indonesian GPT-2 Demo") | |
| st.title("Indonesian GPT-2") | |
| ft_model = fasttext.load_model('lid.176.ftz') | |
| # Sidebar | |
| st.sidebar.image(LOGO) | |
| st.sidebar.subheader("Configurable parameters") | |
| max_len = st.sidebar.number_input( | |
| "Maximum length", | |
| value=100, | |
| help="The maximum length of the sequence to be generated." | |
| ) | |
| temp = st.sidebar.slider( | |
| "Temperature", | |
| value=1.0, | |
| min_value=0.0, | |
| max_value=100.0, | |
| help="The value used to module the next token probabilities." | |
| ) | |
| top_k = st.sidebar.number_input( | |
| "Top k", | |
| value=50, | |
| help="The number of highest probability vocabulary tokens to keep for top-k-filtering." | |
| ) | |
| top_p = st.sidebar.number_input( | |
| "Top p", | |
| value=0.95, | |
| help=" If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation." | |
| ) | |
| st.markdown( | |
| """ | |
| This demo uses the [small](https://huggingface.co/flax-community/gpt2-small-indonesian) and | |
| [medium](https://huggingface.co/flax-community/gpt2-medium-indonesian) Indonesian GPT2 model | |
| trained on the Indonesian [Oscar](https://huggingface.co/datasets/oscar), [MC4](https://huggingface.co/datasets/mc4) | |
| and [Wikipedia](https://huggingface.co/datasets/wikipedia) dataset. We created it as part of the | |
| [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/). | |
| The demo supports "multi language" ;-), feel free to try a prompt on your language. We are also experimenting with | |
| the sentence based image search using Wikipedia passages encoded with distillbert, and search the encoded sentence | |
| in the encoded passages using Facebook's Faiss (disabled temporary). | |
| """ | |
| ) | |
| model_name = st.selectbox('Model',([ | |
| 'GPT-2 Small', | |
| 'GPT-2 Medium', | |
| 'GPT-2 Small finetuned on Indonesian academic journals', | |
| 'GPT-2 Medium finetuned on Indonesian stories'])) | |
| if model_name in ["GPT-2 Small", "GPT-2 Medium"]: | |
| prompt_group_name = "GPT-2" | |
| elif model_name in ["GPT-2 Small finetuned on Indonesian academic journals"]: | |
| prompt_group_name = "Indonesian Journals" | |
| elif model_name in ["GPT-2 Medium finetuned on Indonesian stories"]: | |
| prompt_group_name = "Indonesian Stories" | |
| session_state = SessionState.get(prompt=None, prompt_box=None, text=None) | |
| ALL_PROMPTS = list(PROMPT_LIST[prompt_group_name].keys())+["Custom"] | |
| prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS)-1) | |
| # Update prompt | |
| if session_state.prompt is None: | |
| session_state.prompt = prompt | |
| elif session_state.prompt is not None and (prompt != session_state.prompt): | |
| session_state.prompt = prompt | |
| session_state.prompt_box = None | |
| session_state.text = None | |
| else: | |
| session_state.prompt = prompt | |
| # Update prompt box | |
| if session_state.prompt == "Custom": | |
| session_state.prompt_box = "Enter your text here" | |
| else: | |
| if session_state.prompt is not None and session_state.prompt_box is None: | |
| session_state.prompt_box = random.choice(PROMPT_LIST[prompt_group_name][session_state.prompt]) | |
| session_state.text = st.text_area("Enter text", session_state.prompt_box) | |
| if st.button("Run"): | |
| with st.spinner(text="Getting results..."): | |
| if model_name in ["GPT-2 Medium finetuned on Indonesian stories"]: | |
| lang = "id" | |
| text = session_state.text | |
| else: | |
| lang_predictions, lang_probability = ft_model.predict(session_state.text.replace("\n", " "), k=3) | |
| if "__label__id" in lang_predictions: | |
| lang = "id" | |
| text = session_state.text | |
| else: | |
| lang = lang_predictions[0].replace("__label__", "") | |
| text = translate(session_state.text, "id", lang) | |
| st.subheader("Result") | |
| result = process(text=text, | |
| model_name=model_name, | |
| max_len=int(max_len), | |
| temp=temp, | |
| top_k=int(top_k), | |
| top_p=float(top_p)) | |
| # print("result:", result) | |
| if "error" in result: | |
| if type(result["error"]) is str: | |
| st.write(f'{result["error"]}.', end=" ") | |
| if "estimated_time" in result: | |
| st.write(f'Please try it again in about {result["estimated_time"]:.0f} seconds') | |
| else: | |
| if type(result["error"]) is list: | |
| for error in result["error"]: | |
| st.write(f'{error}') | |
| else: | |
| result = result[0]["generated_text"] | |
| st.write(result.replace("\n", " \n")) | |
| st.text("Translation") | |
| translation = translate(result, "en", "id") | |
| if lang == "id": | |
| st.write(translation.replace("\n", " \n")) | |
| else: | |
| st.write(translate(result, lang, "id").replace("\n", " \n")) | |
| # Reset state | |
| session_state.prompt = None | |
| session_state.prompt_box = None | |
| session_state.text = None | |