novatra / main.py
smitb2005's picture
model changed
f3a5ceb
raw
history blame
4.45 kB
import os
import getpass
from groq import Groq
from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain import hub
from langgraph.graph import START, StateGraph
from pydantic.main import BaseModel
from typing_extensions import List, TypedDict
from langchain_cohere import CohereEmbeddings
import re
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.concurrency import run_in_threadpool
import nltk, os
# Force nltk to use your bundled data
nltk.data.path.append(os.path.join(os.path.dirname(__file__), "nltk_data"))
# Disable downloading at runtime (since Hugging Face is read-only)
def no_download(*args, **kwargs):
return None
nltk.download = no_download
'''
if not os.environ.get("GROQ_API_KEY"):
os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ")
'''
load_dotenv()
# avoid printing secret values
print("GROQ_API_KEY set:", bool(os.getenv("GROQ_API_KEY")))
print("HUGGING_FACE_API_KEY set:", bool(os.getenv("HUGGING_FACE_API_KEY")))
print("COHERE_API_KEY set:", bool(os.getenv("COHERE_API_KEY") or os.getenv("COHERE")))
llm = init_chat_model("moonshotai/kimi-k2-instruct-0905", model_provider="groq", api_key=os.getenv("GROQ_API_KEY"))
'''
embeddings = HuggingFaceInferenceAPIEmbeddings(
api_key = os.getenv('HUGGING_FACE_API_KEY'),
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
embeddings = HuggingFaceInferenceAPIEmbeddings(
api_key=os.getenv('HUGGING_FACE_API_KEY'), model_name="sentence-transformers/all-MiniLM-L6-v2"
)'''
embeddings = CohereEmbeddings(
cohere_api_key=os.getenv("COHERE_API_KEY") or os.getenv("COHERE"),
model="embed-english-v3.0",
user_agent="langchain-cohere-embeddings"
)
vector_store = InMemoryVectorStore(embedding=embeddings)
md_loader = UnstructuredMarkdownLoader('comb.md')
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
# all_splits = text_splitter.split_text(data_1 + "\n\n" + data_2 + "\n\n" + data_3 + "\n\n" + data_4)
# all_splits = text_splitter.split_text(comb)
all_splits = text_splitter.split_documents(md_loader.load())
# docs = [Document(page_content=text) for text in all_splits]
docs = [Document(page_content=text.page_content, metadata=text.metadata) for text in all_splits]
_ = vector_store.add_documents(documents=docs)
prompt = hub.pull("rlm/rag-prompt")
class State(TypedDict):
question: str
context: List[Document]
answer: str
def retrieve(state: State):
retrieved_docs = vector_store.similarity_search(state["question"])
return {"context": retrieved_docs}
def generate(state: State):
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
messages = prompt.invoke({"question": state["question"], "context": docs_content})
response = llm.invoke(messages)
return {"answer": response.content}
graph_builder = StateGraph(State).add_sequence([retrieve, generate])
graph_builder.add_edge(START, "retrieve")
graph = graph_builder.compile()
'''
response = graph.invoke({"question": "Who should i contact for help ?"})
print(response["answer"])
'''
app = FastAPI()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE"],
allow_headers=["*"],
)
@app.get("/ping")
async def ping():
return "Pong!"
class Query(BaseModel):
question: str
@app.post("/chat")
async def chat(request: Query):
# run the blocking graph.invoke without blocking the event loop
result = await run_in_threadpool(lambda: graph.invoke({"question": request.question}))
answer = result.get("answer", "")
answer = str(answer)
answer = re.sub(r'<think>.*?</think>', '', answer, flags=re.DOTALL)
return {"response": answer}
if __name__ == "__main__":
import uvicorn
print("Starting uvicorn server on http://127.0.0.1:8000")
uvicorn.run("main:app", host="127.0.0.1", port=8000, log_level="info")