laverdes commited on
Commit
1853d1e
·
verified ·
1 Parent(s): 5e3dc23

feat: search tools (tavily and wikipedia), youtube transcripts and image_query

Browse files
Files changed (1) hide show
  1. tools.py +205 -0
tools.py CHANGED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import inspect
4
+ import time
5
+ from typing import Callable
6
+
7
+ from datetime import datetime, timezone
8
+
9
+ from langchain.tools import tool
10
+
11
+ from langchain_community.tools.tavily_search import TavilySearchResults
12
+ from langchain_core.messages import HumanMessage
13
+ from langchain_google_genai.chat_models import ChatGoogleGenerativeAIError
14
+
15
+ from markitdown import MarkItDown
16
+ from langchain_tavily import TavilySearch, TavilyExtract
17
+ from langchain_google_genai import ChatGoogleGenerativeAI
18
+ from langchain_community.utilities.wikipedia import WikipediaAPIWrapper
19
+ from langchain_community.tools.wikipedia.tool import WikipediaQueryRun
20
+ from youtube_transcript_api import YouTubeTranscriptApi
21
+
22
+ from basic_agent import print_conversation
23
+
24
+ from dotenv import load_dotenv
25
+ from langchain.globals import set_debug
26
+ from urllib.parse import urlparse, parse_qs
27
+
28
+
29
+ set_debug(False)
30
+ CUSTOM_DEBUG = True
31
+
32
+ load_dotenv()
33
+
34
+
35
+ def encode_image_to_base64(path):
36
+ with open(path, "rb") as image_file:
37
+ return base64.b64encode(image_file.read()).decode("utf-8")
38
+
39
+
40
+ def print_tool_call(tool: Callable, tool_name: str, args: dict):
41
+ """Prints the tool call for debugging purposes."""
42
+ sig = inspect.signature(tool)
43
+ print_conversation(
44
+ messages=[
45
+ {
46
+ 'role': 'Tool-Call',
47
+ 'content': f"Calling `{tool_name}`{sig}"
48
+ },
49
+ {
50
+ 'role': 'Tool-Args',
51
+ 'content': args
52
+ }
53
+ ],
54
+ )
55
+
56
+
57
+ def print_tool_response(response: str):
58
+ """Prints the tool response for debugging purposes."""
59
+ print_conversation(
60
+ messages=[
61
+ {
62
+ 'role': 'Tool-Response',
63
+ 'content': response
64
+ }
65
+ ],
66
+ )
67
+
68
+
69
+ search_tool = TavilySearch(max_results=5)
70
+ extract_tool = TavilyExtract()
71
+
72
+
73
+ @tool
74
+ def search_and_extract(query: str) -> list[dict]:
75
+ """Performs a web search and returns structured content extracted from top results."""
76
+ time.sleep(3) # To avoid hitting the API rate limit in the llm-apis when calling the tool multiple times in a row.
77
+ if query in cache:
78
+ print(f"Cache hit for query: {query}")
79
+ return cache[query]
80
+ MAX_NUMBER_OF_CHARS = 10_000
81
+
82
+ if CUSTOM_DEBUG:
83
+ print_tool_call(
84
+ search_and_extract,
85
+ tool_name='search_and_extract',
86
+ args={'query': query, 'max_number_of_chars': MAX_NUMBER_OF_CHARS},
87
+ )
88
+
89
+ results = search_tool.invoke({"query": query})
90
+ raw_results = results.get("results", [])
91
+ urls = [r["url"] for r in raw_results if r.get("url")]
92
+
93
+ if not urls:
94
+ return [{"error": "No URLs found to extract from."}]
95
+
96
+ extracted = extract_tool.invoke({"urls": urls})
97
+ results = extracted.get("results", [])
98
+
99
+ structured_results = []
100
+ raw_contents = [doc.get("raw_content", "") for doc in results]
101
+
102
+ for result, doc_content in zip(raw_results, raw_contents):
103
+ doc_content_trunc = doc_content[0:MAX_NUMBER_OF_CHARS] if len(doc_content) > MAX_NUMBER_OF_CHARS else doc_content
104
+ structured_results.append({
105
+ "title": result.get("title"),
106
+ "url": result.get("url"),
107
+ "snippet": result.get("content"),
108
+ "raw_content": doc_content_trunc
109
+ })
110
+
111
+ if CUSTOM_DEBUG:
112
+ console_structured_results = [{k: v for k, v in result_dicti.items() if k != "raw_content"} for result_dicti in
113
+ structured_results]
114
+ print_tool_response(json.dumps(console_structured_results))
115
+ return structured_results
116
+
117
+
118
+
119
+ def extract_video_id(url: str) -> str:
120
+ parsed = urlparse(url)
121
+ return parse_qs(parsed.query).get("v", [""])[0]
122
+
123
+
124
+ @tool
125
+ def load_youtube_transcript(url: str) -> str:
126
+ """Load a YouTube transcript using youtube_transcript_api."""
127
+
128
+ video_id = extract_video_id(url)
129
+
130
+ if CUSTOM_DEBUG:
131
+ print_tool_call(
132
+ load_youtube_transcript,
133
+ tool_name='load_youtube_transcript',
134
+ args={'url': url},
135
+ )
136
+ try:
137
+ youtube_api_client = YouTubeTranscriptApi()
138
+ fetched_transcript = youtube_api_client.fetch(video_id=video_id)
139
+ transcript = " ".join(entry.text for entry in fetched_transcript if entry.text.strip())
140
+
141
+ if transcript and CUSTOM_DEBUG:
142
+ print_tool_response(transcript)
143
+
144
+ return transcript
145
+
146
+ except Exception as e:
147
+ error_str = f"Error loading transcript: {e}. Assuming no transcript for this video."
148
+ print_tool_response(error_str)
149
+ return error_str
150
+
151
+
152
+
153
+ gemini = ChatGoogleGenerativeAI(model="gemini-1.5-flash")
154
+
155
+ @tool
156
+ def image_query_tool(image_path: str, question: str) -> str:
157
+ """
158
+ Uses Gemini Vision to answer a question about an image.
159
+ - image_path: file path to the image to analyze (.png)
160
+ - question: the query to ask about the image
161
+ """
162
+ try:
163
+ base64_img = encode_image_to_base64(image_path)
164
+ except OSError:
165
+ response = f"OSError: Invalid argument (invalid image path or file format): {image_path}. Please provide a valid PNG image."
166
+ print_tool_response(response)
167
+ return response
168
+
169
+ base64_img_str = f"data:image/png;base64,{base64_img}"
170
+ if CUSTOM_DEBUG:
171
+ print_tool_call(
172
+ image_query_tool,
173
+ tool_name='image_query_tool',
174
+ args={'base64_image': base64_img_str[:100], 'question': question},
175
+ )
176
+ msg = HumanMessage(content=[
177
+ {"type": "text", "text": question},
178
+ {"type": "image_url", "image_url": base64_img_str},
179
+ ])
180
+ try:
181
+ response = gemini.invoke([msg])
182
+ except ChatGoogleGenerativeAIError:
183
+ response = "ChatGoogleGenerativeAIError: Invalid argument provided to Gemini: 400 Provided image is not valid"
184
+ print_tool_response(response)
185
+ return response
186
+ if CUSTOM_DEBUG:
187
+ print_tool_response(response.content)
188
+ return response.content
189
+
190
+
191
+ @tool
192
+ def search_and_extract_from_wikipedia(query: str) -> list:
193
+ """Search Wikipedia for a query and extract useful information."""
194
+ wiki_api_wrapper = WikipediaAPIWrapper()
195
+ wiki_tool = WikipediaQueryRun(api_wrapper=wiki_api_wrapper)
196
+ if CUSTOM_DEBUG:
197
+ print_tool_call(
198
+ search_and_extract_from_wikipedia,
199
+ tool_name='search_and_extract_from_wikipedia',
200
+ args={'query': query},
201
+ )
202
+ response = wiki_tool.invoke(query)
203
+ if CUSTOM_DEBUG:
204
+ print_tool_response(response)
205
+ return response