brainsqueeze's picture
UI callbacks and style changes
cc80c3d verified
raw
history blame
4.09 kB
from typing import List, Tuple, Callable, Optional, Any
from functools import partial
import logging
from pydantic import BaseModel, Field
from langchain_core.language_models.llms import LLM
from langchain_core.documents import Document
from langchain_core.tools import Tool
from ask_candid.retrieval.elastic import get_query_results, get_reranked_results
from ask_candid.base.config.data import DataIndices
from ask_candid.agents.schema import AgentState
logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class RetrieverInput(BaseModel):
"""Input to the Elasticsearch retriever."""
user_input: str = Field(description="query to look up in retriever")
def get_search_results(
user_input: str,
indices: List[DataIndices],
user_callback: Optional[Callable[[str], Any]] = None
) -> Tuple[str, List[Document]]:
"""End-to-end search and re-rank function.
Parameters
----------
user_input : str
Search context string
indices : List[DataIndices]
Semantic index names to search over
user_callback : Optional[Callable[[str], Any]], optional
Optional UI callback to inform the user of apps states, by default None
Returns
-------
Tuple[str, List[Document]]
(concatenated text from search results, documents list)
"""
if user_callback is not None:
try:
user_callback("Searching for relevant information")
except Exception as ex:
logger.warning("User callback was passed in but failed: %s", ex)
output = ["Search didn't return any Candid sources"]
page_content = []
content = "Search didn't return any Candid sources"
results = get_query_results(search_text=user_input, indices=indices)
if results:
output = get_reranked_results(results, search_text=user_input)
for doc in output:
page_content.append(doc.page_content)
content = "\n\n".join(page_content)
# for the tool we need to return a tuple for content_and_artifact type
return content, output
def retriever_tool(
indices: List[DataIndices],
user_callback: Optional[Callable[[str], Any]] = None
) -> Tool:
"""Tool component for use in conditional edge building for RAG execution graph.
Cannot use `create_retriever_tool` because it only provides content losing all metadata on the way
https://python.langchain.com/docs/how_to/custom_tools/#returning-artifacts-of-tool-execution
Parameters
----------
indices : List[DataIndices]
Semantic index names to search over
user_callback : Optional[Callable[[str], Any]], optional
Optional UI callback to inform the user of apps states, by default None
Returns
-------
Tool
"""
return Tool(
name="retrieve_social_sector_information",
func=partial(get_search_results, indices=indices, user_callback=user_callback),
description=(
"Return additional information about social and philanthropic sector, "
"including nonprofits (NGO), grants, foundations, funding, RFP, LOI, Candid."
),
args_schema=RetrieverInput,
response_format="content_and_artifact"
)
def search_agent(state: AgentState, llm: LLM, tools: List[Tool]) -> AgentState:
"""Invokes the agent model to generate a response based on the current state. Given
the question, it will decide to retrieve using the retriever tool, or simply end.
Parameters
----------
state : _type_
The current state
llm : LLM
tools : List[Tool]
Returns
-------
AgentState
The updated state with the agent response appended to messages
"""
logger.info("---SEARCH AGENT---")
messages = state["messages"]
question = messages[-1].content
model = llm.bind_tools(tools)
response = model.invoke(messages)
# return a list, because this will get added to the existing list
return {"messages": [response], "user_input": question}