Dr-Bang / app.py
Cudd1es's picture
switched to QWEN
7405997 verified
import gradio as gr
import os
from openai import OpenAI
from retriever import (
load_collection, load_encoder, encode_query, retrieve_docs,
query_rerank, expand_with_neighbors, dedup_by_chapter_event
)
from sentence_transformers import CrossEncoder
QWEN_MODEL="qwen/qwen3-235b-a22b:free"
DEEPSEEK_MODEL="deepseek/deepseek-chat-v3.1:free"
GPT_OSS_MODEL="openai/gpt-oss-20b:free"
api_key = os.getenv("OPENROUTER_API_KEY")
#deepseek_key = os.getenv("DEEPSEEK_API_KEY")
client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=api_key)
# open ai method
#api_key = os.getenv("OPENAI_API_KEY")
#client = OpenAI(api_key=api_key)
#GPT_MODEL="gpt-4o"
collection = load_collection()
encoder = load_encoder()
reranker = CrossEncoder("BAAI/bge-reranker-large")
def reformulate_query(user_question, model_name=QWEN_MODEL):
prompt = f"""你是一个BangDream知识检索助手。请把用户的问题扩写或转写为适合知识库语义检索的检索语句,涵盖所有可能的提问方式或同义关键词。
用户问题:{user_question}
"""
resp = client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": prompt}],
temperature=0.1,
max_tokens=4096,
)
return resp.choices[0].message.content.strip()
def build_rag_prompt(query, context, system_message):
prompt = f"""{system_message}
你将获得多个独立的资料片段,请充分查阅每一条资料.
已知资料如下:
{context}
用户提问:{query}
规则:
1. 请参考所有已知资料, 并结合资料内容,简明、准确地回答问题。
2. 如果有多个相关答案或不同观点,可以考虑是否全部分点列出
3. 如果只能在部分资料里找到答案,也请说明是参考哪些资料内容
4. 如果不能确定答案,请如实说明理由,不要凭空编造。
"""
return prompt
def respond(
message,
history: list[dict[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
"""
message: 当前输入内容
history: [{"role": "user", "content": ...}, {"role": "assistant", "content": ...}, ...]
system_message: 自定义 System Prompt
"""
default_system_message = "你是BangDream知识问答助手, 也就是邦学家. 只能基于提供的资料内容作答。"
system_msg = (system_message or default_system_message).strip()
chat_history = [{"role": "system", "content": system_msg}]
# reformulate query
print("Reformulating...")
reformulated_query_text = reformulate_query(message)
print(f"[DEBUG] reformulated query: {reformulated_query_text}")
print("Thinking...\n...")
# rerank original query
query_vec = encode_query(encoder, message)
results = retrieve_docs(collection, query_vec, top_k=20)
reranked = query_rerank(reranker, message, results, top_n=10)
# rerank reformulated query
reformulated_query_vec = encode_query(encoder, reformulated_query_text)
reformulated_results = retrieve_docs(collection, reformulated_query_vec, top_k=20)
reformulated_reranked = query_rerank(reranker, reformulated_query_text, reformulated_results, top_n=10)
total_reranked = reranked + reformulated_reranked
deduped = dedup_by_chapter_event(total_reranked, max_per_group=1)
expanded_results = expand_with_neighbors(deduped[:5], collection)
context = []
for idx, text in enumerate(expanded_results):
context.append(text[0]) if text else ""
rag_prompt = build_rag_prompt(message, context, system_msg)
messages = [
{"role": "system", "content": system_msg},
{"role": "user", "content": rag_prompt}
]
response = ""
stream = client.chat.completions.create(
model=QWEN_MODEL,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stream=True
)
for chunk in stream:
delta = getattr(chunk.choices[0].delta, "content", None)
if delta:
response += delta
yield response
print("\n=== Answer ===")
print(response)
print("\n=== retrieved documents ===")
for idx, (context, score, meta) in enumerate(expanded_results, 1):
print(f"\n--- document {idx} (Score={score:.4f}) ---\n{context[:200]}...")
print(meta)
# ========== Gradio ChatInterface with extra sidebar inputs ==========
chatbot = gr.ChatInterface(
respond,
type="messages",
additional_inputs=[
gr.Textbox(value="你是BangDream知识问答助手, 只能基于提供资料内容作答。", label="System message"),
gr.Slider(minimum=64, maximum=8192, value=1536, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=2.0, value=0.2, step=0.05, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
#examples=[
# ["在水族馆里爱音和灯发生了什么?"],
# ["RAS的目标是什么?"],
#],
description="输入你关于BangDream的问题,邦学家会基于资料库为你检索并作答\n\nGitHub项目地址: [GitHub repo](https://github.com/Cudd1es/dr-bang?tab=readme-ov-file)",
title="Dr-Bang RAG QA Chatbot"
)
with gr.Blocks(title="Dr-Bang RAG QA") as demo:
with gr.Sidebar():
gr.Markdown(
"## Dr-Bang QA\n\n"
"[GitHub project](https://github.com/Cudd1es/dr-bang?tab=readme-ov-file)\n\n"
)
gr.LoginButton()
chatbot.render()
if __name__ == "__main__":
demo.launch()