novatra / main.py
smitb2005's picture
switched to json
73856fc
raw
history blame
5.69 kB
import os
import getpass
from groq import Groq
from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
import json
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain import hub
from langgraph.graph import START, StateGraph
from pydantic.main import BaseModel
from typing_extensions import List, TypedDict
from langchain_cohere import CohereEmbeddings
import re
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.concurrency import run_in_threadpool
import nltk, os
# Force nltk to use your bundled data
nltk.data.path.append(os.path.join(os.path.dirname(__file__), "nltk_data"))
# Disable downloading at runtime (since Hugging Face is read-only)
def no_download(*args, **kwargs):
return None
nltk.download = no_download
'''
if not os.environ.get("GROQ_API_KEY"):
os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ")
'''
load_dotenv()
# avoid printing secret values
print("GROQ_API_KEY set:", bool(os.getenv("GROQ_API_KEY")))
print("HUGGING_FACE_API_KEY set:", bool(os.getenv("HUGGING_FACE_API_KEY")))
print("COHERE_API_KEY set:", bool(os.getenv("COHERE_API_KEY") or os.getenv("COHERE")))
llm = init_chat_model("moonshotai/kimi-k2-instruct-0905", model_provider="groq", api_key=os.getenv("GROQ_API_KEY"))
'''
embeddings = HuggingFaceInferenceAPIEmbeddings(
api_key = os.getenv('HUGGING_FACE_API_KEY'),
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
embeddings = HuggingFaceInferenceAPIEmbeddings(
api_key=os.getenv('HUGGING_FACE_API_KEY'), model_name="sentence-transformers/all-MiniLM-L6-v2"
)'''
embeddings = CohereEmbeddings(
cohere_api_key=os.getenv("COHERE_API_KEY") or os.getenv("COHERE"),
model="embed-english-v3.0",
user_agent="langchain-cohere-embeddings"
)
vector_store = InMemoryVectorStore(embedding=embeddings)
# Load and parse comb.json instead of using the markdown loader
json_path = os.path.join(os.path.dirname(__file__), "comb.json")
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
def _collect_texts(node):
"""Recursively collect text strings from a JSON structure.
Supports: string, list, dict with common keys like 'text', 'content', 'body',
or a list of documents. Falls back to joining stringifiable values.
"""
texts = []
if isinstance(node, str):
texts.append(node)
elif isinstance(node, list):
for item in node:
texts.extend(_collect_texts(item))
elif isinstance(node, dict):
# common text keys
if "text" in node and isinstance(node["text"], str):
texts.append(node["text"])
elif "content" in node and isinstance(node["content"], str):
texts.append(node["content"])
elif "body" in node and isinstance(node["body"], str):
texts.append(node["body"])
elif "documents" in node and isinstance(node["documents"], list):
texts.extend(_collect_texts(node["documents"]))
else:
# fallback: stringify simple values
joined = " ".join(str(v) for v in node.values() if isinstance(v, (str, int, float)))
if joined:
texts.append(joined)
else:
texts.append(str(node))
return texts
raw_texts = _collect_texts(data)
# Split each raw text into chunks
all_splits = []
for t in raw_texts:
if not t:
continue
splits = text_splitter.split_text(t)
all_splits.extend(splits)
# Build Documents from split strings
docs = [Document(page_content=text) for text in all_splits]
_ = vector_store.add_documents(documents=docs)
prompt = hub.pull("rlm/rag-prompt")
class State(TypedDict):
question: str
context: List[Document]
answer: str
def retrieve(state: State):
retrieved_docs = vector_store.similarity_search(state["question"])
return {"context": retrieved_docs}
def generate(state: State):
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
messages = prompt.invoke({"question": state["question"], "context": docs_content})
response = llm.invoke(messages)
return {"answer": response.content}
graph_builder = StateGraph(State).add_sequence([retrieve, generate])
graph_builder.add_edge(START, "retrieve")
graph = graph_builder.compile()
'''
response = graph.invoke({"question": "Who should i contact for help ?"})
print(response["answer"])
'''
app = FastAPI()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE"],
allow_headers=["*"],
)
@app.get("/ping")
async def ping():
return "Pong!"
class Query(BaseModel):
question: str
@app.post("/chat")
async def chat(request: Query):
# run the blocking graph.invoke without blocking the event loop
result = await run_in_threadpool(lambda: graph.invoke({"question": request.question}))
answer = result.get("answer", "")
answer = str(answer)
answer = re.sub(r'<think>.*?</think>', '', answer, flags=re.DOTALL)
return {"response": answer}
if __name__ == "__main__":
import uvicorn
print("Starting uvicorn server on http://127.0.0.1:8000")
uvicorn.run("main:app", host="127.0.0.1", port=8000, log_level="info")