import os
import shutil
import time
import re
import chainlit as cl
from chainlit.input_widget import Slider
import tokeniser
from tavily import TavilyClient
from ddgs import DDGS
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings, PromptTemplate
from llama_index.core.callbacks import CallbackManager
from llama_index.core.callbacks.schema import CBEventType
from llama_index.core.llms import ChatMessage
from llama_index.llms.nvidia import NVIDIA
from llama_index.embeddings.nvidia import NVIDIAEmbedding
from transformers import pipeline
from collections import deque
from anyascii import anyascii
MAX_CONTEXT_WINDOW_TOKENS = 32000
MODEL = "speakleash/bielik-11b-v2.6-instruct"
GUARD_MODEL = "speakleash/Bielik-Guard-0.1B-v1.0"
EMBEDDING_MODEL = "baai/bge-m3"
TOP_K = 5
GUARD_THRESHOLD = 0.5
COMMANDS = [
    {"id": "Wyszukaj", "icon": "globe", "description": "Wyszukaj w Internecie", "button": True, "persistent": True},
    {"id": "Rozumuj", "icon": "brain", "description": "Rozumuj przed odpowiedzią", "button": True, "persistent": True},
    {"id": "W+R", "icon": "badge-plus", "description": "Wyszukaj i Rozumuj", "button": True, "persistent": True},
    {"id": "Chroń", "icon": "shield", "description": "Chroń przed szkodliwymi treściami", "button": True, "persistent": True}
]
SYSTEM_PROMPT = "Jesteś pomocnym asystentem. Odpowiadaj wyczerpująco w języku polskim na pytania użytkownika. Dzisiaj jest " + time.strftime("%d.%m.%Y") + " r. "
SEARCH_SYSTEM_PROMPT = "Wykorzystaj poniższe wyniki wyszukiwania w Internecie. Na końcu odpowiedzi umieść odnośniki do stron źródłowych, na podstawie których udzieliłeś odpowiedzi. "
THINK_SYSTEM_PROMPT = "Dokładnie przeanalizuj pytania krok po kroku, wyraźnie pokazując swój proces rozumowania przed udzieleniem ostatecznej odpowiedzi. Ustrukturyzuj swoją odpowiedź w następujący sposób: użyj znaczników  i , aby pokazać szczegółowe etapy rozumowania, w tym analizę, podsumowanie, burzę mózgów, weryfikację dokładności, poprawę błędów i ponowne przeanalizowanie wcześniejszych punktów, np. {myśli krok po kroku}. Użyj maksymalnie 500 słów na rozumowanie. Następnie wyraźnie przedstaw ostateczną, dokładną, wyczerpującą odpowiedź w oparciu o swoje rozumowanie. Odpowiadaj w języku polskim. "
TEXT_QA_SYSTEM_PROMPT = "Poniżej znajduje się kontekst.\n---------------------\n{context_str}\n---------------------\nBiorąc pod uwagę kontekst, a nie wcześniejszą wiedzę, odpowiedz na pytanie.\nPytanie: {query_str}\nOdpowiedź: "
REFINE_SYSTEM_PROMPT = "Oryginalne zapytanie wygląda następująco: {query_str}\nPodaliśmy taką odpowiedź: {existing_answer}\nMamy możliwość doprecyzowania istniejącej odpowiedzi (tylko w razie potrzeby) o dodatkowy kontekst poniżej.\n------------\n{context_msg}\n------------\nBiorąc pod uwagę nowy kontekst, doprecyzuj oryginalną odpowiedź, aby lepiej odpowiedzieć na zapytanie. Jeśli kontekst jest nieprzydatny, zwróć oryginalną odpowiedź.\nDoprecyzowana odpowiedź: "
CB_IGNORE = [
    CBEventType.CHUNKING,
    CBEventType.SYNTHESIZE,
    CBEventType.EMBEDDING,
    CBEventType.NODE_PARSING,
    CBEventType.TREE,
    CBEventType.LLM
]
tavilyClient = TavilyClient(os.getenv("TAVILY_API_KEY"))
classifier = pipeline("text-classification", model=GUARD_MODEL, return_all_scores=True)
class BielikCallbackHandler(cl.LlamaIndexCallbackHandler):
    def __init__(self):
        super().__init__(event_starts_to_ignore=CB_IGNORE, event_ends_to_ignore=CB_IGNORE)
    def on_event_start(self, event_type, payload = None, event_id = "", parent_id = "", **kwargs):
        id = super().on_event_start(event_type, payload, event_id, parent_id, **kwargs)
        if id in self.steps:
            self.steps[id].show_input = False
        return id
