Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| from peft import PeftModel, PeftConfig | |
| from transformers import ( | |
| MistralForCausalLM, | |
| TextIteratorStreamer, | |
| AutoTokenizer, | |
| BitsAndBytesConfig, | |
| GenerationConfig, | |
| ) | |
| from time import sleep | |
| from threading import Thread | |
| from torch import float16 | |
| import spaces | |
| import huggingface_hub | |
| from threading import Thread | |
| from queue import Queue | |
| from time import sleep | |
| from os import getenv | |
| # from data_logger import log_data | |
| from datetime import datetime | |
| def check_thread(logging_queue: Queue): | |
| logging_callback = log_data( | |
| hf_token=getenv("HF_API_TOKEN"), | |
| dataset_name=getenv("OUTPUT_DATASET"), | |
| private=True, | |
| ) | |
| while True: | |
| sleep(60) | |
| batch = [] | |
| while not logging_queue.empty(): | |
| batch.append(logging_queue.get()) | |
| if len(batch) > 0: | |
| try: | |
| logging_callback(batch) | |
| except: | |
| print( | |
| "Error happened while pushing data to HF. Puttting items back in queue..." | |
| ) | |
| for item in batch: | |
| logging_queue.put(item) | |
| if False: #getenv("HF_API_TOKEN") is not None: | |
| #print("Starting logging thread...") | |
| #log_queue = Queue() | |
| #t = Thread(target=check_thread, args=(log_queue,)) | |
| #t.start() | |
| logging_callback = log_data( | |
| hf_token=getenv("HF_API_TOKEN"), | |
| dataset_name=getenv("OUTPUT_DATASET"), | |
| private=True, | |
| ) | |
| else: | |
| print("No HF_API_TOKEN found. Logging is disabled.") | |
| config = PeftConfig.from_pretrained("lang-uk/dragoman") | |
| quant_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=float16, | |
| bnb_4bit_use_double_quant=False, | |
| ) | |
| model = MistralForCausalLM.from_pretrained( | |
| "mistralai/Mistral-7B-v0.1", quantization_config=quant_config | |
| ) | |
| # device_map="auto",) | |
| model = PeftModel.from_pretrained(model, "lang-uk/dragoman").to("cuda") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "mistralai/Mistral-7B-v0.1", use_fast=False, add_bos_token=False | |
| ) | |
| def translate(input_text): | |
| global log_queue | |
| # generated_text = "" | |
| input_text = input_text.strip() | |
| print(f"{datetime.utcnow()} | Translating: {input_text}") | |
| if False: #getenv("HF_API_TOKEN") is not None: | |
| try: | |
| logging_callback = log_data( | |
| hf_token=getenv("HF_API_TOKEN"), | |
| dataset_name=getenv("OUTPUT_DATASET"), | |
| private=True, | |
| ) | |
| logging_callback([[input_text]]) | |
| except: | |
| print("Error happened while pushing data to HF.") | |
| input_text = f"[INST] {input_text} [/INST]" | |
| inputs = tokenizer([input_text], return_tensors="pt").to(model.device) | |
| generation_kwargs = dict( | |
| inputs, max_new_tokens=200, num_beams=10, temperature=1, pad_token_id=tokenizer.eos_token_id | |
| ) # streamer=streamer, | |
| # streaming support | |
| # streamer = TextIteratorStreamer( | |
| # tokenizer, skip_prompt=True, skip_special_tokens=True | |
| # ) | |
| # thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| # thread.start() | |
| # for new_text in streamer: | |
| # generated_text += new_text | |
| # yield generated_text | |
| # generated_text += "\n" | |
| # yield generated_text | |
| output = model.generate(**generation_kwargs) | |
| output = ( | |
| tokenizer.decode(output[0], skip_special_tokens=True) | |
| .split("[/INST] ")[-1] | |
| .strip() | |
| ) | |
| return output | |
| # download description of the model | |
| desc_file = huggingface_hub.hf_hub_download("lang-uk/dragoman", "README.md") | |
| with open(desc_file, "r") as f: | |
| model_description = f.read() | |
| model_description = model_description[model_description.find("---", 1) + 5 :] | |
| model_description = ( | |
| """### By using this service, users are required to agree to the following terms: you agree that user input will be collected for future research and model improvements. \n\n""" | |
| + model_description | |
| ) | |
| iface = gr.Interface( | |
| fn=translate, | |
| inputs=gr.Textbox( | |
| value='This demo contains a model from paper "Setting up the Data Printer with Improved English to Ukrainian Machine Translation", accepted to UNLP 2024 workshop at the LREC-COLING 2024 conference.', | |
| label="Source sentence", | |
| ), | |
| outputs=gr.Textbox( | |
| value='Ця демо-версія містить модель із статті "Налаштування принтера даних із покращеним машинним перекладом з англійської на українську", яка була прийнята до семінару UNLP 2024 на конференції LREC-COLING 2024.', | |
| label="Translated sentence", | |
| ), | |
| examples=[ | |
| [ | |
| "The Colosseum in Rome was a symbol of the grandeur and power of the Roman Empire and was a place for the emperor to connect with the people by providing them with entertainment and free food." | |
| ], | |
| [ | |
| "How many leaves would it drop in a month of February in a non-leap year?", | |
| ], | |
| [ | |
| "ChatGPT (Chat Generative Pre-trained Transformer) is a chatbot developed by OpenAI and launched on November 30, 2022. Based on a large language model, it enables users to refine and steer a conversation towards a desired length, format, style, level of detail, and language. Successive prompts and replies, known as prompt engineering, are considered at each conversation stage as a context.[2] ", | |
| ], | |
| [ | |
| "who holds this neighborhood?", | |
| ], | |
| ], | |
| title="Dragoman: SOTA English-Ukrainian translation model", | |
| description='This demo contains a model from paper "Setting up the Data Printer with Improved English to Ukrainian Machine Translation", accepted to UNLP 2024 workshop at the LREC-COLING 2024 conference.', | |
| article=model_description, | |
| # thumbnail: str | None = None, | |
| # css: str | None = None, | |
| # batch: bool = False, | |
| # max_batch_size: int = 4, | |
| # api_name: str | Literal[False] | None = "predict", | |
| submit_btn="Translate", | |
| ) | |
| iface.launch() | |