Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| import os | |
| from collections.abc import Iterator | |
| from threading import Thread | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| BitsAndBytesConfig, | |
| TextIteratorStreamer, | |
| ) | |
| DESCRIPTION = """# Swallow-13B instruct""" | |
| if not torch.cuda.is_available(): | |
| DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>" | |
| if torch.cuda.is_available(): | |
| model_name = "tokyotech-llm/Swallow-13b-instruct-hf" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| quantization_config=BitsAndBytesConfig(load_in_8bit=True), | |
| low_cpu_mem_usage=True, | |
| device_map="auto", | |
| ) | |
| MAX_INPUT_TOKENS = 2048 | |
| PROMPT_DICT = { | |
| "prompt_input": ( | |
| "以下に、あるタスクを説明する指示があり、それに付随する入力が更なる文脈を提供しています。" | |
| "リクエストを適切に完了するための回答を記述してください。\n\n" | |
| "### 指示:\n{instruction}\n\n### 入力:\n{input}\n\n### 応答:" | |
| ), | |
| "prompt_no_input": ( | |
| "以下に、あるタスクを説明する指示があります。" | |
| "リクエストを適切に完了するための回答を記述してください。\n\n" | |
| "### 指示:\n{instruction}\n\n### 応答:" | |
| ), | |
| } | |
| def create_prompt(instruction: str, input_text: str | None = None) -> str: | |
| """Generate a prompt based on the given instruction and an optional input. | |
| If input is provided, it uses the 'prompt_input' template from PROMPT_DICT. | |
| If no input is provided, it uses the 'prompt_no_input' template. | |
| Args: | |
| instruction (str): The instruction describing the task. | |
| input_text (str | None): Additional input providing context for the task. Defaults to None. | |
| Returns: | |
| str: The generated prompt. | |
| """ | |
| if input_text: | |
| # Use the 'prompt_input' template when additional input is provided | |
| return PROMPT_DICT["prompt_input"].format(instruction=instruction, input=input_text) | |
| # Use the 'prompt_no_input' template when no additional input is provided | |
| return PROMPT_DICT["prompt_no_input"].format(instruction=instruction) | |
| def run( | |
| instruction: str, | |
| input_text: str | None = None, | |
| max_new_tokens: int = 256, | |
| temperature: float = 0.99, | |
| top_p: float = 0.95, | |
| ) -> Iterator[str]: | |
| if input_text == "": | |
| input_text = None | |
| prompt = create_prompt(instruction, input_text) | |
| input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt") | |
| if input_ids.shape[-1] > MAX_INPUT_TOKENS: | |
| error_message = f"Input exceeds maximum number of tokens ({MAX_INPUT_TOKENS})" | |
| raise gr.Error(error_message) | |
| streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) | |
| generate_kwargs = dict( | |
| {"input_ids": input_ids.to(model.device)}, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| ) | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| outputs = [] | |
| for text in streamer: | |
| outputs.append(text) | |
| yield "".join(outputs) | |
| def process_example(instruction: str, input_text: str) -> Iterator[str]: | |
| yield from run(instruction, input_text) | |
| with gr.Blocks(css_paths="style.css") as demo: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(): | |
| instruction = gr.Textbox(label="Instruction", lines=5) | |
| input_text = gr.Textbox(label="Input (optional)", lines=5) | |
| run_button = gr.Button() | |
| with gr.Accordion(label="Advanced Options", open=False): | |
| max_new_tokens = gr.Slider(label="Max New Tokens", minimum=1, maximum=1024, step=1, value=256) | |
| temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=2.0, step=0.01, value=0.99) | |
| top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, step=0.01, value=0.95) | |
| with gr.Column(): | |
| output = gr.Textbox(label="Output", lines=10) | |
| run_button.click( | |
| fn=run, | |
| inputs=[instruction, input_text, max_new_tokens, temperature, top_p], | |
| outputs=output, | |
| api_name="run", | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| "以下のトピックに関する詳細な情報を提供してください。", | |
| "東京工業大学の主なキャンパスについて教えてください。", | |
| ], | |
| [ | |
| "以下のトピックに関する詳細な情報を提供してください。", | |
| "夢オチとは何かについて教えてください。", | |
| ], | |
| ["暴れん坊将軍って誰のことですか?", ""], # noqa: RUF001 | |
| ], | |
| inputs=[instruction, input_text], | |
| outputs=output, | |
| fn=process_example, | |
| cache_examples=os.getenv("CACHE_EXAMPLES") == "1", | |
| api_name=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |