STEM-Question-Generator / QuestionRetrieverTool.py
bhardwaj08sarthak's picture
Create QuestionRetrieverTool.py
f61abee verified
raw
history blame
1.36 kB
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from smolagents import Tool
from langchain_community.retrievers import BM25Retriever
from smolagents import CodeAgent, InferenceClientModel
from datasets import load_dataset
import re
import pandas as pd
#%%
class QuestionRetrieverTool(Tool):
name = "Question_retriever"
description = "Uses semantic search to retrieve relevant question given the class, difficulty, and topic inputs by the user."
inputs = {
"query": {
"type": "string",
"description": "This tool returns relevant question and answer pairs based on the provided context.",
}
}
output_type = "string"
def __init__(self, docs, **kwargs):
super().__init__(**kwargs)
self.retriever = BM25Retriever.from_documents(
docs, k=5 # Retrieve the top 5 documents
)
def forward(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string"
docs = self.retriever.invoke(
query,
)
return "\nRetrieved example question and answer pairs:\n" + "".join(
[
f"\n\n===== Q and A pairs {str(i)} =====\n" + doc.page_content
for i, doc in enumerate(docs)
]
)