Spaces:
Sleeping
Sleeping
| import os | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
| import gradio as gr | |
| from transformers import pipeline | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from Ashaar.utils import get_output_df, get_highlighted_patterns_html | |
| from Ashaar.bait_analysis import BaitAnalysis | |
| from langs import * | |
| import sys | |
| import json | |
| import argparse | |
| arg_parser = argparse.ArgumentParser() | |
| arg_parser.add_argument('--lang', type = str, default = 'ar') | |
| args = arg_parser.parse_args() | |
| lang = args.lang | |
| if lang == 'ar': | |
| TITLE = TITLE_ar | |
| DESCRIPTION = DESCRIPTION_ar | |
| textbox_trg_text = textbox_trg_text_ar | |
| textbox_inp_text = textbox_inp_text_ar | |
| btn_trg_text = btn_trg_text_ar | |
| btn_inp_text = btn_inp_text_ar | |
| css = """ #textbox{ direction: RTL;}""" | |
| else: | |
| TITLE = TITLE_en | |
| DESCRIPTION = DESCRIPTION_en | |
| textbox_trg_text = textbox_trg_text_en | |
| textbox_inp_text = textbox_inp_text_en | |
| btn_trg_text = btn_trg_text_en | |
| btn_inp_text = btn_inp_text_en | |
| css = "" | |
| gpt_tokenizer = AutoTokenizer.from_pretrained('arbml/ashaar_tokenizer') | |
| model = AutoModelForCausalLM.from_pretrained('arbml/Ashaar_model') | |
| theme_to_token = json.load(open("extra/theme_tokens.json", "r")) | |
| token_to_theme = {t:m for m,t in theme_to_token.items()} | |
| meter_to_token = json.load(open("extra/meter_tokens.json", "r")) | |
| token_to_meter = {t:m for m,t in meter_to_token.items()} | |
| analysis = BaitAnalysis() | |
| meter, theme, qafiyah = "", "", "" | |
| def analyze(poem): | |
| global meter,theme,qafiyah, generate_btn | |
| shatrs = poem.split("\n") | |
| baits = [' # '.join(shatrs[2*i:2*i+2]) for i in range(len(shatrs)//2)] | |
| output = analysis.analyze(baits,override_tashkeel=True) | |
| meter = output['meter'] | |
| qafiyah = output['qafiyah'][0] | |
| theme = output['theme'][-1] | |
| df = get_output_df(output) | |
| return get_highlighted_patterns_html(df), gr.Button.update(interactive=True) | |
| def generate(inputs, top_p = 3): | |
| baits = inputs.split('\n') | |
| if len(baits) % 2 !=0: | |
| baits = baits[:-1] | |
| poem = ' '.join(['<|bsep|> '+baits[i]+' <|vsep|> '+baits[i+1]+' </|bsep|>' for i in range(0, len(baits), 2)]) | |
| prompt = f""" | |
| {meter_to_token[meter]} {qafiyah} {theme_to_token[theme]} | |
| <|psep|> | |
| {poem} | |
| """.strip() | |
| print(prompt) | |
| encoded_input = gpt_tokenizer(prompt, return_tensors='pt') | |
| output = model.generate(**encoded_input, max_length = 512, top_p = 3, do_sample=True) | |
| result = "" | |
| prev_token = "" | |
| line_cnts = 0 | |
| for i, beam in enumerate(output[:, len(encoded_input.input_ids[0]):]): | |
| if line_cnts >= 10: | |
| break | |
| for token in beam: | |
| if line_cnts >= 10: | |
| break | |
| decoded = gpt_tokenizer.decode(token) | |
| if 'meter' in decoded or 'theme' in decoded: | |
| break | |
| if decoded in ["<|vsep|>", "</|bsep|>"]: | |
| result += "\n" | |
| line_cnts+=1 | |
| elif decoded in ['<|bsep|>', '<|psep|>', '</|psep|>']: | |
| pass | |
| else: | |
| result += decoded | |
| prev_token = decoded | |
| else: | |
| break | |
| # return theme+" "+ f"من بحر {meter} مع قافية بحر ({qafiyah})" + "\n" +result | |
| return result, gr.Button.update(interactive=False) | |
| examples = [ | |
| [ | |
| """القلب أعلم يا عذول بدائه | |
| وأحق منك بجفنه وبمائه""" | |
| ], | |
| [ | |
| """رمتِ الفؤادَ مليحة عذراءُ | |
| بسهامِ لحظٍ ما لهنَّ دواءُ""" | |
| ], | |
| [ | |
| """أذَلَّ الحِرْصُ والطَّمَعُ الرِّقابَا | |
| وقَد يَعفو الكَريمُ، إذا استَرَابَا""" | |
| ] | |
| ] | |
| with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.HTML(TITLE) | |
| gr.HTML(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(): | |
| textbox_output = gr.Textbox(lines=10, label=textbox_trg_text, elem_id="textbox") | |
| with gr.Column(): | |
| inputs = gr.Textbox(lines=10, label=textbox_inp_text, elem_id="textbox") | |
| with gr.Row(): | |
| with gr.Column(): | |
| if lang == 'ar': | |
| trg_btn = gr.Button(btn_trg_text, interactive=False) | |
| else: | |
| trg_btn = gr.Button(btn_trg_text) | |
| with gr.Column(): | |
| if lang == 'ar': | |
| inp_btn = gr.Button(btn_inp_text) | |
| else: | |
| inp_btn = gr.Button(btn_inp_text, interactive = False) | |
| with gr.Row(): | |
| html_output = gr.HTML() | |
| if lang == 'en': | |
| gr.Examples(examples, textbox_output) | |
| inp_btn.click(generate, inputs = textbox_output, outputs=[inputs, inp_btn]) | |
| trg_btn.click(analyze, inputs = textbox_output, outputs=[html_output,inp_btn]) | |
| else: | |
| gr.Examples(examples, inputs) | |
| trg_btn.click(generate, inputs = inputs, outputs=[textbox_output, trg_btn]) | |
| inp_btn.click(analyze, inputs = inputs, outputs=[html_output,trg_btn] ) | |
| # demo.launch(server_name = '0.0.0.0', share=True) | |
| demo.launch() |