|
|
import gradio as gr |
|
|
import os |
|
|
from openai import OpenAI |
|
|
from retriever import ( |
|
|
load_collection, load_encoder, encode_query, retrieve_docs, |
|
|
query_rerank, expand_with_neighbors, dedup_by_chapter_event |
|
|
) |
|
|
from sentence_transformers import CrossEncoder |
|
|
|
|
|
QWEN_MODEL="qwen/qwen3-235b-a22b:free" |
|
|
DEEPSEEK_MODEL="deepseek/deepseek-chat-v3.1:free" |
|
|
GPT_OSS_MODEL="openai/gpt-oss-20b:free" |
|
|
api_key = os.getenv("OPENROUTER_API_KEY") |
|
|
|
|
|
client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=api_key) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
collection = load_collection() |
|
|
encoder = load_encoder() |
|
|
reranker = CrossEncoder("BAAI/bge-reranker-large") |
|
|
|
|
|
def reformulate_query(user_question, model_name=QWEN_MODEL): |
|
|
prompt = f"""你是一个BangDream知识检索助手。请把用户的问题扩写或转写为适合知识库语义检索的检索语句,涵盖所有可能的提问方式或同义关键词。 |
|
|
用户问题:{user_question} |
|
|
""" |
|
|
resp = client.chat.completions.create( |
|
|
model=model_name, |
|
|
messages=[{"role": "user", "content": prompt}], |
|
|
temperature=0.1, |
|
|
max_tokens=4096, |
|
|
) |
|
|
return resp.choices[0].message.content.strip() |
|
|
|
|
|
def build_rag_prompt(query, context, system_message): |
|
|
prompt = f"""{system_message} |
|
|
你将获得多个独立的资料片段,请充分查阅每一条资料. |
|
|
已知资料如下: |
|
|
{context} |
|
|
|
|
|
用户提问:{query} |
|
|
|
|
|
规则: |
|
|
1. 请参考所有已知资料, 并结合资料内容,简明、准确地回答问题。 |
|
|
2. 如果有多个相关答案或不同观点,可以考虑是否全部分点列出 |
|
|
3. 如果只能在部分资料里找到答案,也请说明是参考哪些资料内容 |
|
|
4. 如果不能确定答案,请如实说明理由,不要凭空编造。 |
|
|
""" |
|
|
return prompt |
|
|
|
|
|
|
|
|
def respond( |
|
|
message, |
|
|
history: list[dict[str, str]], |
|
|
system_message, |
|
|
max_tokens, |
|
|
temperature, |
|
|
top_p, |
|
|
): |
|
|
""" |
|
|
message: 当前输入内容 |
|
|
history: [{"role": "user", "content": ...}, {"role": "assistant", "content": ...}, ...] |
|
|
system_message: 自定义 System Prompt |
|
|
""" |
|
|
default_system_message = "你是BangDream知识问答助手, 也就是邦学家. 只能基于提供的资料内容作答。" |
|
|
system_msg = (system_message or default_system_message).strip() |
|
|
chat_history = [{"role": "system", "content": system_msg}] |
|
|
|
|
|
|
|
|
print("Reformulating...") |
|
|
reformulated_query_text = reformulate_query(message) |
|
|
print(f"[DEBUG] reformulated query: {reformulated_query_text}") |
|
|
|
|
|
print("Thinking...\n...") |
|
|
|
|
|
query_vec = encode_query(encoder, message) |
|
|
results = retrieve_docs(collection, query_vec, top_k=20) |
|
|
reranked = query_rerank(reranker, message, results, top_n=10) |
|
|
|
|
|
|
|
|
reformulated_query_vec = encode_query(encoder, reformulated_query_text) |
|
|
reformulated_results = retrieve_docs(collection, reformulated_query_vec, top_k=20) |
|
|
reformulated_reranked = query_rerank(reranker, reformulated_query_text, reformulated_results, top_n=10) |
|
|
|
|
|
total_reranked = reranked + reformulated_reranked |
|
|
deduped = dedup_by_chapter_event(total_reranked, max_per_group=1) |
|
|
expanded_results = expand_with_neighbors(deduped[:5], collection) |
|
|
context = [] |
|
|
for idx, text in enumerate(expanded_results): |
|
|
context.append(text[0]) if text else "" |
|
|
|
|
|
rag_prompt = build_rag_prompt(message, context, system_msg) |
|
|
messages = [ |
|
|
{"role": "system", "content": system_msg}, |
|
|
{"role": "user", "content": rag_prompt} |
|
|
] |
|
|
|
|
|
response = "" |
|
|
stream = client.chat.completions.create( |
|
|
model=QWEN_MODEL, |
|
|
messages=messages, |
|
|
temperature=temperature, |
|
|
max_tokens=max_tokens, |
|
|
top_p=top_p, |
|
|
stream=True |
|
|
) |
|
|
for chunk in stream: |
|
|
delta = getattr(chunk.choices[0].delta, "content", None) |
|
|
if delta: |
|
|
response += delta |
|
|
yield response |
|
|
|
|
|
print("\n=== Answer ===") |
|
|
print(response) |
|
|
print("\n=== retrieved documents ===") |
|
|
for idx, (context, score, meta) in enumerate(expanded_results, 1): |
|
|
print(f"\n--- document {idx} (Score={score:.4f}) ---\n{context[:200]}...") |
|
|
print(meta) |
|
|
|
|
|
|
|
|
|
|
|
chatbot = gr.ChatInterface( |
|
|
respond, |
|
|
type="messages", |
|
|
additional_inputs=[ |
|
|
gr.Textbox(value="你是BangDream知识问答助手, 只能基于提供资料内容作答。", label="System message"), |
|
|
gr.Slider(minimum=64, maximum=8192, value=1536, step=1, label="Max new tokens"), |
|
|
gr.Slider(minimum=0.1, maximum=2.0, value=0.2, step=0.05, label="Temperature"), |
|
|
gr.Slider( |
|
|
minimum=0.1, |
|
|
maximum=1.0, |
|
|
value=0.95, |
|
|
step=0.05, |
|
|
label="Top-p (nucleus sampling)", |
|
|
), |
|
|
], |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
description="输入你关于BangDream的问题,邦学家会基于资料库为你检索并作答\n\nGitHub项目地址: [GitHub repo](https://github.com/Cudd1es/dr-bang?tab=readme-ov-file)", |
|
|
title="Dr-Bang RAG QA Chatbot" |
|
|
) |
|
|
|
|
|
with gr.Blocks(title="Dr-Bang RAG QA") as demo: |
|
|
with gr.Sidebar(): |
|
|
gr.Markdown( |
|
|
"## Dr-Bang QA\n\n" |
|
|
"[GitHub project](https://github.com/Cudd1es/dr-bang?tab=readme-ov-file)\n\n" |
|
|
) |
|
|
gr.LoginButton() |
|
|
chatbot.render() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |