File size: 5,691 Bytes
3ceb7bf 73856fc 3ceb7bf dc518fa 3ceb7bf f3a5ceb 3ceb7bf 73856fc 3ceb7bf 73856fc 3ceb7bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import os
import getpass
from groq import Groq
from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
import json
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain import hub
from langgraph.graph import START, StateGraph
from pydantic.main import BaseModel
from typing_extensions import List, TypedDict
from langchain_cohere import CohereEmbeddings
import re
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.concurrency import run_in_threadpool
import nltk, os
# Force nltk to use your bundled data
nltk.data.path.append(os.path.join(os.path.dirname(__file__), "nltk_data"))
# Disable downloading at runtime (since Hugging Face is read-only)
def no_download(*args, **kwargs):
return None
nltk.download = no_download
'''
if not os.environ.get("GROQ_API_KEY"):
os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ")
'''
load_dotenv()
# avoid printing secret values
print("GROQ_API_KEY set:", bool(os.getenv("GROQ_API_KEY")))
print("HUGGING_FACE_API_KEY set:", bool(os.getenv("HUGGING_FACE_API_KEY")))
print("COHERE_API_KEY set:", bool(os.getenv("COHERE_API_KEY") or os.getenv("COHERE")))
llm = init_chat_model("moonshotai/kimi-k2-instruct-0905", model_provider="groq", api_key=os.getenv("GROQ_API_KEY"))
'''
embeddings = HuggingFaceInferenceAPIEmbeddings(
api_key = os.getenv('HUGGING_FACE_API_KEY'),
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
embeddings = HuggingFaceInferenceAPIEmbeddings(
api_key=os.getenv('HUGGING_FACE_API_KEY'), model_name="sentence-transformers/all-MiniLM-L6-v2"
)'''
embeddings = CohereEmbeddings(
cohere_api_key=os.getenv("COHERE_API_KEY") or os.getenv("COHERE"),
model="embed-english-v3.0",
user_agent="langchain-cohere-embeddings"
)
vector_store = InMemoryVectorStore(embedding=embeddings)
# Load and parse comb.json instead of using the markdown loader
json_path = os.path.join(os.path.dirname(__file__), "comb.json")
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
def _collect_texts(node):
"""Recursively collect text strings from a JSON structure.
Supports: string, list, dict with common keys like 'text', 'content', 'body',
or a list of documents. Falls back to joining stringifiable values.
"""
texts = []
if isinstance(node, str):
texts.append(node)
elif isinstance(node, list):
for item in node:
texts.extend(_collect_texts(item))
elif isinstance(node, dict):
# common text keys
if "text" in node and isinstance(node["text"], str):
texts.append(node["text"])
elif "content" in node and isinstance(node["content"], str):
texts.append(node["content"])
elif "body" in node and isinstance(node["body"], str):
texts.append(node["body"])
elif "documents" in node and isinstance(node["documents"], list):
texts.extend(_collect_texts(node["documents"]))
else:
# fallback: stringify simple values
joined = " ".join(str(v) for v in node.values() if isinstance(v, (str, int, float)))
if joined:
texts.append(joined)
else:
texts.append(str(node))
return texts
raw_texts = _collect_texts(data)
# Split each raw text into chunks
all_splits = []
for t in raw_texts:
if not t:
continue
splits = text_splitter.split_text(t)
all_splits.extend(splits)
# Build Documents from split strings
docs = [Document(page_content=text) for text in all_splits]
_ = vector_store.add_documents(documents=docs)
prompt = hub.pull("rlm/rag-prompt")
class State(TypedDict):
question: str
context: List[Document]
answer: str
def retrieve(state: State):
retrieved_docs = vector_store.similarity_search(state["question"])
return {"context": retrieved_docs}
def generate(state: State):
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
messages = prompt.invoke({"question": state["question"], "context": docs_content})
response = llm.invoke(messages)
return {"answer": response.content}
graph_builder = StateGraph(State).add_sequence([retrieve, generate])
graph_builder.add_edge(START, "retrieve")
graph = graph_builder.compile()
'''
response = graph.invoke({"question": "Who should i contact for help ?"})
print(response["answer"])
'''
app = FastAPI()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE"],
allow_headers=["*"],
)
@app.get("/ping")
async def ping():
return "Pong!"
class Query(BaseModel):
question: str
@app.post("/chat")
async def chat(request: Query):
# run the blocking graph.invoke without blocking the event loop
result = await run_in_threadpool(lambda: graph.invoke({"question": request.question}))
answer = result.get("answer", "")
answer = str(answer)
answer = re.sub(r'<think>.*?</think>', '', answer, flags=re.DOTALL)
return {"response": answer}
if __name__ == "__main__":
import uvicorn
print("Starting uvicorn server on http://127.0.0.1:8000")
uvicorn.run("main:app", host="127.0.0.1", port=8000, log_level="info")
|