|
|
import os |
|
|
import getpass |
|
|
from groq import Groq |
|
|
from langchain.chat_models import init_chat_model |
|
|
from langchain_core.messages import HumanMessage, SystemMessage |
|
|
from langchain_core.vectorstores import InMemoryVectorStore |
|
|
from langchain_core.documents import Document |
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
|
import json |
|
|
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings |
|
|
from langchain import hub |
|
|
from langgraph.graph import START, StateGraph |
|
|
from pydantic.main import BaseModel |
|
|
from typing_extensions import List, TypedDict |
|
|
|
|
|
from langchain_cohere import CohereEmbeddings |
|
|
|
|
|
import re |
|
|
from dotenv import load_dotenv |
|
|
from fastapi import FastAPI |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.responses import JSONResponse |
|
|
from fastapi.concurrency import run_in_threadpool |
|
|
import nltk, os |
|
|
|
|
|
|
|
|
nltk.data.path.append(os.path.join(os.path.dirname(__file__), "nltk_data")) |
|
|
|
|
|
|
|
|
def no_download(*args, **kwargs): |
|
|
return None |
|
|
|
|
|
nltk.download = no_download |
|
|
|
|
|
''' |
|
|
if not os.environ.get("GROQ_API_KEY"): |
|
|
os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ") |
|
|
''' |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
print("GROQ_API_KEY set:", bool(os.getenv("GROQ_API_KEY"))) |
|
|
print("HUGGING_FACE_API_KEY set:", bool(os.getenv("HUGGING_FACE_API_KEY"))) |
|
|
print("COHERE_API_KEY set:", bool(os.getenv("COHERE_API_KEY") or os.getenv("COHERE"))) |
|
|
|
|
|
|
|
|
llm = init_chat_model("moonshotai/kimi-k2-instruct-0905", model_provider="groq", api_key=os.getenv("GROQ_API_KEY")) |
|
|
''' |
|
|
embeddings = HuggingFaceInferenceAPIEmbeddings( |
|
|
api_key = os.getenv('HUGGING_FACE_API_KEY'), |
|
|
model_name="sentence-transformers/all-MiniLM-L6-v2" |
|
|
) |
|
|
|
|
|
embeddings = HuggingFaceInferenceAPIEmbeddings( |
|
|
api_key=os.getenv('HUGGING_FACE_API_KEY'), model_name="sentence-transformers/all-MiniLM-L6-v2" |
|
|
)''' |
|
|
|
|
|
embeddings = CohereEmbeddings( |
|
|
cohere_api_key=os.getenv("COHERE_API_KEY") or os.getenv("COHERE"), |
|
|
model="embed-english-v3.0", |
|
|
user_agent="langchain-cohere-embeddings" |
|
|
) |
|
|
|
|
|
vector_store = InMemoryVectorStore(embedding=embeddings) |
|
|
|
|
|
|
|
|
json_path = os.path.join(os.path.dirname(__file__), "comb.json") |
|
|
with open(json_path, "r", encoding="utf-8") as f: |
|
|
data = json.load(f) |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100) |
|
|
|
|
|
def _collect_texts(node): |
|
|
"""Recursively collect text strings from a JSON structure. |
|
|
|
|
|
Supports: string, list, dict with common keys like 'text', 'content', 'body', |
|
|
or a list of documents. Falls back to joining stringifiable values. |
|
|
""" |
|
|
texts = [] |
|
|
if isinstance(node, str): |
|
|
texts.append(node) |
|
|
elif isinstance(node, list): |
|
|
for item in node: |
|
|
texts.extend(_collect_texts(item)) |
|
|
elif isinstance(node, dict): |
|
|
|
|
|
if "text" in node and isinstance(node["text"], str): |
|
|
texts.append(node["text"]) |
|
|
elif "content" in node and isinstance(node["content"], str): |
|
|
texts.append(node["content"]) |
|
|
elif "body" in node and isinstance(node["body"], str): |
|
|
texts.append(node["body"]) |
|
|
elif "documents" in node and isinstance(node["documents"], list): |
|
|
texts.extend(_collect_texts(node["documents"])) |
|
|
else: |
|
|
|
|
|
joined = " ".join(str(v) for v in node.values() if isinstance(v, (str, int, float))) |
|
|
if joined: |
|
|
texts.append(joined) |
|
|
else: |
|
|
texts.append(str(node)) |
|
|
return texts |
|
|
|
|
|
raw_texts = _collect_texts(data) |
|
|
|
|
|
|
|
|
all_splits = [] |
|
|
for t in raw_texts: |
|
|
if not t: |
|
|
continue |
|
|
splits = text_splitter.split_text(t) |
|
|
all_splits.extend(splits) |
|
|
|
|
|
|
|
|
docs = [Document(page_content=text) for text in all_splits] |
|
|
_ = vector_store.add_documents(documents=docs) |
|
|
|
|
|
|
|
|
prompt = hub.pull("rlm/rag-prompt") |
|
|
|
|
|
class State(TypedDict): |
|
|
question: str |
|
|
context: List[Document] |
|
|
answer: str |
|
|
|
|
|
def retrieve(state: State): |
|
|
retrieved_docs = vector_store.similarity_search(state["question"]) |
|
|
return {"context": retrieved_docs} |
|
|
|
|
|
def generate(state: State): |
|
|
docs_content = "\n\n".join(doc.page_content for doc in state["context"]) |
|
|
messages = prompt.invoke({"question": state["question"], "context": docs_content}) |
|
|
response = llm.invoke(messages) |
|
|
return {"answer": response.content} |
|
|
|
|
|
graph_builder = StateGraph(State).add_sequence([retrieve, generate]) |
|
|
graph_builder.add_edge(START, "retrieve") |
|
|
graph = graph_builder.compile() |
|
|
''' |
|
|
response = graph.invoke({"question": "Who should i contact for help ?"}) |
|
|
print(response["answer"]) |
|
|
''' |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
origins = ["*"] |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=origins, |
|
|
allow_credentials=True, |
|
|
allow_methods=["GET", "POST", "PUT", "DELETE"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
@app.get("/ping") |
|
|
async def ping(): |
|
|
return "Pong!" |
|
|
|
|
|
class Query(BaseModel): |
|
|
question: str |
|
|
|
|
|
@app.post("/chat") |
|
|
async def chat(request: Query): |
|
|
|
|
|
result = await run_in_threadpool(lambda: graph.invoke({"question": request.question})) |
|
|
answer = result.get("answer", "") |
|
|
answer = str(answer) |
|
|
answer = re.sub(r'<think>.*?</think>', '', answer, flags=re.DOTALL) |
|
|
return {"response": answer} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
print("Starting uvicorn server on http://127.0.0.1:8000") |
|
|
uvicorn.run("main:app", host="127.0.0.1", port=8000, log_level="info") |
|
|
|