Spaces:
Runtime error
Runtime error
| # app.py | |
| import os | |
| from pathlib import Path | |
| import torch | |
| from threading import Event, Thread | |
| from typing import List, Tuple | |
| # Importing necessary packages | |
| from transformers import AutoConfig, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer | |
| from optimum.intel.openvino import OVModelForCausalLM | |
| import openvino as ov | |
| import openvino.properties as props | |
| import openvino.properties.hint as hints | |
| import openvino.properties.streams as streams | |
| from gradio_helper import make_demo # UI logic import | |
| from llm_config import SUPPORTED_LLM_MODELS | |
| # Model configuration setup | |
| max_new_tokens = 256 | |
| model_language_value = "English" | |
| model_id_value = 'qwen2.5-0.5b-instruct' | |
| prepare_int4_model_value = True | |
| enable_awq_value = False | |
| device_value = 'CPU' | |
| model_to_run_value = 'INT4' | |
| pt_model_id = SUPPORTED_LLM_MODELS[model_language_value][model_id_value]["model_id"] | |
| pt_model_name = model_id_value.split("-")[0] | |
| int4_model_dir = Path(model_id_value) / "INT4_compressed_weights" | |
| int4_weights = int4_model_dir / "openvino_model.bin" | |
| model_configuration = SUPPORTED_LLM_MODELS[model_language_value][model_id_value] | |
| model_name = model_configuration["model_id"] | |
| start_message = model_configuration["start_message"] | |
| history_template = model_configuration.get("history_template") | |
| has_chat_template = model_configuration.get("has_chat_template", history_template is None) | |
| current_message_template = model_configuration.get("current_message_template") | |
| stop_tokens = model_configuration.get("stop_tokens") | |
| tokenizer_kwargs = model_configuration.get("tokenizer_kwargs", {}) | |
| # Model loading | |
| core = ov.Core() | |
| ov_config = { | |
| hints.performance_mode(): hints.PerformanceMode.LATENCY, | |
| streams.num(): "1", | |
| props.cache_dir(): "" | |
| } | |
| tok = AutoTokenizer.from_pretrained(int4_model_dir, trust_remote_code=True) | |
| ov_model = OVModelForCausalLM.from_pretrained( | |
| int4_model_dir, | |
| device=device_value, | |
| ov_config=ov_config, | |
| config=AutoConfig.from_pretrained(int4_model_dir, trust_remote_code=True), | |
| trust_remote_code=True, | |
| ) | |
| # Stopping criteria for token generation | |
| class StopOnTokens(StoppingCriteria): | |
| def __init__(self, token_ids): | |
| self.token_ids = token_ids | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
| return any(input_ids[0][-1] == stop_id for stop_id in self.token_ids) | |
| # Functions for chatbot logic | |
| def convert_history_to_token(history: List[Tuple[str, str]]): | |
| """ | |
| function for conversion history stored as list pairs of user and assistant messages to tokens according to model expected conversation template | |
| Params: | |
| history: dialogue history | |
| Returns: | |
| history in token format | |
| """ | |
| if pt_model_name == "baichuan2": | |
| system_tokens = tok.encode(start_message) | |
| history_tokens = [] | |
| for old_query, response in history[:-1]: | |
| round_tokens = [] | |
| round_tokens.append(195) | |
| round_tokens.extend(tok.encode(old_query)) | |
| round_tokens.append(196) | |
| round_tokens.extend(tok.encode(response)) | |
| history_tokens = round_tokens + history_tokens | |
| input_tokens = system_tokens + history_tokens | |
| input_tokens.append(195) | |
| input_tokens.extend(tok.encode(history[-1][0])) | |
| input_tokens.append(196) | |
| input_token = torch.LongTensor([input_tokens]) | |
| elif history_template is None or has_chat_template: | |
| messages = [{"role": "system", "content": start_message}] | |
| for idx, (user_msg, model_msg) in enumerate(history): | |
| if idx == len(history) - 1 and not model_msg: | |
| messages.append({"role": "user", "content": user_msg}) | |
| break | |
| if user_msg: | |
| messages.append({"role": "user", "content": user_msg}) | |
| if model_msg: | |
| messages.append({"role": "assistant", "content": model_msg}) | |
| input_token = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_tensors="pt") | |
| else: | |
| text = start_message + "".join( | |
| ["".join([history_template.format(num=round, user=item[0], assistant=item[1])]) for round, item in enumerate(history[:-1])] | |
| ) | |
| text += "".join( | |
| [ | |
| "".join( | |
| [ | |
| current_message_template.format( | |
| num=len(history) + 1, | |
| user=history[-1][0], | |
| assistant=history[-1][1], | |
| ) | |
| ] | |
| ) | |
| ] | |
| ) | |
| input_token = tok(text, return_tensors="pt", **tokenizer_kwargs).input_ids | |
| return input_token | |
| def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id): | |
| # Callback function for running chatbot on submit button click | |
| input_ids = convert_history_to_token(history) | |
| if input_ids.shape[1] > 2000: | |
| history = [history[-1]] | |
| input_ids = convert_history_to_token(history) | |
| streamer = TextIteratorStreamer(tok, timeout=3600.0, skip_prompt=True, skip_special_tokens=True) | |
| generate_kwargs = dict( | |
| input_ids=input_ids, | |
| max_new_tokens=256, | |
| temperature=temperature, | |
| do_sample=temperature > 0.0, | |
| top_p=top_p, | |
| top_k=top_k, | |
| repetition_penalty=repetition_penalty, | |
| streamer=streamer, | |
| ) | |
| stream_complete = Event() | |
| def generate_and_signal_complete(): | |
| ov_model.generate(**generate_kwargs) | |
| stream_complete.set() | |
| Thread(target=generate_and_signal_complete).start() | |
| partial_text = "" | |
| for new_text in streamer: | |
| partial_text += new_text | |
| history[-1][1] = partial_text | |
| yield history | |
| def request_cancel(): | |
| ov_model.request.cancel() | |
| # Gradio setup and launch | |
| demo = make_demo(run_fn=bot, stop_fn=request_cancel, title=f"OpenVINO {model_id_value} Chatbot", language=model_language_value) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True, share=True, server_name="0.0.0.0", server_port=7860) | |