MaryamKarimi080's picture
Update scripts/rag_chat.py
6ca7edb verified
raw
history blame
1.51 kB
import os
from pathlib import Path
from langchain.chains import RetrievalQA
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_chroma import Chroma
from langchain.prompts import PromptTemplate
BASE_DIR = Path(__file__).resolve().parent.parent
DB_DIR = BASE_DIR / "db"
def build_general_qa_chain(model_name=None):
if not DB_DIR.exists():
print("📦 No DB found. Building vectorstore...")
from scripts import load_documents, chunk_and_embed, setup_vectorstore
load_documents.main()
chunk_and_embed.main()
setup_vectorstore.main()
embedding = OpenAIEmbeddings(model="text-embedding-3-small")
vectorstore = Chroma(persist_directory=str(DB_DIR), embedding_function=embedding)
template = """Use the following context to answer the question.
If the answer isn't found in the context, use your general knowledge but say so.
Always cite your sources at the end with 'Source: <filename>' when using course materials.
Context: {context}
Question: {question}
Helpful Answer:"""
QA_PROMPT = PromptTemplate(
template=template,
input_variables=["context", "question"]
)
llm = ChatOpenAI(model_name=model_name or "gpt-4o-mini", temperature=0.0)
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=vectorstore.as_retriever(search_kwargs={"k": 4}),
chain_type_kwargs={"prompt": QA_PROMPT},
return_source_documents=True
)
return qa_chain