biaotang commited on
Commit
1255246
·
1 Parent(s): 8c41b14

RAG tool for guest story

Browse files
Files changed (3) hide show
  1. .gitignore +3 -2
  2. app.py +14 -3
  3. retriever.py +48 -0
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
- .venv
2
- .env
 
 
1
+ .venv/
2
+ .env
3
+ __pycache__/
app.py CHANGED
@@ -1,7 +1,18 @@
1
  import gradio as gr
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
7
  demo.launch()
 
1
  import gradio as gr
2
+ from smolagents import CodeAgent, HfApiModel
3
+ from retriever import load_guest_dataset
4
 
5
+ # Initialize the Hugging Face model
6
+ model = HfApiModel()
7
+ guest_info_tool = load_guest_dataset()
8
 
9
+ # Create Alfred, our gala agent, with the guest info tool
10
+ alfred = CodeAgent(tools=[guest_info_tool], model=model)
11
+
12
+
13
+ def query_guest(guest_name):
14
+ return alfred.run(f"Tell me about our guest named '{guest_name}'.")
15
+
16
+
17
+ demo = gr.Interface(fn=query_guest, inputs="text", outputs="text")
18
  demo.launch()
retriever.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ from langchain.docstore.document import Document
3
+
4
+ from smolagents import Tool
5
+ from langchain_community.retrievers import BM25Retriever
6
+
7
+ class GuestInfoRetrieverTool(Tool):
8
+ name = "guest_info_retriever"
9
+ description = "Retrieves detailed information about gala guests based on their name or relation."
10
+ inputs = {
11
+ "query": {
12
+ "type": "string",
13
+ "description": "The name or relation of the guest you want information about."
14
+ }
15
+ }
16
+ output_type = "string"
17
+
18
+ def __init__(self, docs):
19
+ self.is_initialized = False
20
+ self.retriever = BM25Retriever.from_documents(docs)
21
+
22
+ def forward(self, query: str):
23
+ results = self.retriever.invoke(query)
24
+ if results:
25
+ return "\n\n".join([doc.page_content for doc in results[:3]])
26
+ else:
27
+ return "No matching guest information found."
28
+
29
+
30
+ def load_guest_dataset():
31
+ # Load the dataset
32
+ guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
33
+
34
+ # Convert dataset entries into Document objects
35
+ docs = [
36
+ Document(
37
+ page_content="\n".join([
38
+ f"Name: {guest['name']}",
39
+ f"Relation: {guest['relation']}",
40
+ f"Description: {guest['description']}",
41
+ f"Email: {guest['email']}"
42
+ ]),
43
+ metadata={"name": guest["name"]}
44
+ )
45
+ for guest in guest_dataset
46
+ ]
47
+
48
+ return GuestInfoRetrieverTool(docs)