File size: 1,109 Bytes
c751e97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from typing import List
import logging

from langchain_core.language_models.llms import LLM
from langchain_core.tools import Tool

from ask_candid.agents.schema import AgentState

logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def search_agent(state, llm: LLM, tools: List[Tool]) -> AgentState:
    """Invokes the agent model to generate a response based on the current state. Given
    the question, it will decide to retrieve using the retriever tool, or simply end.

    Parameters
    ----------
    state : _type_
        The current state
    llm : LLM
    tools : List[Tool]

    Returns
    -------
    AgentState
        The updated state with the agent response appended to messages
    """

    logger.info("---SEARCH AGENT---")
    messages = state["messages"]
    question = messages[-1].content

    model = llm.bind_tools(tools)
    response = model.invoke(messages)
    # return a list, because this will get added to the existing list
    return {"messages": [response], "user_input": question}