bielik-czat / app.py
jglowa's picture
Update app.py
1c04354 verified
import os
import shutil
import time
import re
import chainlit as cl
from chainlit.input_widget import Slider
import tokeniser
from tavily import TavilyClient
from ddgs import DDGS
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings, PromptTemplate
from llama_index.core.callbacks import CallbackManager
from llama_index.core.callbacks.schema import CBEventType
from llama_index.core.llms import ChatMessage
from llama_index.llms.nvidia import NVIDIA
from llama_index.embeddings.nvidia import NVIDIAEmbedding
from transformers import pipeline
from collections import deque
from anyascii import anyascii
MAX_CONTEXT_WINDOW_TOKENS = 32000
MODEL = "speakleash/bielik-11b-v2.6-instruct"
GUARD_MODEL = "speakleash/Bielik-Guard-0.1B-v1.0"
EMBEDDING_MODEL = "baai/bge-m3"
TOP_K = 5
GUARD_THRESHOLD = 0.5
COMMANDS = [
{"id": "Wyszukaj", "icon": "globe", "description": "Wyszukaj w Internecie", "button": True, "persistent": True},
{"id": "Rozumuj", "icon": "brain", "description": "Rozumuj przed odpowiedzi膮", "button": True, "persistent": True},
{"id": "W+R", "icon": "badge-plus", "description": "Wyszukaj i Rozumuj", "button": True, "persistent": True},
{"id": "Chro艅", "icon": "shield", "description": "Chro艅 przed szkodliwymi tre艣ciami", "button": True, "persistent": True}
]
SYSTEM_PROMPT = "Jeste艣 pomocnym asystentem. Odpowiadaj wyczerpuj膮co w j臋zyku polskim na pytania u偶ytkownika. Dzisiaj jest " + time.strftime("%d.%m.%Y") + " r. "
SEARCH_SYSTEM_PROMPT = "Wykorzystaj poni偶sze wyniki wyszukiwania w Internecie. Na ko艅cu odpowiedzi umie艣膰 odno艣niki do stron 藕r贸d艂owych, na podstawie kt贸rych udzieli艂e艣 odpowiedzi. "
THINK_SYSTEM_PROMPT = "Dok艂adnie przeanalizuj pytania krok po kroku, wyra藕nie pokazuj膮c sw贸j proces rozumowania przed udzieleniem ostatecznej odpowiedzi. Ustrukturyzuj swoj膮 odpowied藕 w nast臋puj膮cy spos贸b: u偶yj znacznik贸w <think> i </think>, aby pokaza膰 szczeg贸艂owe etapy rozumowania, w tym analiz臋, podsumowanie, burz臋 m贸zg贸w, weryfikacj臋 dok艂adno艣ci, popraw臋 b艂臋d贸w i ponowne przeanalizowanie wcze艣niejszych punkt贸w, np. <think>{my艣li krok po kroku}</think>. U偶yj maksymalnie 500 s艂贸w na rozumowanie. Nast臋pnie wyra藕nie przedstaw ostateczn膮, dok艂adn膮, wyczerpuj膮c膮 odpowied藕 w oparciu o swoje rozumowanie. Odpowiadaj w j臋zyku polskim. "
TEXT_QA_SYSTEM_PROMPT = "Poni偶ej znajduje si臋 kontekst.\n---------------------\n{context_str}\n---------------------\nBior膮c pod uwag臋 kontekst, a nie wcze艣niejsz膮 wiedz臋, odpowiedz na pytanie.\nPytanie: {query_str}\nOdpowied藕: "
REFINE_SYSTEM_PROMPT = "Oryginalne zapytanie wygl膮da nast臋puj膮co: {query_str}\nPodali艣my tak膮 odpowied藕: {existing_answer}\nMamy mo偶liwo艣膰 doprecyzowania istniej膮cej odpowiedzi (tylko w razie potrzeby) o dodatkowy kontekst poni偶ej.\n------------\n{context_msg}\n------------\nBior膮c pod uwag臋 nowy kontekst, doprecyzuj oryginaln膮 odpowied藕, aby lepiej odpowiedzie膰 na zapytanie. Je艣li kontekst jest nieprzydatny, zwr贸膰 oryginaln膮 odpowied藕.\nDoprecyzowana odpowied藕: "
CB_IGNORE = [
CBEventType.CHUNKING,
CBEventType.SYNTHESIZE,
CBEventType.EMBEDDING,
CBEventType.NODE_PARSING,
CBEventType.TREE,
CBEventType.LLM
]
tavilyClient = TavilyClient(os.getenv("TAVILY_API_KEY"))
classifier = pipeline("text-classification", model=GUARD_MODEL, return_all_scores=True)
class BielikCallbackHandler(cl.LlamaIndexCallbackHandler):
def __init__(self):
super().__init__(event_starts_to_ignore=CB_IGNORE, event_ends_to_ignore=CB_IGNORE)
def on_event_start(self, event_type, payload = None, event_id = "", parent_id = "", **kwargs):
id = super().on_event_start(event_type, payload, event_id, parent_id, **kwargs)
if id in self.steps:
self.steps[id].show_input = False
return id
def truncate_messages(messages, max_tokens = MAX_CONTEXT_WINDOW_TOKENS):
if not messages:
return []
truncated = messages.copy()
if truncated and truncated[-1]["role"] == "assistant":
truncated = truncated[:-1]
total_tokens = 0
for i in range(len(truncated) - 1, -1, -1):
message_tokens = tokeniser.estimate_tokens(truncated[i]["content"])
total_tokens += message_tokens
if total_tokens > max_tokens:
truncated = truncated[i + 1:]
break
if truncated and truncated[-1]["role"] == "assistant":
truncated = truncated[:-1]
return truncated
@cl.step(name="wyszukiwanie", type="tool", show_input=False)
async def search_web(query):
try:
search_results = tavilyClient.search(query=query[:400], country="poland")["results"]
except Exception as e:
print(f"Tavily search failed: {e}. Falling back to DDGS.")
try:
search = DDGS().text(query, region="pl-pl", backend="duckduckgo, brave, google, mullvad_google, mullvad_brave", max_results=5)
search_results = [{"title": r["title"], "url": r["href"], "content": r["body"]} for r in search]
except Exception as e:
print(f"DDGS search failed: {e}")
return f"B艂膮d wyszukiwania: {str(e)}"
formatted_text = "Wyniki wyszukiwania:\n"
for i, result in enumerate(search_results, 1):
formatted_text += f"{i}. [{result['title']}]({result['url']})\n {result['content']}\n"
return formatted_text
@cl.step(name="rozumowanie", type="tool", show_input=False)
async def think(messages, llm):
current_step = cl.context.current_step
current_step.output = ""
stream = await infer(messages, llm)
think_content = ""
async for chunk in stream:
if chunk.delta:
think_content += chunk.delta
await current_step.stream_token(chunk.delta)
if think_content.endswith("</think>") or think_content.endswith("\n \n"):
break
return stream
async def infer(messages, llm):
return await llm.astream_chat([ChatMessage(role=m["role"], content=m["content"]) for m in messages])
async def ask_files(message, files):
dir = os.path.dirname(files[0].path)
cl.user_session.set("dir", dir)
documents = SimpleDirectoryReader(dir, exclude_hidden=False).load_data(show_progress=True)
index = VectorStoreIndex.from_documents(documents)
# index.storage_context.persist()
query_engine = index.as_query_engine(
streaming=True,
similarity_top_k=TOP_K,
service_context=Settings.callback_manager,
text_qa_template=PromptTemplate(TEXT_QA_SYSTEM_PROMPT),
refine_template=PromptTemplate(REFINE_SYSTEM_PROMPT)
)
return await query_engine.aquery(message)
@cl.step(name="klasyfikowanie", type="tool", show_input=False)
async def classify(message):
return classifier(re.sub(r"(?<=[A-Za-z])[\.,_-](?=[A-Za-z])", " ", anyascii(message)))[0]
def update_llm_settings(llm, settings):
llm.temperature = settings["Temp"]
llm.max_tokens = settings["MaxTokens"]
llm.additional_kwargs = {
"top_p": settings["TopP"],
"frequency_penalty": settings["FreqPenalty"],
"presence_penalty": settings["PresPenalty"]
}
return llm
async def run_chat(messages, files = None, search_enabled = False, think_enabled = False, guard_enabled = False):
llm = update_llm_settings(Settings.llm, cl.user_session.get("settings"))
msg = cl.Message(content="", author="Bielik")
response_content = ""
system_prompt = SYSTEM_PROMPT
curr_message = messages[-1]["content"].strip()
if not curr_message:
return
try:
if files:
stream = await ask_files(curr_message, files)
async for chunk in stream.response_gen:
response_content += chunk
await msg.stream_token(chunk)
await msg.send()
return response_content
if guard_enabled:
guard_results = await classify(curr_message)
if any(r["score"] > GUARD_THRESHOLD for r in guard_results):
msg.content = response_content = "Wykry艂em szkodliwe tre艣ci! Koniec rozmowy!"
await msg.send()
return response_content
if think_enabled:
system_prompt += THINK_SYSTEM_PROMPT
if search_enabled:
search_result = await search_web(curr_message)
system_prompt += SEARCH_SYSTEM_PROMPT + "\n\n" + search_result
context_messages = truncate_messages(messages)
messages = [{"role": "system", "content": system_prompt}, *context_messages]
print(messages)
stream = await think(messages, llm) if think_enabled else await infer(messages, llm)
async for chunk in stream:
if chunk.delta:
response_content += chunk.delta
await msg.stream_token(chunk.delta)
await msg.send()
response_content = re.sub(r'\\\[(.*?)\\\]', r'$$\1$$', re.sub(r'\\\((.*?)\\\)', r'$\1$', response_content, flags=re.S), flags=re.S) # LaTeX format
if guard_enabled:
guard_results = await classify(response_content)
if any(r["score"] > GUARD_THRESHOLD for r in guard_results):
response_content = "W mojej odpowiedzi wykry艂em szkodliwe tre艣ci. Gryz臋 si臋 w j臋zyk!"
msg.content = response_content
await msg.update()
return response_content
except Exception as e:
print(f"Response failed: {e}")
error_msg = f"B艂膮d generowania odpowiedzi: {str(e)}"
await cl.Message(content=error_msg).send()
return error_msg
@cl.on_chat_start
async def start_chat():
settings = await cl.ChatSettings([
Slider(id="Temp", label="Temperatura", initial=0.2, min=0, max=1, step=0.1),
Slider(id="TopP", label="Top P", initial=0.7, min=0.01, max=1, step=0.01),
Slider(id="FreqPenalty", label="Frequency Penalty", initial=0, min=-2, max=2, step=0.1),
Slider(id="PresPenalty", label="Presence Penalty", initial=0, min=-2, max=2, step=0.1),
Slider(id="MaxTokens", label="Max Token贸w", initial=4096, min=1, max=4096, step=64)
]).send()
await cl.context.emitter.set_commands(COMMANDS)
cl.user_session.set("chat_messages", [])
cl.user_session.set("settings", settings)
Settings.llm = NVIDIA(
model=MODEL,
is_chat_model=True,
context_window=32768,
temperature=settings["Temp"],
top_p=settings["TopP"],
max_tokens=settings["MaxTokens"],
frequency_penalty=settings["FreqPenalty"],
presence_penalty=settings["PresPenalty"],
streaming=True
)
Settings.embed_model = NVIDIAEmbedding(
model=EMBEDDING_MODEL
)
Settings.callback_manager = CallbackManager([BielikCallbackHandler()])
@cl.on_chat_end
def end_chat():
dir = cl.user_session.get("dir")
if dir and os.path.exists(dir):
shutil.rmtree(dir)
@cl.on_settings_update
async def setup_agent(settings):
cl.user_session.set("settings", settings)
@cl.on_message
async def on_message(msg):
chat_messages = cl.user_session.get("chat_messages", [])
chat_messages.append({"role": "user", "content": msg.content})
files = [el for el in msg.elements if el.path] or None
search_enabled = msg.command in ["Wyszukaj", "W+R"]
think_enabled = msg.command in ["Rozumuj", "W+R"]
guard_enabled = msg.command == "Chro艅"
response = await run_chat(chat_messages, files, search_enabled, think_enabled, guard_enabled)
chat_messages.append({"role": "assistant", "content": response})
cl.user_session.set("chat_messages", chat_messages)