def truncate_messages(messages, max_tokens = MAX_CONTEXT_WINDOW_TOKENS):
    if not messages:
        return []
    truncated = messages.copy()
    if truncated and truncated[-1]["role"] == "assistant":
        truncated = truncated[:-1]
    total_tokens = 0
    for i in range(len(truncated) - 1, -1, -1):
        message_tokens = tokeniser.estimate_tokens(truncated[i]["content"])
        total_tokens += message_tokens
        if total_tokens > max_tokens:
            truncated = truncated[i + 1:]
            break
    if truncated and truncated[-1]["role"] == "assistant":
        truncated = truncated[:-1]
    return truncated
@cl.step(name="wyszukiwanie", type="tool", show_input=False)
async def search_web(query):
    try:
        search_results = tavilyClient.search(query=query[:400], country="poland")["results"]
    except Exception as e:
        print(f"Tavily search failed: {e}. Falling back to DDGS.")
        try:
            search = DDGS().text(query, region="pl-pl", backend="duckduckgo, brave, google, mullvad_google, mullvad_brave", max_results=5)
            search_results = [{"title": r["title"], "url": r["href"], "content": r["body"]} for r in search]
        except Exception as e:
            print(f"DDGS search failed: {e}")
            return f"Błąd wyszukiwania: {str(e)}"
    formatted_text = "Wyniki wyszukiwania:\n"
    for i, result in enumerate(search_results, 1):
        formatted_text += f"{i}. [{result['title']}]({result['url']})\n   {result['content']}\n"
    return formatted_text
@cl.step(name="rozumowanie", type="tool", show_input=False)
async def think(messages, llm):
    current_step = cl.context.current_step
    current_step.output = ""
    stream = await infer(messages, llm)
    think_content = ""
    async for chunk in stream:
        if chunk.delta:
            think_content += chunk.delta
            await current_step.stream_token(chunk.delta)
            if think_content.endswith("") or think_content.endswith("\n \n"):
                break
    return stream
    
async def infer(messages, llm):
    return await llm.astream_chat([ChatMessage(role=m["role"], content=m["content"]) for m in messages])
async def ask_files(message, files):
    dir = os.path.dirname(files[0].path)
    cl.user_session.set("dir", dir)
    documents = SimpleDirectoryReader(dir, exclude_hidden=False).load_data(show_progress=True)
    index = VectorStoreIndex.from_documents(documents)
    # index.storage_context.persist()
    query_engine = index.as_query_engine(
        streaming=True,
        similarity_top_k=TOP_K,
        service_context=Settings.callback_manager,
        text_qa_template=PromptTemplate(TEXT_QA_SYSTEM_PROMPT),
        refine_template=PromptTemplate(REFINE_SYSTEM_PROMPT)
    )
    return await query_engine.aquery(message)
@cl.step(name="klasyfikowanie", type="tool", show_input=False)
async def classify(message):
    return classifier(re.sub(r"(?<=[A-Za-z])[\.,_-](?=[A-Za-z])", " ", anyascii(message)))[0]
def update_llm_settings(llm, settings):
    llm.temperature = settings["Temp"]
    llm.max_tokens = settings["MaxTokens"]
    llm.additional_kwargs = {
      "top_p": settings["TopP"],
      "frequency_penalty": settings["FreqPenalty"],
      "presence_penalty": settings["PresPenalty"]
    }
    return llm
