import torch
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
import os
from threading import Thread
from accelerate import init_empty_weights
max_memory = {
    0: "30GiB", 
    "cpu": "64GiB",  
}
MODEL_LIST = ["THUDM/GLM-4-Z1-32B-0414"]
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = MODEL_LIST[0]
MODEL_NAME = "GLM-4-Z1-32B-0414"
TITLE = "
3ML-bot (Text Only)
"
DESCRIPTION = f"""
😊 A Multi-Lingual Analytical Chatbot. 
🚀 MODEL NOW: {MODEL_NAME}
"""
CSS = """
h1 {
    text-align: center;
    display: block;
}
"""
# Configure BitsAndBytes for 4-bit quantization
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
@spaces.GPU()
def stream_chat(message, history: list, temperature: float, max_length: int, top_p: float, top_k: int, penalty: float):
    
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        trust_remote_code=True,
        quantization_config=quantization_config,
        device_map="auto",
        max_memory=max_memory,
    )
        
    print(f'message is - {message}')
    print(f'history is - {history}')
    
    conversation = []
    if len(history) > 0:
        for prompt, answer in history:
            conversation.extend([
                {"role": "user", "content": prompt},
                {"role": "assistant", "content": answer}
            ])
    
    conversation.append({"role": "user", "content": message})
    
    print(f"Conversation is -\n{conversation}")
    input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True,
                                            return_tensors="pt", return_dict=True).to(model.device)
    streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        max_length=max_length,
        streamer=streamer,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        repetition_penalty=penalty,
        eos_token_id=[151329, 151336, 151338],
    )
    gen_kwargs = {**input_ids, **generate_kwargs}
    with torch.no_grad():
        thread = Thread(target=model.generate, kwargs=gen_kwargs)
        thread.start()
        buffer = ""
        for new_text in streamer:
            buffer += new_text
            yield buffer
chatbot = gr.Chatbot()
chat_input = gr.Textbox(
    interactive=True,
    placeholder="Enter your message here...",
    show_label=False,
)
EXAMPLES = [
    ["Analyze the geopolitical implications of recent technological advancements in AI ."],
    ["¿Cuáles son los desafíos éticos más importantes en el desarrollo de la inteligencia artificial general?"],
    ["从经济学和社会学角度分析,人工智能将如何改变未来的就业市场?"],
    ["ما هي التحديات الرئيسية التي تواجه تطوير الذكاء الاصطناعي في العالم العربي؟"],
    ["नैतिक कृत्रिम बुद्धिमत्ता विकास में सबसे बड़ी चुनौतियाँ क्या हैं? विस्तार से समझाइए।"],
    ["Кои са основните предизвикателства пред разработването на изкуствен интелект в България и Източна Европа?"],
    ["Explain the potential risks and benefits of quantum computing in national security contexts."],
    ["分析气候变化对全球经济不平等的影响,并提出可能的解决方案。"],
]
with gr.Blocks(css=CSS, theme="soft", fill_height=True) as demo:
    gr.HTML(TITLE)
    gr.HTML(DESCRIPTION)
    gr.ChatInterface(
        fn=stream_chat,
        textbox=chat_input,
        chatbot=chatbot,
        fill_height=True,
        additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
        additional_inputs=[
            gr.Slider(
                minimum=0,
                maximum=1,
                step=0.1,
                value=0.8,
                label="Temperature",
                render=False,
            ),
            gr.Slider(
                minimum=1024,
                maximum=8192,
                step=1,
                value=4096,
                label="Max Length",
                render=False,
            ),
            gr.Slider(
                minimum=0.0,
                maximum=1.0,
                step=0.1,
                value=1.0,
                label="top_p",
                render=False,
            ),
            gr.Slider(
                minimum=1,
                maximum=20,
                step=1,
                value=10,
                label="top_k",
                render=False,
            ),
            gr.Slider(
                minimum=0.0,
                maximum=2.0,
                step=0.1,
                value=1.0,
                label="Repetition penalty",
                render=False,
            ),
        ],
        examples=EXAMPLES,
    )
if __name__ == "__main__":
    demo.queue(api_open=False).launch(show_api=False, share=False)