from typing import List, Tuple, Callable, Optional, Any from functools import partial import logging from pydantic import BaseModel, Field from langchain_core.language_models.llms import LLM from langchain_core.documents import Document from langchain_core.tools import Tool from ask_candid.retrieval.elastic import get_query_results, get_reranked_results from ask_candid.base.config.data import DataIndices from ask_candid.agents.schema import AgentState logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s") logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) class RetrieverInput(BaseModel): """Input to the Elasticsearch retriever.""" user_input: str = Field(description="query to look up in retriever") def get_search_results( user_input: str, indices: List[DataIndices], user_callback: Optional[Callable[[str], Any]] = None ) -> Tuple[str, List[Document]]: """End-to-end search and re-rank function. Parameters ---------- user_input : str Search context string indices : List[DataIndices] Semantic index names to search over user_callback : Optional[Callable[[str], Any]], optional Optional UI callback to inform the user of apps states, by default None Returns ------- Tuple[str, List[Document]] (concatenated text from search results, documents list) """ if user_callback is not None: try: user_callback("Searching for relevant information") except Exception as ex: logger.warning("User callback was passed in but failed: %s", ex) output = ["Search didn't return any Candid sources"] page_content = [] content = "Search didn't return any Candid sources" results = get_query_results(search_text=user_input, indices=indices) if results: output = get_reranked_results(results, search_text=user_input) for doc in output: page_content.append(doc.page_content) content = "\n\n".join(page_content) # for the tool we need to return a tuple for content_and_artifact type return content, output def retriever_tool( indices: List[DataIndices], user_callback: Optional[Callable[[str], Any]] = None ) -> Tool: """Tool component for use in conditional edge building for RAG execution graph. Cannot use `create_retriever_tool` because it only provides content losing all metadata on the way https://python.langchain.com/docs/how_to/custom_tools/#returning-artifacts-of-tool-execution Parameters ---------- indices : List[DataIndices] Semantic index names to search over user_callback : Optional[Callable[[str], Any]], optional Optional UI callback to inform the user of apps states, by default None Returns ------- Tool """ return Tool( name="retrieve_social_sector_information", func=partial(get_search_results, indices=indices, user_callback=user_callback), description=( "Return additional information about social and philanthropic sector, " "including nonprofits (NGO), grants, foundations, funding, RFP, LOI, Candid." ), args_schema=RetrieverInput, response_format="content_and_artifact" ) def search_agent(state: AgentState, llm: LLM, tools: List[Tool]) -> AgentState: """Invokes the agent model to generate a response based on the current state. Given the question, it will decide to retrieve using the retriever tool, or simply end. Parameters ---------- state : _type_ The current state llm : LLM tools : List[Tool] Returns ------- AgentState The updated state with the agent response appended to messages """ logger.info("---SEARCH AGENT---") messages = state["messages"] question = messages[-1].content model = llm.bind_tools(tools) response = model.invoke(messages) # return a list, because this will get added to the existing list return {"messages": [response], "user_input": question}