laverdes commited on
Commit
f0a8f60
·
verified ·
1 Parent(s): 1853d1e

feat: basic tool-use langchain agent besides langgraph

Browse files
Files changed (1) hide show
  1. basic_agent.py +156 -116
basic_agent.py CHANGED
@@ -8,13 +8,9 @@ from rich.markdown import Markdown
8
  from rich.json import JSON
9
 
10
  from typing import TypedDict, Sequence, Annotated
11
- from langchain_core.messages import BaseMessage
12
- from langgraph.graph.message import add_messages
13
- from langgraph.graph import StateGraph, START, END
14
  from langchain_openai import ChatOpenAI
15
- from langgraph.prebuilt import ToolNode, tools_condition
16
- from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
17
- from tqdm import tqdm
18
 
19
 
20
  def print_conversation(messages):
@@ -22,18 +18,55 @@ def print_conversation(messages):
22
 
23
  for msg in messages:
24
  role = msg.get("role", "unknown").capitalize()
 
25
  content = msg.get("content", "")
26
 
27
  try:
28
- parsed_json = json.loads(content)
29
- rendered_content = JSON.from_data(parsed_json)
 
 
 
 
 
 
 
 
 
 
30
  except (json.JSONDecodeError, TypeError):
31
- rendered_content = Markdown(content.strip())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  panel = Panel(
34
  rendered_content,
35
  title=f"[bold blue]{role}[/]",
36
- border_style="green" if role == "User" else "magenta",
37
  expand=True
38
  )
39
 
@@ -43,18 +76,44 @@ def print_conversation(messages):
43
  def generate_final_answer(qa: dict[str, str]) -> str:
44
  """Invokes gpt-4o-mini to extract generate a final answer based on the content query, response, and metadata"""
45
 
46
- final_answer_llm = ChatOpenAI(model="gpt-4o", temperature=0)
47
 
48
  system_prompt = (
49
- "You will receive a JSON string containing a user's query, a response, and metadata. "
50
- "Extract and return only the final answer to the query as a plain string. "
51
- "Do not return anything else. "
52
- "Avoid any labels, prefixes, or explanation. "
53
- "Return only the exact value that satisfies the query, suitable for string comparison."
54
- "If the query is not answerable due to a missing file in the input and is reflected in the response, answer with 'File not found'. "
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  )
56
 
57
  system_message = SystemMessage(content=system_prompt)
 
 
 
 
