Spaces:
Running
Running
from typing import List | |
from functools import partial | |
import logging | |
from langchain_core.messages import AIMessage, BaseMessage | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.language_models.llms import LLM | |
from langgraph.prebuilt import tools_condition, ToolNode | |
from langgraph.graph.state import StateGraph | |
from langgraph.constants import START, END | |
from ask_candid.retrieval.elastic import retriever_tool | |
from ask_candid.tools.recommendation import ( | |
detect_intent_with_llm, | |
determine_context, | |
make_recommendation | |
) | |
from ask_candid.tools.question_reformulation import reformulate_question_using_history | |
from ask_candid.tools.org_seach import has_org_name, insert_org_link | |
from ask_candid.tools.search import search_agent | |
from ask_candid.agents.schema import AgentState | |
from ask_candid.utils import html_format_docs_chat | |
logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s") | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
def generate_with_context(state: AgentState, llm: LLM) -> AgentState: | |
"""Generate answer. | |
Parameters | |
---------- | |
state : AgentState | |
The current state | |
llm : LLM | |
Returns | |
------- | |
AgentState | |
The updated state with the agent response appended to messages | |
""" | |
logger.info("---GENERATE ANSWER---") | |
messages = state["messages"] | |
question = state["user_input"] | |
last_message = messages[-1] | |
sources_str = last_message.content | |
sources_list = last_message.artifact # cannot use directly as list of Documents | |
# converting to html string | |
sources_html = html_format_docs_chat(sources_list) | |
if sources_list: | |
logger.info("---ADD SOURCES---") | |
state["messages"].append(BaseMessage(content=sources_html, type="HTML")) | |
# Prompt | |
qa_system_prompt = """ | |
You are an assistant for question-answering tasks in the social and philanthropic sector. \n | |
Use the following pieces of retrieved context to answer the question at the end. \n | |
If you don't know the answer, just say that you don't know. \n | |
Keep the response professional, friendly, and as concise as possible. \n | |
Question: {question} | |
Context: {context} | |
Answer: | |
""" | |
qa_prompt = ChatPromptTemplate([ | |
("system", qa_system_prompt), | |
("human", question), | |
]) | |
rag_chain = qa_prompt | llm | StrOutputParser() | |
response = rag_chain.invoke({"context": sources_str, "question": question}) | |
return {"messages": [AIMessage(content=response)], "user_input": question} | |
def add_recommendations_pipeline_( | |
G: StateGraph, | |
llm: LLM, | |
reformulation_node_name: str = "reformulate", | |
search_node_name: str = "search_agent" | |
) -> None: | |
"""Adds execution sub-graph for recommendation engine flow. Graph changes are in-place. | |
Parameters | |
---------- | |
G : StateGraph | |
Execution graph | |
reformulation_node_name : str, optional | |
Name of the node which reforumates input queries, by default "reformulate" | |
search_node_name : str, optional | |
Name of the node which executes document search + retrieval, by default "search_agent" | |
""" | |
# Nodes for recommendation functionalities | |
G.add_node("detect_intent_with_llm", partial(detect_intent_with_llm, llm=llm)) | |
G.add_node("determine_context", determine_context) | |
G.add_node("make_recommendation", make_recommendation) | |
# Check for recommendation query first | |
# Execute until reaching END if user asks for recommendation | |
G.add_edge(reformulation_node_name, "detect_intent_with_llm") | |
G.add_conditional_edges( | |
source="detect_intent_with_llm", | |
path=lambda state: "determine_context" if state["intent"] in ["rfp", "funder"] else search_node_name, | |
path_map={ | |
"determine_context": "determine_context", | |
search_node_name: search_node_name | |
}, | |
) | |
G.add_edge("determine_context", "make_recommendation") | |
G.add_edge("make_recommendation", END) | |
def build_compute_graph( | |
llm: LLM, | |
indices: List[str], | |
enable_recommendations: bool = False | |
) -> StateGraph: | |
"""Execution graph builder, the output is the execution flow for an interaction with the assistant. | |
Parameters | |
---------- | |
llm : LLM | |
indices : List[str] | |
Semantic index names to search over | |
enable_recommendations : bool, optional | |
Set to `True` to allow the flow to generate recommendations based on context, by default False | |
Returns | |
------- | |
StateGraph | |
Execution graph | |
""" | |
candid_retriever_tool = retriever_tool(indices=indices) | |
retrieve = ToolNode([candid_retriever_tool]) | |
tools = [candid_retriever_tool] | |
G = StateGraph(AgentState) | |
G.add_node("reformulate", partial(reformulate_question_using_history, llm=llm, focus_on_recommendations=enable_recommendations)) | |
G.add_node("search_agent", partial(search_agent, llm=llm, tools=tools)) | |
G.add_node("retrieve", retrieve) | |
G.add_node("generate_with_context", partial(generate_with_context, llm=llm)) | |
G.add_node("has_org_name", partial(has_org_name, llm=llm)) | |
G.add_node("insert_org_link", insert_org_link) | |
if enable_recommendations: | |
add_recommendations_pipeline_(G, llm=llm, reformulation_node_name="reformulate", search_node_name="search_agent") | |
else: | |
G.add_edge("reformulate", "search_agent") | |
G.add_edge(START, "reformulate") | |
G.add_conditional_edges( | |
source="search_agent", | |
path=tools_condition, | |
path_map={ | |
"tools": "retrieve", | |
END: "has_org_name", | |
}, | |
) | |
G.add_edge("retrieve", "generate_with_context") | |
G.add_edge("generate_with_context", "has_org_name") | |
G.add_conditional_edges( | |
source="has_org_name", | |
path=lambda x: x["next"], # Now we're accessing the 'next' key from the dict | |
path_map={ | |
"insert_org_link": "insert_org_link", | |
END: END | |
}, | |
) | |
G.add_edge("insert_org_link", END) | |
return G | |