Spaces:
Sleeping
Sleeping
Update scripts/router_chain.py
Browse files- scripts/router_chain.py +39 -10
scripts/router_chain.py
CHANGED
|
@@ -44,23 +44,52 @@ User request: {input}
|
|
| 44 |
# chain = prompt | llm | StrOutputParser()
|
| 45 |
# return {"result": chain.invoke({"input": input_dict["input"]})}
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
elif category == "summarize":
|
| 48 |
-
# 1
|
| 49 |
rag_result = general_qa({"query": input_dict["input"]})
|
| 50 |
|
| 51 |
-
# 2
|
| 52 |
-
source_docs = rag_result.get("source_documents", [])
|
| 53 |
-
combined_text = "\n\n".join([doc.page_content for doc in source_docs])
|
| 54 |
|
| 55 |
-
# 3
|
|
|
|
| 56 |
from scripts.summarizer import get_summarizer
|
|
|
|
| 57 |
summarizer_chain = get_summarizer()
|
| 58 |
-
summary = summarizer_chain.run(combined_text)
|
| 59 |
|
| 60 |
-
#
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
return {"result": summary}
|
| 66 |
|
|
|
|
| 44 |
# chain = prompt | llm | StrOutputParser()
|
| 45 |
# return {"result": chain.invoke({"input": input_dict["input"]})}
|
| 46 |
|
| 47 |
+
#elif category == "summarize":
|
| 48 |
+
# # 1. Use RAG to retrieve relevant docs
|
| 49 |
+
# rag_result = general_qa({"query": input_dict["input"]})
|
| 50 |
+
|
| 51 |
+
# # 2. Extract docs and prepare text
|
| 52 |
+
# source_docs = rag_result.get("source_documents", [])
|
| 53 |
+
# combined_text = "\n\n".join([doc.page_content for doc in source_docs])
|
| 54 |
+
|
| 55 |
+
# # 3. Run the summarizer chain on the retrieved text
|
| 56 |
+
# from scripts.summarizer import get_summarizer
|
| 57 |
+
# summarizer_chain = get_summarizer()
|
| 58 |
+
# summary = summarizer_chain.run(combined_text)
|
| 59 |
+
|
| 60 |
+
# # 4. Add sources if any
|
| 61 |
+
# sources = list({str(doc.metadata.get("source", "unknown")) for doc in source_docs})
|
| 62 |
+
# if sources:
|
| 63 |
+
# summary += f"\n\n📚 Sources: {', '.join(sources)}"
|
| 64 |
+
|
| 65 |
+
# return {"result": summary}
|
| 66 |
+
|
| 67 |
+
|
| 68 |
elif category == "summarize":
|
| 69 |
+
# 1) Retrieve relevant documents via your existing RAG chain
|
| 70 |
rag_result = general_qa({"query": input_dict["input"]})
|
| 71 |
|
| 72 |
+
# 2) Get the retrieved docs (already LangChain Document objects)
|
| 73 |
+
source_docs = rag_result.get("source_documents", []) or []
|
|
|
|
| 74 |
|
| 75 |
+
# 3) Build the summarizer and prepare the docs list
|
| 76 |
+
from langchain.docstore.document import Document
|
| 77 |
from scripts.summarizer import get_summarizer
|
| 78 |
+
|
| 79 |
summarizer_chain = get_summarizer()
|
|
|
|
| 80 |
|
| 81 |
+
# If retrieval returned nothing, fall back to summarizing the user’s text
|
| 82 |
+
docs = source_docs if source_docs else [Document(page_content=input_dict["input"])]
|
| 83 |
+
|
| 84 |
+
# 4) Summarize — load_summarize_chain returns {"output_text": "..."}
|
| 85 |
+
out = summarizer_chain.invoke(docs)
|
| 86 |
+
summary = out["output_text"] if isinstance(out, dict) and "output_text" in out else str(out)
|
| 87 |
+
|
| 88 |
+
# 5) Append sources (only if we actually had retrieved docs)
|
| 89 |
+
if source_docs:
|
| 90 |
+
sources = sorted({str(d.metadata.get("source", "unknown")) for d in source_docs})
|
| 91 |
+
if sources:
|
| 92 |
+
summary += f"\n\n📚 Sources: {', '.join(sources)}"
|
| 93 |
|
| 94 |
return {"result": summary}
|
| 95 |
|