Spaces:
Sleeping
Sleeping
import os | |
import base64 | |
import json | |
import inspect | |
import time | |
import trafilatura | |
from typing import Callable, Union | |
from pathlib import PurePath | |
from datetime import datetime, timezone | |
from markitdown import MarkItDown | |
from langchain.tools import tool | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
from langchain_core.messages import HumanMessage | |
from langchain_google_genai.chat_models import ChatGoogleGenerativeAIError | |
from langchain_tavily import TavilySearch, TavilyExtract | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain_community.utilities.wikipedia import WikipediaAPIWrapper | |
from langchain_community.tools.wikipedia.tool import WikipediaQueryRun | |
from langchain_google_community import SpeechToTextLoader | |
from langchain_community.tools import YouTubeSearchTool | |
from youtube_transcript_api import YouTubeTranscriptApi | |
from langchain_community.document_loaders.blob_loaders.youtube_audio import YoutubeAudioLoader | |
from langchain_community.tools.file_management.read import ReadFileTool | |
from langchain.chains.summarize import load_summarize_chain | |
from langchain.prompts import PromptTemplate | |
from langchain_core.documents import Document | |
from langchain_openai import ChatOpenAI | |
from basic_agent import print_conversation | |
from dotenv import load_dotenv | |
from langchain.globals import set_debug | |
from urllib.parse import urlparse, parse_qs | |
set_debug(False) | |
CUSTOM_DEBUG = True | |
load_dotenv() | |
def encode_image_to_base64(path): | |
with open(path, "rb") as image_file: | |
return base64.b64encode(image_file.read()).decode("utf-8") | |
def print_tool_call(tool: Callable, tool_name: str, args: dict): | |
"""Prints the tool call for debugging purposes.""" | |
sig = inspect.signature(tool) | |
print_conversation( | |
messages=[ | |
{ | |
'role': 'Tool-Call', | |
'content': f"Calling `{tool_name}`{sig}" | |
}, | |
{ | |
'role': 'Tool-Args', | |
'content': args | |
} | |
], | |
) | |
def print_tool_response(response: str): | |
"""Prints the tool response for debugging purposes.""" | |
print_conversation( | |
messages=[ | |
{ | |
'role': 'Tool-Response', | |
'content': response | |
} | |
], | |
) | |
search_tool = TavilySearch(max_results=3) | |
extract_tool = TavilyExtract() | |
def search_and_extract(query: str) -> list[dict]: | |
"""Performs a web search and returns structured content extracted from top results.""" | |
time.sleep(3) # To avoid hitting the API rate limit in the llm-apis when calling the tool multiple times in a row. | |
MAX_NUMBER_OF_CHARS = 15_000 | |
if CUSTOM_DEBUG: | |
print_tool_call( | |
search_and_extract, | |
tool_name='search_and_extract', | |
args={'query': query, 'max_number_of_chars': MAX_NUMBER_OF_CHARS}, | |
) | |
results = search_tool.invoke({"query": query}) | |
raw_results = results.get("results", []) | |
urls = [r["url"] for r in raw_results if r.get("url")] | |
if not urls: | |
return [{"error": "No URLs found to extract from."}] | |
extracted = extract_tool.invoke({"urls": urls}) | |
results = extracted.get("results", []) | |
structured_results = [] | |
raw_contents = [doc.get("raw_content", "") for doc in results] | |
for result, doc_content in zip(raw_results, raw_contents): | |
doc_content_trunc = doc_content[0:MAX_NUMBER_OF_CHARS] if len(doc_content) > MAX_NUMBER_OF_CHARS else doc_content | |
structured_results.append({ | |
"title": result.get("title"), | |
"url": result.get("url"), | |
"snippet": result.get("content"), | |
"raw_content": doc_content_trunc | |
}) | |
if CUSTOM_DEBUG: | |
console_structured_results = [{k: v for k, v in result_dicti.items() if k != "raw_content"} for result_dicti in | |
structured_results] | |
print_tool_response(json.dumps(console_structured_results)) | |
return structured_results | |
def aggregate_information(query: str, results: list[str]) -> str: | |
""" | |
Processes a list of unstructured text chunks (e.g., search results) and produces a concise, query-specific summary. | |
Each input text is filtered and summarized individually in the context of the provided query. Irrelevant results are discarded. | |
Relevant content is aggregated and synthesized into a final, coherent answer that directly addresses the query. | |
""" | |
if CUSTOM_DEBUG: | |
print_tool_call( | |
aggregate_information, | |
tool_name='aggregate_information', | |
args={'results': results, 'query': query}, | |
) | |
if not results: | |
response = "No search results provided." | |
if CUSTOM_DEBUG: | |
print_tool_response(response) | |
return response | |
# Convert to LangChain Document objects | |
docs = [Document(page_content=chunk) for chunk in results] | |
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0) | |
# 🔍 Map Prompt — Summarize each document in light of the query | |
map_prompt = PromptTemplate.from_template( | |
"You are analyzing a search result in the context of the question: '{query}'.\n\n" | |
"Search result:\n{text}\n\n" | |
"Instructions:\n" | |
"- If the result contains information relevant to answering the query, summarize the relevant parts clearly.\n" | |
"- If the result is not helpful or unrelated, return 'IGNORE'.\n" | |
"- Do not include generic information or filler.\n" | |
"- Focus on extracting facts, key statements, or numbers that directly support the query.\n\n" | |
"Relevant Summary:" | |
) | |
# 🧠 Combine Prompt — Aggregate the summaries to one final answer | |
combine_prompt = PromptTemplate.from_template( | |
"You are aggregating information to provide context to answer the following question: '{query}'.\n\n" | |
"Here are the summaries from filtered search results:\n{text}\n\n" | |
"Use the provided summaries to construct a context that directly supports the query without answering it.\n" | |
"Context:" | |
) | |
chain = load_summarize_chain( | |
llm, | |
chain_type="map_reduce", | |
map_prompt=map_prompt.partial(query=query), | |
combine_prompt=combine_prompt.partial(query=query), | |
) | |
summary = chain.invoke({'input_documents': docs}) | |
output_text = summary.get('output_text', str(summary)) | |
output_text = json.dumps({'summary': output_text}) | |
if CUSTOM_DEBUG: | |
print_tool_response(output_text) | |
return output_text | |
gemini = ChatGoogleGenerativeAI(model="gemini-1.5-flash") | |
def image_query_tool(image_path: str, question: str) -> str: | |
""" | |
Uses Gemini Vision to answer a question about an image. | |
- image_path: file path to the image to analyze (.png) | |
- question: the query to ask about the image | |
""" | |
try: | |
base64_img = encode_image_to_base64(image_path) | |
except OSError: | |
response = f"OSError: Invalid argument (invalid image path or file format): {image_path}. Please provide a valid PNG image." | |
print_tool_response(response) | |
return response | |
base64_img_str = f"data:image/png;base64,{base64_img}" | |
if CUSTOM_DEBUG: | |
print_tool_call( | |
image_query_tool, | |
tool_name='image_query_tool', | |
args={'base64_image': base64_img_str[:100], 'question': question}, | |
) | |
msg = HumanMessage(content=[ | |
{"type": "text", "text": question}, | |
{"type": "image_url", "image_url": base64_img_str}, | |
]) | |
try: | |
response = gemini.invoke([msg]) | |
except ChatGoogleGenerativeAIError: | |
response = "ChatGoogleGenerativeAIError: Invalid argument provided to Gemini: 400 Provided image is not valid" | |
print_tool_response(response) | |
return response | |
if CUSTOM_DEBUG: | |
print_tool_response(response.content) | |
return response.content | |
def extract_video_id(url: str) -> str: | |
parsed = urlparse(url) | |
return parse_qs(parsed.query).get("v", [""])[0] | |
def get_audio_from_youtube(urls: list[str], save_dir:str="./tmp/") -> list[str | PurePath | None]: | |
"""Extracts audio from a YouTube video URL.""" | |
if CUSTOM_DEBUG: | |
print_tool_call( | |
get_audio_from_youtube, | |
tool_name='get_audio_from_youtube', | |
args={'urls': urls, 'save_dir': save_dir}, | |
) | |
loader = YoutubeAudioLoader(urls, save_dir) | |
audio_blobs = list(loader.yield_blobs()) | |
paths = [str(blob.path) for blob in audio_blobs] | |
if CUSTOM_DEBUG: | |
print_tool_response(json.dumps({'paths': paths})) | |
return paths | |
def load_youtube_transcript(url: str) -> str: | |
"""Load a YouTube transcript using youtube_transcript_api.""" | |
video_id = extract_video_id(url) | |
if CUSTOM_DEBUG: | |
print_tool_call( | |
load_youtube_transcript, | |
tool_name='load_youtube_transcript', | |
args={'url': url}, | |
) | |
try: | |
youtube_api_client = YouTubeTranscriptApi() | |
fetched_transcript = youtube_api_client.fetch(video_id=video_id) | |
transcript = " ".join(entry.text for entry in fetched_transcript if entry.text.strip()) | |
if transcript and CUSTOM_DEBUG: | |
print_tool_response(transcript) | |
return transcript | |
except Exception as e: | |
error_str = f"Error loading transcript: {e}. Assuming no transcript for this video." | |
print_tool_response(error_str) | |
return error_str | |
youtube_search_api = YouTubeSearchTool() | |
def youtube_search_tool(query: str, number_of_results:int=3) -> list: | |
"""Search YouTube for a query and return the top number_of_results.""" | |
if CUSTOM_DEBUG: | |
print_tool_call( | |
youtube_search_tool, | |
tool_name='youtube_search_tool', | |
args={'query': query, number_of_results: number_of_results}, | |
) | |
response = youtube_search_api.run(f"{query},{number_of_results}") | |
if CUSTOM_DEBUG: | |
print_tool_response(response) | |
return response | |
def search_and_extract_from_wikipedia(query: str) -> list: | |
"""Search Wikipedia for a query and extract useful information.""" | |
wiki_api_wrapper = WikipediaAPIWrapper() | |
wiki_tool = WikipediaQueryRun(api_wrapper=wiki_api_wrapper) | |
if CUSTOM_DEBUG: | |
print_tool_call( | |
search_and_extract_from_wikipedia, | |
tool_name='search_and_extract_from_wikipedia', | |
args={'query': query}, | |
) | |
response = wiki_tool.invoke(query) | |
if CUSTOM_DEBUG: | |
print_tool_response(response) | |
return response | |
def transcribe_audio(file_path: str) -> list: | |
"""Transcribe audio from an audio file in file_path using Google Speech-to-Text.""" | |
docs, docs_content = [], [] | |
if CUSTOM_DEBUG: | |
print_tool_call( | |
transcribe_audio, | |
tool_name='transcribe_audio', | |
args={'file_path': file_path}, | |
) | |
try: | |
loader = SpeechToTextLoader( | |
project_id=os.getenv("GOOGLE_CLOUD_PROJECT_ID"), | |
file_path=file_path, | |
is_long = False, # Set to True for long audio files | |
) | |
docs = loader.load() | |
except Exception as e: | |
print(f"Error loading audio file: {e}") | |
try: | |
loader = SpeechToTextLoader( | |
project_id=os.getenv("GOOGLE_CLOUD_PROJECT_ID"), | |
file_path=file_path, | |
is_long=True, # Set to True for long audio files | |
) | |
docs = loader.load() | |
except Exception as e: | |
docs_content = [f"Error loading audio file: {e}"] | |
docs_content = [doc.page_content for doc in docs] if docs else docs_content | |
if CUSTOM_DEBUG: | |
print_tool_response(docs_content) | |
return docs_content | |
def extract_clean_text_from_url(url: str) -> str: | |
"""Extract the main readable content from a webpage using trafilatura.""" | |
if CUSTOM_DEBUG: | |
print_tool_call( | |
extract_clean_text_from_url, | |
tool_name='extract_clean_text_from_url', | |
args={'url': url}, | |
) | |
downloaded = trafilatura.fetch_url(url) | |
response = "" | |
if not downloaded: | |
response = "Failed to download the page. Please check the URL." | |
if not "Failed" in response: | |
response = trafilatura.extract(downloaded) | |
response = response or "No meaningful content found." | |
if CUSTOM_DEBUG: | |
print_tool_response(response) | |
return response | |
read_tool = ReadFileTool() | |
def smart_read_file(file_path: str) -> str: | |
""" | |
Smart tool to read a file based on its type. | |
- Use `read_file_tool` for simple text, CSV, code files. | |
- Use MarkItDown for PDFs, Word, Excel, HTML, and other complex formats. | |
""" | |
if CUSTOM_DEBUG: | |
print_tool_call( | |
smart_read_file, | |
tool_name='smart_read_file', | |
args={'file_path': file_path}, | |
) | |
_, ext = os.path.splitext(file_path.lower()) | |
if ext in [".mp3", ".wav", ".m4a", ".flac"]: | |
# If the file is an audio file, transcribe it | |
return transcribe_audio.invoke({"file_path": file_path}) | |
if ext in [".png", ".jpg", ".jpeg", ".gif", ".bmp"]: | |
# If the file is an image, use image_query_tool to analyze it | |
q = "What can you tell me about this image?" | |
return image_query_tool.invoke({"image_path": file_path, "question": q}) | |
if any(ext in url_pattern for url_pattern in ["http://", "https://", "www."]): | |
if "youtube.com/watch?v=" in file_path: | |
transcript = load_youtube_transcript.invoke({"url": file_path}) | |
if "Error loading" in transcript: | |
return get_audio_from_youtube.invoke({'urls': [file_path], 'save_dir': './tmp/'}) | |
else: | |
return extract_clean_text_from_url.invoke(file_path) | |
md = MarkItDown() | |
try: | |
result = md.convert(file_path) | |
result = result.text_content | |
except Exception as e: | |
# print("Error reading file with MarkItDown:", e) | |
result = read_tool.invoke({"file_path": file_path}) | |
if CUSTOM_DEBUG: | |
print_tool_response(result) | |
return result | |