Update app.py
Browse files
app.py
CHANGED
|
@@ -37,6 +37,7 @@ from typing import List, Annotated, Any
|
|
| 37 |
import re, operator
|
| 38 |
|
| 39 |
|
|
|
|
| 40 |
class MultiAgentState(BaseModel):
|
| 41 |
state: List[str] = []
|
| 42 |
messages: Annotated[list[AnyMessage], add_messages]
|
|
@@ -54,22 +55,38 @@ class StoryState(BaseModel):
|
|
| 54 |
stories_lst: Annotated[list, operator.add]
|
| 55 |
|
| 56 |
class DocumentRAG:
|
| 57 |
-
def __init__(self):
|
| 58 |
self.document_store = None
|
| 59 |
self.qa_chain = None
|
| 60 |
self.document_summary = ""
|
| 61 |
self.chat_history = []
|
| 62 |
self.last_processed_time = None
|
| 63 |
-
self.api_key = os.getenv("OPENAI_API_KEY")
|
| 64 |
self.init_time = datetime.now(pytz.UTC)
|
| 65 |
-
|
| 66 |
-
if not self.api_key:
|
| 67 |
-
raise ValueError("API Key not found. Make sure to set the 'OPENAI_API_KEY' environment variable.")
|
| 68 |
|
| 69 |
# Persistent directory for Chroma to avoid tenant-related errors
|
| 70 |
self.chroma_persist_dir = "./chroma_storage"
|
| 71 |
os.makedirs(self.chroma_persist_dir, exist_ok=True)
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
def process_documents(self, uploaded_files):
|
| 74 |
"""Process uploaded files by saving them temporarily and extracting content."""
|
| 75 |
if not self.api_key:
|
|
@@ -118,7 +135,7 @@ class DocumentRAG:
|
|
| 118 |
self.document_text = " ".join([doc.page_content for doc in documents]) # Store for later use
|
| 119 |
|
| 120 |
# Create embeddings and initialize retrieval chain
|
| 121 |
-
embeddings =
|
| 122 |
self.document_store = Chroma.from_documents(
|
| 123 |
documents,
|
| 124 |
embeddings,
|
|
@@ -294,29 +311,53 @@ class DocumentRAG:
|
|
| 294 |
def topic_extractor(self, state: MultiAgentState):
|
| 295 |
return {"sub_topic_list": self.extract_subtopics(state.sub_topics)}
|
| 296 |
|
| 297 |
-
def retrieve_docs(self, state: StoryState):
|
| 298 |
-
retriever = self.document_store.as_retriever(search_kwargs={"k": 20})
|
| 299 |
-
docs = retriever.get_relevant_documents(f"information about {state.story_topic}")
|
| 300 |
-
return {"retrieved_docs": docs}
|
| 301 |
|
| 302 |
-
def
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
|
|
|
| 307 |
|
| 308 |
-
def rerank_docs(self, state: StoryState):
|
| 309 |
topic = state.story_topic
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
docs = state.retrieved_docs
|
| 311 |
texts = [doc.page_content for doc in docs]
|
| 312 |
|
| 313 |
-
# Fallback: return top 5 if no reranker available
|
| 314 |
if not texts:
|
| 315 |
-
return {"reranked_docs": []}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
|
| 321 |
|
| 322 |
def run_multiagent_storygraph(self, topic: str, context: str):
|
|
@@ -324,9 +365,9 @@ class DocumentRAG:
|
|
| 324 |
|
| 325 |
# Define the story subgraph with reranking
|
| 326 |
story_graph = StateGraph(StoryState)
|
| 327 |
-
story_graph.add_node("Retrieve", self.
|
| 328 |
-
story_graph.add_node("Rerank", self.
|
| 329 |
-
story_graph.add_node("Generate", self.
|
| 330 |
story_graph.set_entry_point("Retrieve")
|
| 331 |
story_graph.add_edge("Retrieve", "Rerank")
|
| 332 |
story_graph.add_edge("Rerank", "Generate")
|
|
@@ -365,13 +406,9 @@ class DocumentRAG:
|
|
| 365 |
return result
|
| 366 |
|
| 367 |
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
# Initialize RAG system in session state
|
| 372 |
-
if "rag_system" not in st.session_state:
|
| 373 |
-
st.session_state.rag_system = DocumentRAG()
|
| 374 |
-
|
| 375 |
|
| 376 |
|
| 377 |
|
|
|
|
| 37 |
import re, operator
|
| 38 |
|
| 39 |
|
| 40 |
+
|
| 41 |
class MultiAgentState(BaseModel):
|
| 42 |
state: List[str] = []
|
| 43 |
messages: Annotated[list[AnyMessage], add_messages]
|
|
|
|
| 55 |
stories_lst: Annotated[list, operator.add]
|
| 56 |
|
| 57 |
class DocumentRAG:
|
| 58 |
+
def __init__(self, embedding_choice="OpenAI"):
|
| 59 |
self.document_store = None
|
| 60 |
self.qa_chain = None
|
| 61 |
self.document_summary = ""
|
| 62 |
self.chat_history = []
|
| 63 |
self.last_processed_time = None
|
| 64 |
+
self.api_key = os.getenv("OPENAI_API_KEY")
|
| 65 |
self.init_time = datetime.now(pytz.UTC)
|
| 66 |
+
self.embedding_choice = embedding_choice
|
|
|
|
|
|
|
| 67 |
|
| 68 |
# Persistent directory for Chroma to avoid tenant-related errors
|
| 69 |
self.chroma_persist_dir = "./chroma_storage"
|
| 70 |
os.makedirs(self.chroma_persist_dir, exist_ok=True)
|
| 71 |
|
| 72 |
+
|
| 73 |
+
def _get_embedding_model(self):
|
| 74 |
+
if self.embedding_choice == "OpenAI":
|
| 75 |
+
return OpenAIEmbeddings(api_key=self.api_key)
|
| 76 |
+
else:
|
| 77 |
+
from langchain.embeddings import CohereEmbeddings
|
| 78 |
+
return CohereEmbeddings(
|
| 79 |
+
model="embed-multilingual-light-v3.0",
|
| 80 |
+
cohere_api_key=os.getenv("COHERE_API_KEY")
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
if not self.api_key:
|
| 86 |
+
raise ValueError("API Key not found. Make sure to set the 'OPENAI_API_KEY' environment variable.")
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
|
| 90 |
def process_documents(self, uploaded_files):
|
| 91 |
"""Process uploaded files by saving them temporarily and extracting content."""
|
| 92 |
if not self.api_key:
|
|
|
|
| 135 |
self.document_text = " ".join([doc.page_content for doc in documents]) # Store for later use
|
| 136 |
|
| 137 |
# Create embeddings and initialize retrieval chain
|
| 138 |
+
embeddings = self._get_embedding_model()
|
| 139 |
self.document_store = Chroma.from_documents(
|
| 140 |
documents,
|
| 141 |
embeddings,
|
|
|
|
| 311 |
def topic_extractor(self, state: MultiAgentState):
|
| 312 |
return {"sub_topic_list": self.extract_subtopics(state.sub_topics)}
|
| 313 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
|
| 315 |
+
def retrieve_node(self, state: StoryState):
|
| 316 |
+
embedding = self._get_embedding_model()
|
| 317 |
+
retriever = Chroma(
|
| 318 |
+
persist_directory=self.chroma_persist_dir,
|
| 319 |
+
embedding_function=embedding
|
| 320 |
+
).as_retriever(search_kwargs={"k": 20})
|
| 321 |
|
|
|
|
| 322 |
topic = state.story_topic
|
| 323 |
+
query = f"information about {topic}"
|
| 324 |
+
docs = retriever.get_relevant_documents(query)
|
| 325 |
+
return {"retrieved_docs": docs, "question": query}
|
| 326 |
+
|
| 327 |
+
def rerank_node(self, state: StoryState):
|
| 328 |
+
topic = state.story_topic
|
| 329 |
+
query = f"Rerank documents based on how well they explain the topic {topic}"
|
| 330 |
docs = state.retrieved_docs
|
| 331 |
texts = [doc.page_content for doc in docs]
|
| 332 |
|
|
|
|
| 333 |
if not texts:
|
| 334 |
+
return {"reranked_docs": [], "question": query}
|
| 335 |
+
|
| 336 |
+
# Quick fallback: rank by length
|
| 337 |
+
top_docs = sorted(texts, key=lambda t: -len(t))[:5]
|
| 338 |
+
return {"reranked_docs": top_docs, "question": query}
|
| 339 |
+
|
| 340 |
+
|
| 341 |
|
| 342 |
+
def generate_story_node(self, state: StoryState):
|
| 343 |
+
context = "\n\n".join(state.reranked_docs)
|
| 344 |
+
topic = state.story_topic
|
| 345 |
+
|
| 346 |
+
system_message = f"""
|
| 347 |
+
Suppose you're a brilliant science storyteller.
|
| 348 |
+
You write stories that help middle schoolers understand complex science topics with fun and clarity.
|
| 349 |
+
Add subtle humor and make it engaging.
|
| 350 |
+
"""
|
| 351 |
+
prompt = f"""
|
| 352 |
+
Use the following context to write a fun and simple story explaining **{topic}** to a middle schooler:\n
|
| 353 |
+
Context:\n{context}\n\n
|
| 354 |
+
Story:
|
| 355 |
+
"""
|
| 356 |
+
|
| 357 |
+
msg = self.llm.invoke([SystemMessage(system_message), HumanMessage(prompt)])
|
| 358 |
+
return {"stories": msg}
|
| 359 |
+
|
| 360 |
+
|
| 361 |
|
| 362 |
|
| 363 |
def run_multiagent_storygraph(self, topic: str, context: str):
|
|
|
|
| 365 |
|
| 366 |
# Define the story subgraph with reranking
|
| 367 |
story_graph = StateGraph(StoryState)
|
| 368 |
+
story_graph.add_node("Retrieve", self.retrieve_node)
|
| 369 |
+
story_graph.add_node("Rerank", self.rerank_node)
|
| 370 |
+
story_graph.add_node("Generate", self.generate_story_node)
|
| 371 |
story_graph.set_entry_point("Retrieve")
|
| 372 |
story_graph.add_edge("Retrieve", "Rerank")
|
| 373 |
story_graph.add_edge("Rerank", "Generate")
|
|
|
|
| 406 |
return result
|
| 407 |
|
| 408 |
|
| 409 |
+
if "rag_system" not in st.session_state or st.session_state.embedding_model != embedding_choice:
|
| 410 |
+
st.session_state.embedding_model = embedding_choice
|
| 411 |
+
st.session_state.rag_system = DocumentRAG(embedding_choice=embedding_choice)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
|
| 413 |
|
| 414 |
|