async def run_chat(messages, files = None, search_enabled = False, think_enabled = False, guard_enabled = False):
    llm = update_llm_settings(Settings.llm, cl.user_session.get("settings"))
    msg = cl.Message(content="", author="Bielik")
    response_content = ""
    system_prompt = SYSTEM_PROMPT
    curr_message = messages[-1]["content"].strip()
    if not curr_message:
        return
    try:
        if files:
            stream = await ask_files(curr_message, files)
            async for chunk in stream.response_gen:
                response_content += chunk
                await msg.stream_token(chunk)
            await msg.send()
            return response_content
        if guard_enabled:
            guard_results = await classify(curr_message)
            if any(r["score"] > GUARD_THRESHOLD for r in guard_results):
                msg.content = response_content = "Wykryłem szkodliwe treści! Koniec rozmowy!"
                await msg.send()
                return response_content
        if think_enabled:
            system_prompt += THINK_SYSTEM_PROMPT
        if search_enabled:
            search_result = await search_web(curr_message)
            system_prompt += SEARCH_SYSTEM_PROMPT + "\n\n" + search_result
        context_messages = truncate_messages(messages)
        messages = [{"role": "system", "content": system_prompt}, *context_messages]
        print(messages)
        stream = await think(messages, llm) if think_enabled else await infer(messages, llm)
        async for chunk in stream:
            if chunk.delta:
                response_content += chunk.delta
                await msg.stream_token(chunk.delta)
        await msg.send()
        response_content = re.sub(r'\\\[(.*?)\\\]', r'$$\1$$', re.sub(r'\\\((.*?)\\\)', r'$\1$', response_content, flags=re.S), flags=re.S) # LaTeX format
        if guard_enabled:
            guard_results = await classify(response_content)
            if any(r["score"] > GUARD_THRESHOLD for r in guard_results):
                response_content = "W mojej odpowiedzi wykryłem szkodliwe treści. Gryzę się w język!"
        msg.content = response_content
        await msg.update()
        return response_content
    except Exception as e:
        print(f"Response failed: {e}")
        error_msg = f"Błąd generowania odpowiedzi: {str(e)}"
        await cl.Message(content=error_msg).send()
        return error_msg
@cl.on_chat_start
async def start_chat():
    settings = await cl.ChatSettings([
        Slider(id="Temp", label="Temperatura", initial=0.2, min=0, max=1, step=0.1),
        Slider(id="TopP", label="Top P", initial=0.7, min=0.01, max=1, step=0.01),
        Slider(id="FreqPenalty", label="Frequency Penalty", initial=0, min=-2, max=2, step=0.1),
        Slider(id="PresPenalty", label="Presence Penalty", initial=0, min=-2, max=2, step=0.1),
        Slider(id="MaxTokens", label="Max Tokenów", initial=4096, min=1, max=4096, step=64)
    ]).send()
    await cl.context.emitter.set_commands(COMMANDS)
    cl.user_session.set("chat_messages", [])
    cl.user_session.set("settings", settings)
    Settings.llm = NVIDIA(
        model=MODEL,
        is_chat_model=True,
        context_window=32768,
        temperature=settings["Temp"],
        top_p=settings["TopP"],
        max_tokens=settings["MaxTokens"],
        frequency_penalty=settings["FreqPenalty"],
        presence_penalty=settings["PresPenalty"],
        streaming=True
    )
    Settings.embed_model = NVIDIAEmbedding(
        model=EMBEDDING_MODEL
    )
    Settings.callback_manager = CallbackManager([BielikCallbackHandler()])
@cl.on_chat_end
def end_chat():
    dir = cl.user_session.get("dir")
    if dir and os.path.exists(dir):
        shutil.rmtree(dir)
@cl.on_settings_update
async def setup_agent(settings):
    cl.user_session.set("settings", settings)
@cl.on_message
async def on_message(msg):
    chat_messages = cl.user_session.get("chat_messages", [])
    chat_messages.append({"role": "user", "content": msg.content})
    files = [el for el in msg.elements if el.path] or None
    search_enabled = msg.command in ["Wyszukaj", "W+R"]
    think_enabled = msg.command in ["Rozumuj", "W+R"]
    guard_enabled = msg.command == "Chroń"
    response = await run_chat(chat_messages, files, search_enabled, think_enabled, guard_enabled)
    chat_messages.append({"role": "assistant", "content": response})
    cl.user_session.set("chat_messages", chat_messages)