laverdes commited on
Commit
e235492
·
verified ·
1 Parent(s): 5026d1a

feat: extract_text tool for LangGraph agent

Browse files
Files changed (1) hide show
  1. retriever.py +31 -29
retriever.py CHANGED
@@ -1,35 +1,12 @@
1
- from smolagents import Tool
2
- from langchain_community.retrievers import BM25Retriever
3
- from langchain.docstore.document import Document
4
  import datasets
5
 
6
-
7
- class GuestInfoRetrieverTool(Tool):
8
- name = "guest_info_retriever"
9
- description = "Retrieves detailed information about gala guests based on their name or relation."
10
- inputs = {
11
- "query": {
12
- "type": "string",
13
- "description": "The name or relation of the guest you want information about."
14
- }
15
- }
16
- output_type = "string"
17
-
18
- def __init__(self, docs):
19
- self.is_initialized = False
20
- self.retriever = BM25Retriever.from_documents(docs)
21
-
22
-
23
- def forward(self, query: str):
24
- results = self.retriever.get_relevant_documents(query)
25
- if results:
26
- return "\n\n".join([doc.page_content for doc in results[:3]])
27
- else:
28
- return "No matching guest information found."
29
 
30
 
31
  def load_guest_dataset():
32
- # Load the dataset
33
  guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
34
 
35
  # Convert dataset entries into Document objects
@@ -46,8 +23,33 @@ def load_guest_dataset():
46
  for guest in guest_dataset
47
  ]
48
 
49
- # Return the tool
50
- return GuestInfoRetrieverTool(docs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
 
 
 
 
 
52
 
53
 
 
 
 
 
 
 
 
1
  import datasets
2
 
3
+ from langchain.docstore.document import Document
4
+ from langchain_community.retrievers import BM25Retriever
5
+ from langchain.tools import Tool
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  def load_guest_dataset():
9
+ """Loads the guest dataset and converts it into Document objects."""
10
  guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
11
 
12
  # Convert dataset entries into Document objects
 
23
  for guest in guest_dataset
24
  ]
25
 
26
+ # Return the documents
27
+ return docs
28
+
29
+
30
+ # Load the dataset
31
+ docs = load_guest_dataset()
32
+
33
+ # Initialize the retriever
34
+ bm25_retriever = BM25Retriever.from_documents(docs)
35
+
36
+
37
+ def extract_text(query: str) -> str:
38
+ """Retrieves detailed information about gala guests based on their name or relation."""
39
+ results = bm25_retriever.invoke(query)
40
+ if results:
41
+ return results[0].page_content # [doc.page_content for doc in results[:1]]), :3
42
+ else:
43
+ return "No matching guest information found."
44
+
45
 
46
+ guest_info_tool = Tool(
47
+ name="guest_info_retriever",
48
+ func=extract_text,
49
+ description="Retrieves detailed information about gala guests based on their name or relation."
50
+ )
51
 
52
 
53
+ if __name__ == "__main__":
54
+ query = "Marie"
55
+ print(f"query: {query}:\nretrieval: {extract_text(query)}")