import gradio as gr import os from openai import OpenAI from retriever import ( load_collection, load_encoder, encode_query, retrieve_docs, query_rerank, expand_with_neighbors, dedup_by_chapter_event ) from sentence_transformers import CrossEncoder QWEN_MODEL="qwen/qwen3-235b-a22b:free" DEEPSEEK_MODEL="deepseek/deepseek-chat-v3.1:free" GPT_OSS_MODEL="openai/gpt-oss-20b:free" api_key = os.getenv("OPENROUTER_API_KEY") #deepseek_key = os.getenv("DEEPSEEK_API_KEY") client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=api_key) # open ai method #api_key = os.getenv("OPENAI_API_KEY") #client = OpenAI(api_key=api_key) #GPT_MODEL="gpt-4o" collection = load_collection() encoder = load_encoder() reranker = CrossEncoder("BAAI/bge-reranker-large") def reformulate_query(user_question, model_name=QWEN_MODEL): prompt = f"""你是一个BangDream知识检索助手。请把用户的问题扩写或转写为适合知识库语义检索的检索语句,涵盖所有可能的提问方式或同义关键词。 用户问题:{user_question} """ resp = client.chat.completions.create( model=model_name, messages=[{"role": "user", "content": prompt}], temperature=0.1, max_tokens=4096, ) return resp.choices[0].message.content.strip() def build_rag_prompt(query, context, system_message): prompt = f"""{system_message} 你将获得多个独立的资料片段,请充分查阅每一条资料. 已知资料如下: {context} 用户提问:{query} 规则: 1. 请参考所有已知资料, 并结合资料内容,简明、准确地回答问题。 2. 如果有多个相关答案或不同观点,可以考虑是否全部分点列出 3. 如果只能在部分资料里找到答案,也请说明是参考哪些资料内容 4. 如果不能确定答案,请如实说明理由,不要凭空编造。 """ return prompt def respond( message, history: list[dict[str, str]], system_message, max_tokens, temperature, top_p, ): """ message: 当前输入内容 history: [{"role": "user", "content": ...}, {"role": "assistant", "content": ...}, ...] system_message: 自定义 System Prompt """ default_system_message = "你是BangDream知识问答助手, 也就是邦学家. 只能基于提供的资料内容作答。" system_msg = (system_message or default_system_message).strip() chat_history = [{"role": "system", "content": system_msg}] # reformulate query print("Reformulating...") reformulated_query_text = reformulate_query(message) print(f"[DEBUG] reformulated query: {reformulated_query_text}") print("Thinking...\n...") # rerank original query query_vec = encode_query(encoder, message) results = retrieve_docs(collection, query_vec, top_k=20) reranked = query_rerank(reranker, message, results, top_n=10) # rerank reformulated query reformulated_query_vec = encode_query(encoder, reformulated_query_text) reformulated_results = retrieve_docs(collection, reformulated_query_vec, top_k=20) reformulated_reranked = query_rerank(reranker, reformulated_query_text, reformulated_results, top_n=10) total_reranked = reranked + reformulated_reranked deduped = dedup_by_chapter_event(total_reranked, max_per_group=1) expanded_results = expand_with_neighbors(deduped[:5], collection) context = [] for idx, text in enumerate(expanded_results): context.append(text[0]) if text else "" rag_prompt = build_rag_prompt(message, context, system_msg) messages = [ {"role": "system", "content": system_msg}, {"role": "user", "content": rag_prompt} ] response = "" stream = client.chat.completions.create( model=QWEN_MODEL, messages=messages, temperature=temperature, max_tokens=max_tokens, top_p=top_p, stream=True ) for chunk in stream: delta = getattr(chunk.choices[0].delta, "content", None) if delta: response += delta yield response print("\n=== Answer ===") print(response) print("\n=== retrieved documents ===") for idx, (context, score, meta) in enumerate(expanded_results, 1): print(f"\n--- document {idx} (Score={score:.4f}) ---\n{context[:200]}...") print(meta) # ========== Gradio ChatInterface with extra sidebar inputs ========== chatbot = gr.ChatInterface( respond, type="messages", additional_inputs=[ gr.Textbox(value="你是BangDream知识问答助手, 只能基于提供资料内容作答。", label="System message"), gr.Slider(minimum=64, maximum=8192, value=1536, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=2.0, value=0.2, step=0.05, label="Temperature"), gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)", ), ], #examples=[ # ["在水族馆里爱音和灯发生了什么?"], # ["RAS的目标是什么?"], #], description="输入你关于BangDream的问题,邦学家会基于资料库为你检索并作答\n\nGitHub项目地址: [GitHub repo](https://github.com/Cudd1es/dr-bang?tab=readme-ov-file)", title="Dr-Bang RAG QA Chatbot" ) with gr.Blocks(title="Dr-Bang RAG QA") as demo: with gr.Sidebar(): gr.Markdown( "## Dr-Bang QA\n\n" "[GitHub project](https://github.com/Cudd1es/dr-bang?tab=readme-ov-file)\n\n" ) gr.LoginButton() chatbot.render() if __name__ == "__main__": demo.launch()