File size: 5,691 Bytes
3ceb7bf
 
 
 
 
 
 
 
73856fc
3ceb7bf
 
 
 
 
 
 
 
 
 
 
 
 
 
dc518fa
 
 
 
 
 
 
 
 
 
3ceb7bf
 
 
 
 
 
 
 
 
 
 
 
 
 
f3a5ceb
3ceb7bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73856fc
 
 
 
3ceb7bf
 
 
73856fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ceb7bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import os
import getpass
from groq import Groq
from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
import json
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain import hub
from langgraph.graph import START, StateGraph
from pydantic.main import BaseModel
from typing_extensions import List, TypedDict

from langchain_cohere import CohereEmbeddings

import re
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.concurrency import run_in_threadpool
import nltk, os

# Force nltk to use your bundled data
nltk.data.path.append(os.path.join(os.path.dirname(__file__), "nltk_data"))

# Disable downloading at runtime (since Hugging Face is read-only)
def no_download(*args, **kwargs):
    return None

nltk.download = no_download

'''
if not os.environ.get("GROQ_API_KEY"):
    os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ")
'''

load_dotenv()

# avoid printing secret values
print("GROQ_API_KEY set:", bool(os.getenv("GROQ_API_KEY")))
print("HUGGING_FACE_API_KEY set:", bool(os.getenv("HUGGING_FACE_API_KEY")))
print("COHERE_API_KEY set:", bool(os.getenv("COHERE_API_KEY") or os.getenv("COHERE")))


llm = init_chat_model("moonshotai/kimi-k2-instruct-0905", model_provider="groq", api_key=os.getenv("GROQ_API_KEY"))
'''
embeddings = HuggingFaceInferenceAPIEmbeddings(
    api_key = os.getenv('HUGGING_FACE_API_KEY'),
    model_name="sentence-transformers/all-MiniLM-L6-v2"
)

embeddings = HuggingFaceInferenceAPIEmbeddings(
    api_key=os.getenv('HUGGING_FACE_API_KEY'), model_name="sentence-transformers/all-MiniLM-L6-v2"
)'''

embeddings = CohereEmbeddings(
    cohere_api_key=os.getenv("COHERE_API_KEY") or os.getenv("COHERE"),
    model="embed-english-v3.0",
    user_agent="langchain-cohere-embeddings"
)

vector_store = InMemoryVectorStore(embedding=embeddings)

# Load and parse comb.json instead of using the markdown loader
json_path = os.path.join(os.path.dirname(__file__), "comb.json")
with open(json_path, "r", encoding="utf-8") as f:
    data = json.load(f)

text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)

def _collect_texts(node):
    """Recursively collect text strings from a JSON structure.

    Supports: string, list, dict with common keys like 'text', 'content', 'body',
    or a list of documents. Falls back to joining stringifiable values.
    """
    texts = []
    if isinstance(node, str):
        texts.append(node)
    elif isinstance(node, list):
        for item in node:
            texts.extend(_collect_texts(item))
    elif isinstance(node, dict):
        # common text keys
        if "text" in node and isinstance(node["text"], str):
            texts.append(node["text"])
        elif "content" in node and isinstance(node["content"], str):
            texts.append(node["content"])
        elif "body" in node and isinstance(node["body"], str):
            texts.append(node["body"])
        elif "documents" in node and isinstance(node["documents"], list):
            texts.extend(_collect_texts(node["documents"]))
        else:
            # fallback: stringify simple values
            joined = " ".join(str(v) for v in node.values() if isinstance(v, (str, int, float)))
            if joined:
                texts.append(joined)
    else:
        texts.append(str(node))
    return texts

raw_texts = _collect_texts(data)

# Split each raw text into chunks
all_splits = []
for t in raw_texts:
    if not t:
        continue
    splits = text_splitter.split_text(t)
    all_splits.extend(splits)

# Build Documents from split strings
docs = [Document(page_content=text) for text in all_splits]
_ = vector_store.add_documents(documents=docs)


prompt = hub.pull("rlm/rag-prompt")

class State(TypedDict):
    question: str
    context: List[Document]
    answer: str

def retrieve(state: State):
    retrieved_docs = vector_store.similarity_search(state["question"])
    return {"context": retrieved_docs}

def generate(state: State):
    docs_content = "\n\n".join(doc.page_content for doc in state["context"])
    messages = prompt.invoke({"question": state["question"], "context": docs_content})
    response = llm.invoke(messages)
    return {"answer": response.content}

graph_builder = StateGraph(State).add_sequence([retrieve, generate])
graph_builder.add_edge(START, "retrieve")
graph = graph_builder.compile()
'''
response = graph.invoke({"question": "Who should i contact for help ?"})
print(response["answer"])
'''

app = FastAPI()

origins = ["*"]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["GET", "POST", "PUT", "DELETE"],
    allow_headers=["*"],
)

@app.get("/ping")
async def ping():
    return "Pong!"

class Query(BaseModel):
    question: str

@app.post("/chat")
async def chat(request: Query):
    # run the blocking graph.invoke without blocking the event loop
    result = await run_in_threadpool(lambda: graph.invoke({"question": request.question}))
    answer = result.get("answer", "")
    answer = str(answer)
    answer = re.sub(r'<think>.*?</think>', '', answer, flags=re.DOTALL)
    return {"response": answer}

if __name__ == "__main__":
    import uvicorn
    print("Starting uvicorn server on http://127.0.0.1:8000")
    uvicorn.run("main:app", host="127.0.0.1", port=8000, log_level="info")