58
  messages = [
59
  system_message,
60
  HumanMessage(content=f'Generate the final answer for the following query:\n\n{json.dumps(qa)}')
@@ -63,124 +122,105 @@ def generate_final_answer(qa: dict[str, str]) -> str:
63
  response = final_answer_llm.invoke(messages)
64
 
65
  return response.content
66
-
67
 
68
- class AgentState(TypedDict):
69
- messages: Annotated[Sequence[BaseMessage], add_messages]
70
 
71
-
72
- class BasicOpenAIAgentWorkflow:
73
  """Basic custom class from an agent prompted for tool-use pattern"""
74
 
75
  def __init__(self, tools: list, model='gpt-4o', backstory:str="", streaming=False):
76
- self.name = "Basic OpenAI Agent Workflow"
77
  self.tools = tools
78
- self.llm = ChatOpenAI(model=model, temperature=0, streaming=streaming)
79
- self.graph = None
80
- self.history = []
81
- self.history_messages = [] # Store messages in LangChain format
82
  self.backstory = backstory if backstory else "You are a helpful assistant that can use tools to answer questions. Your name is Gaia."
83
 
84
- role_message = {'role': 'system', 'content': self.backstory}
85
- self.history.append(role_message)
86
-
87
-
88
- def _call_llm(self, state: AgentState):
89
- """invokes the assigned llm"""
90
- return {'messages': [self.llm.invoke(state['messages'])]}
91
-
92
-
93
- def _convert_history_to_messages(self):
94
- """Convert self.history to LangChain-compatible messages"""
95
- converted = []
96
- for msg in self.history:
97
- content = msg['content']
98
-
99
- if not isinstance(content, str):
100
- raise ValueError(f"Expected string content, got: {type(content)} — {content}")
101
-
102
- if msg['role'] == 'user':
103
- converted.append(HumanMessage(content=content))
104
- elif msg['role'] == 'assistant':
105
- converted.append(AIMessage(content=content))
106
- elif msg['role'] == 'system':
107
- converted.append(SystemMessage(content=content))
108
- else:
109
- raise ValueError(f"Unknown role in message: {msg}")
110
- self.history_messages = converted
111
-
112
 
113
  def create_basic_tool_use_agent_state_graph(self, custom_tools_nm="tools"):
114
  """Binds tools, creates and compiles graph"""
115
- self.llm = self.llm.bind_tools(self.tools)
116
-
117
- # Graph Init
118
- graph = StateGraph(AgentState)
119
-
120
- # Nodes
121
- graph.add_node('agent', self._call_llm)
122
- tools_node = ToolNode(self.tools)
123
- graph.add_node(custom_tools_nm, tools_node)
124
 
125
- # Edges
126
- graph.add_edge(START, "agent")
127
- graph.add_conditional_edges('agent', tools_condition, {'tools': custom_tools_nm, END: END})
128
-
129
- self.graph = graph.compile()
130
-
131
-
132
- def chat(self, query, verbose=2, only_final_answer=False):
133
- """Simple agent call"""
134
- if isinstance(query, dict):
135
- query = query["messages"]
136
-
137
- user_message = {'role': 'user', 'content': query}
138
- self.history.append(user_message)
139
-
140
- # Ensure history has at least 1 message
141
- if not self.history:
142
- raise ValueError("History is empty. Cannot proceed.")
143
 
144
- self._convert_history_to_messages()
 
 
 
 
 
 
145
 
146
- if not self.history_messages:
147
- raise ValueError("Converted message history is empty. Something went wrong.")
148
 
149
- response = self.graph.invoke({'messages': self.history_messages}) # invoke with all the history to keep context (dummy mem)
150
- response = response['messages'][-1].content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
- if only_final_answer:
153
- final_answer_content = {
154
  'query': query,
155
  'response': response,
156
- 'metadata': {}
157
- }
158
- response = generate_final_answer(final_answer_content)
 
159
 
160
- assistant_message = {'role': 'assistant', 'content': response}
161
- self.history.append(assistant_message)
162
-
163
- if verbose==2:
164
- print_conversation(self.history)
165
- elif verbose==1:
166
- print_conversation([assistant_message])
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  return response
169
 
170
-
171
- def invoke(self, input_str: str):
172
- """Invoke the compiled graph with the input data"""
173
- _ = self.chat(input_str) # prints response in terminal
174
- self._convert_history_to_messages()
175
- return {'messages': self.history_messages}
176
-
177
-
178
- def chat_batch(self, queries=None, only_final_answer=False):
179
- """Send several simple agent calls to the llm using the compiled graph"""
180
- if queries is None:
181
- queries = []
182
- for i, query in tqdm(enumerate(queries, start=1)):
183
- if i == len(queries):
184
- self.chat(query, verbose=2, only_final_answer=only_final_answer)
185
- else:
186
- self.chat(query, verbose=0, only_final_answer=only_final_answer)
 
8
  from rich.json import JSON
9
 
10
  from typing import TypedDict, Sequence, Annotated
11
+ from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
 
 
12
  from langchain_openai import ChatOpenAI
13
+
 
 
14
 
15
 
16
  def print_conversation(messages):
 
18
 
19
  for msg in messages:
20
  role = msg.get("role", "unknown").capitalize()
21
+
22
  content = msg.get("content", "")
23
 
24
  try:
25
+ if isinstance(content, str):
26
+ content = json.loads(content)
27
+
28
+ elif isinstance(content, dict) and 'output' in content.keys():
29
+ if isinstance(content['output'], HumanMessage):
30
+ content['output'] = content['output'].content
31
+
32
+ elif isinstance(content, HumanMessage):
33
+ content = content.content
34
+
35
+ rendered_content = JSON.from_data(content)
36
+
37
  except (json.JSONDecodeError, TypeError):
38
+ try:
39
+ rendered_content = Markdown(content.strip())
40
+ except AttributeError:
41
+ # from gemini
42
+ try:
43
+ rendered_content = {
44
+ 'query': content.get('query', 'QueryKeyNotFound').content[0]['text'],
45
+ 'output': content.get('output', 'OutputKeyNotFound'),
46
+ }
47
+ rendered_content = JSON.from_data(rendered_content)
48
+
49
+ except Exception as e:
50
+ print(f"Failed to render content for role: {role}. Content: {content}")
51
+ print("Error:", e)
52
+
53
+
54
+ border_style_color = "red"
55
+ if "Assistant" in role:
56
+ border_style_color = "magenta"
57
+ elif "User" in role:
58
+ border_style_color = "green"
59
+ elif "System" in role:
60
+ border_style_color = "blue"
61
+ elif "Tool" in role:
62
+ border_style_color = "yellow"
63
+ elif "Token" in role:
64
+ border_style_color = "white"
65
 
66
  panel = Panel(
67
  rendered_content,
68
  title=f"[bold blue]{role}[/]",
69
+ border_style=border_style_color,
70
  expand=True
71
  )
72
 
 
76
  def generate_final_answer(qa: dict[str, str]) -> str:
77
  """Invokes gpt-4o-mini to extract generate a final answer based on the content query, response, and metadata"""
78
 
79
+ final_answer_llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
80
 
81
  system_prompt = (
82
+ "You will be given a JSON object containing a user's query, a response from an AI assistant, and optional metadata. "
83
+ "Your task is to extract and return a final answer to the query as a plain string, strictly suitable for exact match evaluation. "
84
+
85
+ "Do NOT answer the query yourself. Use the response as the source of truth. "
86
+ "Use the query only as context to interpret the response and extract a final, normalized answer. "
87
+
88
+ "Your output must be:\n"
89
+ "- A **single plain string** with **no prefixes, labels, or explanations**.\n"
90
+ "- Suitable for exact string comparison.\n"
91
+ "- Clean and deterministic: no variation in formatting, casing, or punctuation."
92
+
93
+ "Special rules:\n"
94
+ "- If the response shows inability to process attached media (images, audio, video), return: **'File not found'**.\n"
95
+ "- If the response is a list of search results aggregate the information before constructing an answer"
96
+ "- If the query is quantitative (How many...?), **aggregate the results of the tool(s) call(s) and return the numeric answer** only.\n"
97
+ "- If the query is unanswerable from the response, return: **'No answer found: <brief reason>'**."
98
+
99
+ "Examples:\n"
100
+ "- Query: 'What’s in the attached image?'\n"
101
+ " Response: 'I'm unable to view images directly...'\n"
102
+ " Output: 'File not found'\n\n"
103
+ "- Query: 'What’s the total population of X'\n"
104
+ " Response: '{title: demographics of X, content: 1. City A: 2M, 2. City B: 3M, title: history of X, content: currently there are Y number of inhabitants in X...'\n"
105
+ " Output: '5000000'\n"
106
+
107
+ "Strictly follow these rules. Some final answers will require more analysis if the provided response. "
108
+ "You can reason to get to the answer but always consider the response as the base_knowledge (keep coherence)."
109
+ "Return only the final string answer. Do not include any other content."
110
  )
111
 
112
  system_message = SystemMessage(content=system_prompt)
113
+
114
+ if isinstance(qa['response']['query'], HumanMessage):
115
+ qa['response'] = qa['response']['output']
116
+
117
  messages = [
118
  system_message,
119
  HumanMessage(content=f'Generate the final answer for the following query:\n\n{json.dumps(qa)}')
 
122
  response = final_answer_llm.invoke(messages)
123
 
124
  return response.content
 
125
 
 
 
126
 
127
+ class ToolAgent:
 
128
  """Basic custom class from an agent prompted for tool-use pattern"""
129
 
130
  def __init__(self, tools: list, model='gpt-4o', backstory:str="", streaming=False):
131
+ self.name = "GAIA Tool-Use Agent"
132
  self.tools = tools
133
+ self.llm = ChatOpenAI(model=model, temperature=0, streaming=streaming, max_retries=5)
134
+ self.executor = None
 
 
135
  self.backstory = backstory if backstory else "You are a helpful assistant that can use tools to answer questions. Your name is Gaia."
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  def create_basic_tool_use_agent_state_graph(self, custom_tools_nm="tools"):
139
  """Binds tools, creates and compiles graph"""
 
 
 
 
 
 
 
 
 
140
 
141
+ tools_info = '\n\n'.join([f'{tool.name}: {tool.description}: {tool.args}' for tool in self.tools])
142
+ chatgpt_with_tools = self.llm.bind_tools(self.tools)
143
+
144
+ prompt_template = ChatPromptTemplate.from_messages(
145
+ [
146
+ ("system", self.backstory),
147
+ MessagesPlaceholder(variable_name="history", optional=True),
148
+ ("human", "{query}"),
149
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
150
+ ]
151
+ )
 
 
 
 
 
 
 
152
 
153
+ agent = create_tool_calling_agent(self.llm, self.tools, prompt_template)
154
+ self.executor = AgentExecutor(
155
+ agent=agent,
156
+ tools=self.tools,
157
+ early_stopping_method='force',
158
+ max_iterations=10
159
+ )
160
 
 
 
161
 
162
+ def chat(self, query:str, metadata):
163
+ """Perform a single step in the conversation with the tool agent executor."""
164
+ if metadata is None:
165
+ metadata = {}
166
+
167
+ with_attachments = False
168
+ query_message = HumanMessage(content=query)
169
+
170
+ if "image_path" in metadata:
171
+
172
+ # Create a HumanMessage with image content
173
+ query_message = HumanMessage(
174
+ content=[
175
+ {"type": "text", "text": query},
176
+ {"type": "text", "text": f"image_path: {metadata['image_path']}"},
177
+ ]
178
+ )
179
+ with_attachments = True
180
+
181
+ user_message = {'role': 'user', 'content': query if not with_attachments else query_message}
182
+ print_conversation([user_message])
183
+
184
+ response = self.executor.invoke({
185
+ "query": query if not with_attachments else query_message,
186
+ })
187
+ response_message = {'role': 'assistant', 'content': response}
188
+ print_conversation([response_message])
189
 
190
+ final_answer = generate_final_answer({
 
191
  'query': query,
192
  'response': response,
193
+ })
194
+ final_answer_message = {'role': 'Final Answer', 'content': final_answer}
195
+ print_conversation([final_answer_message])
196
+ return final_answer
197
 
 
 
 
 
 
 
 
198
 
199
+ def invoke(self, q_data):
200
+ """Invoke the executor input data"""
201
+ query = q_data.get("query", "")
202
+ metadata = q_data.get("metadata", None)
203
+
204
+ try:
205
+ response = self.chat(query, metadata)
206
+ time.sleep(3)
207
+ except RateLimitError:
208
+ response = 'Rate limit error encountered. Retrying after a short pause...'
209
+ error_message = {'role': 'Rate-limit-hit', 'content': response}
210
+ print_conversation([error_message])
211
+ time.sleep(5)
212
+
213
+ try:
214
+ response = self.chat(query, metadata)
215
+ except RateLimitError:
216
+ response = 'Rate limit error encountered again. Skipping this query.'
217
+ error_message = {'role': 'Rate-limit-hit', 'content': response}
218
+ print_conversation([error_message])
219
+
220
+ print()
221
  return response
222
 
223
+
224
+ def __call__(self, q_data):
225
+ """Call the invoke method from the agent executor."""
226
+ return self.invoke(q_data)