import os import getpass from groq import Groq from langchain.chat_models import init_chat_model from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.vectorstores import InMemoryVectorStore from langchain_core.documents import Document from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_community.document_loaders import UnstructuredMarkdownLoader from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings from langchain import hub from langgraph.graph import START, StateGraph from pydantic.main import BaseModel from typing_extensions import List, TypedDict from langchain_cohere import CohereEmbeddings import re from dotenv import load_dotenv from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from fastapi.concurrency import run_in_threadpool import nltk, os # Force nltk to use your bundled data nltk.data.path.append(os.path.join(os.path.dirname(__file__), "nltk_data")) # Disable downloading at runtime (since Hugging Face is read-only) def no_download(*args, **kwargs): return None nltk.download = no_download ''' if not os.environ.get("GROQ_API_KEY"): os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ") ''' load_dotenv() # avoid printing secret values print("GROQ_API_KEY set:", bool(os.getenv("GROQ_API_KEY"))) print("HUGGING_FACE_API_KEY set:", bool(os.getenv("HUGGING_FACE_API_KEY"))) print("COHERE_API_KEY set:", bool(os.getenv("COHERE_API_KEY") or os.getenv("COHERE"))) llm = init_chat_model("moonshotai/kimi-k2-instruct-0905", model_provider="groq", api_key=os.getenv("GROQ_API_KEY")) ''' embeddings = HuggingFaceInferenceAPIEmbeddings( api_key = os.getenv('HUGGING_FACE_API_KEY'), model_name="sentence-transformers/all-MiniLM-L6-v2" ) embeddings = HuggingFaceInferenceAPIEmbeddings( api_key=os.getenv('HUGGING_FACE_API_KEY'), model_name="sentence-transformers/all-MiniLM-L6-v2" )''' embeddings = CohereEmbeddings( cohere_api_key=os.getenv("COHERE_API_KEY") or os.getenv("COHERE"), model="embed-english-v3.0", user_agent="langchain-cohere-embeddings" ) vector_store = InMemoryVectorStore(embedding=embeddings) md_loader = UnstructuredMarkdownLoader('comb.md') text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100) # all_splits = text_splitter.split_text(data_1 + "\n\n" + data_2 + "\n\n" + data_3 + "\n\n" + data_4) # all_splits = text_splitter.split_text(comb) all_splits = text_splitter.split_documents(md_loader.load()) # docs = [Document(page_content=text) for text in all_splits] docs = [Document(page_content=text.page_content, metadata=text.metadata) for text in all_splits] _ = vector_store.add_documents(documents=docs) prompt = hub.pull("rlm/rag-prompt") class State(TypedDict): question: str context: List[Document] answer: str def retrieve(state: State): retrieved_docs = vector_store.similarity_search(state["question"]) return {"context": retrieved_docs} def generate(state: State): docs_content = "\n\n".join(doc.page_content for doc in state["context"]) messages = prompt.invoke({"question": state["question"], "context": docs_content}) response = llm.invoke(messages) return {"answer": response.content} graph_builder = StateGraph(State).add_sequence([retrieve, generate]) graph_builder.add_edge(START, "retrieve") graph = graph_builder.compile() ''' response = graph.invoke({"question": "Who should i contact for help ?"}) print(response["answer"]) ''' app = FastAPI() origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["GET", "POST", "PUT", "DELETE"], allow_headers=["*"], ) @app.get("/ping") async def ping(): return "Pong!" class Query(BaseModel): question: str @app.post("/chat") async def chat(request: Query): # run the blocking graph.invoke without blocking the event loop result = await run_in_threadpool(lambda: graph.invoke({"question": request.question})) answer = result.get("answer", "") answer = str(answer) answer = re.sub(r'.*?', '', answer, flags=re.DOTALL) return {"response": answer} if __name__ == "__main__": import uvicorn print("Starting uvicorn server on http://127.0.0.1:8000") uvicorn.run("main:app", host="127.0.0.1", port=8000, log_level="info")