brainsqueeze's picture
Adding optional news data source
bea5044 verified
raw
history blame
6.15 kB
from typing import List
from functools import partial
import logging
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.language_models.llms import LLM
from langgraph.prebuilt import tools_condition, ToolNode
from langgraph.graph.state import StateGraph
from langgraph.constants import START, END
from ask_candid.retrieval.elastic import retriever_tool
from ask_candid.tools.recommendation import (
detect_intent_with_llm,
determine_context,
make_recommendation
)
from ask_candid.tools.question_reformulation import reformulate_question_using_history
from ask_candid.tools.org_seach import has_org_name, insert_org_link
from ask_candid.tools.search import search_agent
from ask_candid.agents.schema import AgentState
from ask_candid.utils import html_format_docs_chat
logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def generate_with_context(state: AgentState, llm: LLM) -> AgentState:
"""Generate answer.
Parameters
----------
state : AgentState
The current state
llm : LLM
Returns
-------
AgentState
The updated state with the agent response appended to messages
"""
logger.info("---GENERATE ANSWER---")
messages = state["messages"]
question = state["user_input"]
last_message = messages[-1]
sources_str = last_message.content
sources_list = last_message.artifact # cannot use directly as list of Documents
# converting to html string
sources_html = html_format_docs_chat(sources_list)
if sources_list:
logger.info("---ADD SOURCES---")
state["messages"].append(BaseMessage(content=sources_html, type="HTML"))
# Prompt
qa_system_prompt = """
You are an assistant for question-answering tasks in the social and philanthropic sector. \n
Use the following pieces of retrieved context to answer the question at the end. \n
If you don't know the answer, just say that you don't know. \n
Keep the response professional, friendly, and as concise as possible. \n
Question: {question}
Context: {context}
Answer:
"""
qa_prompt = ChatPromptTemplate([
("system", qa_system_prompt),
("human", question),
])
rag_chain = qa_prompt | llm | StrOutputParser()
response = rag_chain.invoke({"context": sources_str, "question": question})
return {"messages": [AIMessage(content=response)], "user_input": question}
def add_recommendations_pipeline_(
G: StateGraph,
llm: LLM,
reformulation_node_name: str = "reformulate",
search_node_name: str = "search_agent"
) -> None:
"""Adds execution sub-graph for recommendation engine flow. Graph changes are in-place.
Parameters
----------
G : StateGraph
Execution graph
reformulation_node_name : str, optional
Name of the node which reforumates input queries, by default "reformulate"
search_node_name : str, optional
Name of the node which executes document search + retrieval, by default "search_agent"
"""
# Nodes for recommendation functionalities
G.add_node("detect_intent_with_llm", partial(detect_intent_with_llm, llm=llm))
G.add_node("determine_context", determine_context)
G.add_node("make_recommendation", make_recommendation)
# Check for recommendation query first
# Execute until reaching END if user asks for recommendation
G.add_edge(reformulation_node_name, "detect_intent_with_llm")
G.add_conditional_edges(
source="detect_intent_with_llm",
path=lambda state: "determine_context" if state["intent"] in ["rfp", "funder"] else search_node_name,
path_map={
"determine_context": "determine_context",
search_node_name: search_node_name
},
)
G.add_edge("determine_context", "make_recommendation")
G.add_edge("make_recommendation", END)
def build_compute_graph(
llm: LLM,
indices: List[str],
enable_recommendations: bool = False
) -> StateGraph:
"""Execution graph builder, the output is the execution flow for an interaction with the assistant.
Parameters
----------
llm : LLM
indices : List[str]
Semantic index names to search over
enable_recommendations : bool, optional
Set to `True` to allow the flow to generate recommendations based on context, by default False
Returns
-------
StateGraph
Execution graph
"""
candid_retriever_tool = retriever_tool(indices=indices)
retrieve = ToolNode([candid_retriever_tool])
tools = [candid_retriever_tool]
G = StateGraph(AgentState)
G.add_node("reformulate", partial(reformulate_question_using_history, llm=llm, focus_on_recommendations=enable_recommendations))
G.add_node("search_agent", partial(search_agent, llm=llm, tools=tools))
G.add_node("retrieve", retrieve)
G.add_node("generate_with_context", partial(generate_with_context, llm=llm))
G.add_node("has_org_name", partial(has_org_name, llm=llm))
G.add_node("insert_org_link", insert_org_link)
if enable_recommendations:
add_recommendations_pipeline_(G, llm=llm, reformulation_node_name="reformulate", search_node_name="search_agent")
else:
G.add_edge("reformulate", "search_agent")
G.add_edge(START, "reformulate")
G.add_conditional_edges(
source="search_agent",
path=tools_condition,
path_map={
"tools": "retrieve",
END: "has_org_name",
},
)
G.add_edge("retrieve", "generate_with_context")
G.add_edge("generate_with_context", "has_org_name")
G.add_conditional_edges(
source="has_org_name",
path=lambda x: x["next"], # Now we're accessing the 'next' key from the dict
path_map={
"insert_org_link": "insert_org_link",
END: END
},
)
G.add_edge("insert_org_link", END)
return G