from langchain.docstore.document import Document from langchain.text_splitter import RecursiveCharacterTextSplitter from smolagents import Tool from langchain_community.retrievers import BM25Retriever from smolagents import CodeAgent, InferenceClientModel from datasets import load_dataset import re import pandas as pd #%% class QuestionRetrieverTool(Tool): name = "Question_retriever" description = "Uses semantic search to retrieve relevant question given the class, difficulty, and topic inputs by the user." inputs = { "query": { "type": "string", "description": "This tool returns relevant question and answer pairs based on the provided context.", } } output_type = "string" def __init__(self, docs, **kwargs): super().__init__(**kwargs) self.retriever = BM25Retriever.from_documents( docs, k=5 # Retrieve the top 5 documents ) def forward(self, query: str) -> str: assert isinstance(query, str), "Your search query must be a string" docs = self.retriever.invoke( query, ) return "\nRetrieved example question and answer pairs:\n" + "".join( [ f"\n\n===== Q and A pairs {str(i)} =====\n" + doc.page_content for i, doc in enumerate(docs) ] )