brainsqueeze commited on
Commit
c751e97
·
verified ·
1 Parent(s): b73a55e

Feedback and style changes

Browse files

Simplified style of Candid resource links. Added feedback tab.

app.py CHANGED
@@ -4,7 +4,6 @@ import os
4
  import gradio as gr
5
 
6
  from langchain_core.language_models.llms import LLM
7
-
8
  from langchain_openai.chat_models import ChatOpenAI
9
  from langchain_aws import ChatBedrock
10
  import boto3
@@ -15,9 +14,12 @@ from ask_candid.base.config.data import ALL_INDICES
15
  from ask_candid.utils import format_chat_ag_response
16
  from ask_candid.chat import run_chat
17
 
 
 
 
 
 
18
  ROOT = os.path.dirname(os.path.abspath(__file__))
19
- BUCKET = "candid-data-science-reporting"
20
- PREFIX = "Assistant"
21
 
22
  class LoggedComponents(TypedDict):
23
  context: List[gr.components.Component]
@@ -27,6 +29,31 @@ class LoggedComponents(TypedDict):
27
  email: gr.components.Component
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def select_foundation_model(model_name: str, max_new_tokens: int) -> LLM:
31
  if model_name == "gpt-4o":
32
  llm = ChatOpenAI(
@@ -92,7 +119,7 @@ def build_rag_chat() -> Tuple[LoggedComponents, gr.Blocks]:
92
  )
93
  llmname = gr.Radio(
94
  label="Language model",
95
- value="gpt-4o",
96
  choices=list(Name2Endpoint.keys()),
97
  interactive=True,
98
  )
@@ -118,8 +145,8 @@ def build_rag_chat() -> Tuple[LoggedComponents, gr.Blocks]:
118
  type="messages",
119
  show_label=False,
120
  show_copy_button=True,
121
- show_share_button=True,
122
- show_copy_all_button=True,
123
  )
124
  msg = gr.MultimodalTextbox(label="Your message", interactive=True)
125
  thread_id = gr.Text(visible=False, value="", label="thread_id")
@@ -132,12 +159,54 @@ def build_rag_chat() -> Tuple[LoggedComponents, gr.Blocks]:
132
  outputs=[msg, chatbot, thread_id],
133
  )
134
  chat_msg.then(format_chat_ag_response, chatbot, chatbot, api_name="bot_response")
135
- logged = LoggedComponents(context=[thread_id, chatbot])
136
  return logged, demo
137
 
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  def build_app():
140
- _, candid_chat = build_rag_chat()
 
141
 
142
  with open(os.path.join(ROOT, "static", "chatStyle.css"), "r", encoding="utf8") as f:
143
  css_chat = f.read()
@@ -145,10 +214,13 @@ def build_app():
145
  demo = gr.TabbedInterface(
146
  interface_list=[
147
  candid_chat,
 
148
  ],
149
  tab_names=[
150
- "AskCandid",
 
151
  ],
 
152
  theme=gr.themes.Soft(),
153
  css=css_chat,
154
  )
 
4
  import gradio as gr
5
 
6
  from langchain_core.language_models.llms import LLM
 
7
  from langchain_openai.chat_models import ChatOpenAI
8
  from langchain_aws import ChatBedrock
9
  import boto3
 
14
  from ask_candid.utils import format_chat_ag_response
15
  from ask_candid.chat import run_chat
16
 
17
+ try:
18
+ from feedback import FeedbackApi
19
+ except ImportError:
20
+ from demos.feedback import FeedbackApi
21
+
22
  ROOT = os.path.dirname(os.path.abspath(__file__))
 
 
23
 
24
  class LoggedComponents(TypedDict):
25
  context: List[gr.components.Component]
 
29
  email: gr.components.Component
30
 
31
 
32
+ def send_feedback(
33
+ chat_context,
34
+ found_helpful,
35
+ will_recommend,
36
+ comments,
37
+ email
38
+ ):
39
+ api = FeedbackApi()
40
+ total_submissions = 0
41
+
42
+ try:
43
+ response = api(
44
+ context=chat_context,
45
+ found_helpful=found_helpful,
46
+ will_recommend=will_recommend,
47
+ comments=comments,
48
+ email=email
49
+ )
50
+ total_submissions = response.get("response", 0)
51
+ gr.Info("Thank you for submitting feedback")
52
+ except Exception as ex:
53
+ raise gr.Error(f"Error submitting feedback: {ex}")
54
+ return total_submissions
55
+
56
+
57
  def select_foundation_model(model_name: str, max_new_tokens: int) -> LLM:
58
  if model_name == "gpt-4o":
59
  llm = ChatOpenAI(
 
119
  )
120
  llmname = gr.Radio(
121
  label="Language model",
122
+ value="claude-3.5-haiku",
123
  choices=list(Name2Endpoint.keys()),
124
  interactive=True,
125
  )
 
145
  type="messages",
146
  show_label=False,
147
  show_copy_button=True,
148
+ show_share_button=None,
149
+ show_copy_all_button=False,
150
  )
151
  msg = gr.MultimodalTextbox(label="Your message", interactive=True)
152
  thread_id = gr.Text(visible=False, value="", label="thread_id")
 
159
  outputs=[msg, chatbot, thread_id],
160
  )
161
  chat_msg.then(format_chat_ag_response, chatbot, chatbot, api_name="bot_response")
162
+ logged = LoggedComponents(context=chatbot)
163
  return logged, demo
164
 
165
 
166
+ def build_feedback(components: LoggedComponents) -> gr.Blocks:
167
+ with gr.Blocks(theme=gr.themes.Soft(), title="Candid AI demo") as demo:
168
+ gr.Markdown("<h1>Help us improve this tool with your valuable feedback</h1>")
169
+
170
+ with gr.Row():
171
+ with gr.Column():
172
+ found_helpful = gr.Radio(
173
+ [True, False], label="Did you find what you were looking for?"
174
+ )
175
+ will_recommend = gr.Radio(
176
+ [True, False],
177
+ label="Will you recommend this Chatbot to others?",
178
+ )
179
+ comment = gr.Textbox(label="Additional comments (optional)", lines=4)
180
+ email = gr.Textbox(label="Your email (optional)", lines=1)
181
+ submit = gr.Button("Submit Feedback")
182
+
183
+ components["found_helpful"] = found_helpful
184
+ components["will_recommend"] = will_recommend
185
+ components["comments"] = comment
186
+ components["email"] = email
187
+
188
+ # pylint: disable=no-member
189
+ submit.click(
190
+ fn=send_feedback,
191
+ inputs=[
192
+ components["context"],
193
+ components["found_helpful"],
194
+ components["will_recommend"],
195
+ components["comments"],
196
+ components["email"]
197
+ ],
198
+ outputs=None,
199
+ show_api=False,
200
+ api_name=False,
201
+ preprocess=False,
202
+ )
203
+
204
+ return demo
205
+
206
+
207
  def build_app():
208
+ logger, candid_chat = build_rag_chat()
209
+ feedback = build_feedback(logger)
210
 
211
  with open(os.path.join(ROOT, "static", "chatStyle.css"), "r", encoding="utf8") as f:
212
  css_chat = f.read()
 
214
  demo = gr.TabbedInterface(
215
  interface_list=[
216
  candid_chat,
217
+ feedback
218
  ],
219
  tab_names=[
220
+ "Ask Candid",
221
+ "Feedback"
222
  ],
223
+ title="Ask Candid Assistant",
224
  theme=gr.themes.Soft(),
225
  css=css_chat,
226
  )
ask_candid/agents/schema.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, TypedDict, Sequence, Union, Annotated
2
+
3
+ from langchain_core.messages import BaseMessage
4
+ from langgraph.graph.message import add_messages
5
+
6
+
7
+ class Context(TypedDict):
8
+ """PCS + geonames context payload for common tasks like recommendations.
9
+ """
10
+ subject: List[str]
11
+ population: List[str]
12
+ geography: List[Union[str, int]]
13
+
14
+
15
+ class AgentState(TypedDict):
16
+ """State of the chat agent for the execution graph(s).
17
+ """
18
+ # The add_messages function defines how an update should be processed
19
+ # Default is to replace. add_messages says "append"
20
+ messages: Annotated[Sequence[BaseMessage], add_messages]
21
+ user_input: str
22
+ org_dict: Dict
23
+
24
+ # Recommendation-specific fields
25
+ intent: str
26
+ context: Context
27
+ recommendation: str
ask_candid/base/config/constants.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ START_SYSTEM_PROMPT = (
2
+ "You are a Candid subject matter expert on the social sector and philanthropy. "
3
+ "You should address the user's queries and stay on topic."
4
+ )
ask_candid/chat.py CHANGED
@@ -1,204 +1,12 @@
1
- from typing import List, Optional, Dict, Any, TypedDict, Annotated, Sequence
2
- from functools import partial
3
- import logging
4
- import os
5
 
6
  import gradio as gr
7
-
8
- from langchain_core.messages import AIMessage, BaseMessage
9
- from langchain_core.output_parsers import StrOutputParser
10
- from langchain_core.prompts import ChatPromptTemplate
11
  from langchain_core.language_models.llms import LLM
12
-
13
- from langgraph.prebuilt import tools_condition, ToolNode
14
  from langgraph.checkpoint.memory import MemorySaver
15
- from langgraph.graph.state import StateGraph
16
- from langgraph.graph.message import add_messages
17
- from langgraph.constants import START, END
18
-
19
- from ask_candid.tools.org_seach import extract_org_links_from_chatbot, embed_org_links_in_text, generate_org_link_dict
20
- from ask_candid.tools.question_reformulation import reformulate_question_using_history
21
- from ask_candid.utils import html_format_docs_chat, get_session_id
22
- from ask_candid.retrieval.elastic import retriever_tool
23
-
24
- ROOT = os.path.dirname(os.path.abspath(__file__))
25
- logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
26
- logger = logging.getLogger(__name__)
27
- logger.setLevel(logging.INFO)
28
-
29
- # TODO https://www.metadocs.co/2024/08/29/simple-domain-specific-corrective-rag-with-langchain-and-langgraph/
30
-
31
-
32
- class AgentState(TypedDict):
33
- # The add_messages function defines how an update should be processed
34
- # Default is to replace. add_messages says "append"
35
- messages: Annotated[Sequence[BaseMessage], add_messages]
36
- user_input: str
37
- org_dict: Dict
38
-
39
-
40
- def search_agent(state, llm: LLM, tools) -> AgentState:
41
- """Invokes the agent model to generate a response based on the current state. Given
42
- the question, it will decide to retrieve using the retriever tool, or simply end.
43
-
44
- Parameters
45
- ----------
46
- state : _type_
47
- The current state
48
- llm : LLM
49
- tools : _type_
50
- _description_
51
-
52
- Returns
53
- -------
54
- AgentState
55
- The updated state with the agent response appended to messages
56
- """
57
-
58
- logger.info("---SEARCH AGENT---")
59
- messages = state["messages"]
60
- question = messages[-1].content
61
-
62
- model = llm.bind_tools(tools)
63
- response = model.invoke(messages)
64
- # return a list, because this will get added to the existing list
65
- return {"messages": [response], "user_input": question}
66
-
67
-
68
- def generate_with_context(state, llm: LLM) -> AgentState:
69
- """Generate answer.
70
-
71
- Parameters
72
- ----------
73
- state : _type_
74
- The current state
75
- llm : LLM
76
- tools : _type_
77
- _description_
78
-
79
- Returns
80
- -------
81
- AgentState
82
- The updated state with the agent response appended to messages
83
- """
84
-
85
- logger.info("---GENERATE ANSWER---")
86
- messages = state["messages"]
87
- question = state["user_input"]
88
- last_message = messages[-1]
89
-
90
- sources_str = last_message.content
91
- sources_list = last_message.artifact # cannot use directly as list of Documents
92
- # converting to html string
93
- sources_html = html_format_docs_chat(sources_list)
94
- if sources_list:
95
- logger.info("---ADD SOURCES---")
96
- state["messages"].append(BaseMessage(content=sources_html, type="HTML"))
97
-
98
- # Prompt
99
- qa_system_prompt = """
100
- You are an assistant for question-answering tasks in the social and philanthropic sector. \n
101
- Use the following pieces of retrieved context to answer the question at the end. \n
102
- If you don't know the answer, just say that you don't know. \n
103
- Keep the response professional, friendly, and as concise as possible. \n
104
- Question: {question}
105
- Context: {context}
106
- Answer:
107
- """
108
-
109
- qa_prompt = ChatPromptTemplate(
110
- [
111
- ("system", qa_system_prompt),
112
- ("human", question),
113
- ]
114
- )
115
 
116
- rag_chain = qa_prompt | llm | StrOutputParser()
117
- response = rag_chain.invoke({"context": sources_str, "question": question})
118
- # couldn't figure out why returning usual "response" was seen as HumanMessage
119
- return {"messages": [AIMessage(content=response)], "user_input": question}
120
-
121
-
122
- def has_org_name(state: AgentState) -> AgentState:
123
- """
124
- Processes the latest message to extract organization links and determine the next step.
125
-
126
- Args:
127
- state (AgentState): The current state of the agent, including a list of messages.
128
-
129
- Returns:
130
- dict: A dictionary with the next agent action and, if available, a dictionary of organization links.
131
- """
132
- logger.info("---HAS ORG NAMES?---")
133
- messages = state["messages"]
134
- last_message = messages[-1].content
135
- output_list = extract_org_links_from_chatbot(last_message)
136
- link_dict = generate_org_link_dict(output_list) if output_list else {}
137
- if link_dict:
138
- logger.info("---FOUND ORG NAMES---")
139
- return {"next": "insert_org_link", "org_dict": link_dict}
140
- logger.info("---NO ORG NAMES FOUND---")
141
- return {"next": END, "messages": messages}
142
-
143
-
144
- def insert_org_link(state: AgentState) -> AgentState:
145
- """
146
- Embeds organization links in the latest message content and returns it as an AI message.
147
-
148
- Args:
149
- state (dict): The current state, including the organization links and latest message.
150
-
151
- Returns:
152
- dict: A dictionary with the updated message content as an AIMessage.
153
- """
154
- logger.info("---INSERT ORG LINKS---")
155
- messages = state["messages"]
156
- last_message = messages[-1].content
157
- messages.pop(-1) # Deleting the original message because we will append the same one but with links
158
- link_dict = state["org_dict"]
159
- last_message = embed_org_links_in_text(last_message, link_dict)
160
- return {"messages": [AIMessage(content=last_message)]}
161
-
162
-
163
- def build_compute_graph(llm: LLM, indices: List[str]) -> StateGraph:
164
- candid_retriever_tool = retriever_tool(indices=indices)
165
- retrieve = ToolNode([candid_retriever_tool])
166
- tools = [candid_retriever_tool]
167
-
168
- G = StateGraph(AgentState)
169
- # Add nodes
170
- G.add_node("reformulate", partial(reformulate_question_using_history, llm=llm))
171
- G.add_node("search_agent", partial(search_agent, llm=llm, tools=tools))
172
- G.add_node("retrieve", retrieve)
173
- G.add_node("generate_with_context", partial(generate_with_context, llm=llm))
174
- G.add_node("has_org_name", has_org_name)
175
- G.add_node("insert_org_link", insert_org_link)
176
-
177
- # Add edges
178
- G.add_edge(START, "reformulate")
179
- G.add_edge("reformulate", "search_agent")
180
- # Conditional edges from search_agent
181
- G.add_conditional_edges(
182
- source="search_agent",
183
- path=tools_condition,
184
- path_map={
185
- "tools": "retrieve",
186
- END: "has_org_name",
187
- },
188
- )
189
- G.add_edge("retrieve", "generate_with_context")
190
-
191
- # Add edges
192
- G.add_edge("generate_with_context", "has_org_name")
193
- # Use add_conditional_edges for has_org_name
194
- G.add_conditional_edges(
195
- "has_org_name",
196
- lambda x: x["next"], # Now we're accessing the 'next' key from the dict
197
- {"insert_org_link": "insert_org_link", END: END},
198
- )
199
- G.add_edge("insert_org_link", END)
200
-
201
- return G
202
 
203
 
204
  def run_chat(
@@ -207,17 +15,10 @@ def run_chat(
207
  history: List[Dict],
208
  llm: LLM,
209
  indices: Optional[List[str]] = None,
210
- ):
211
- # https://langchain-ai.github.io/langgraph/tutorials/rag/langgraph_agentic_rag/#graph
212
-
213
  if len(history) == 0:
214
- history.append({
215
- "role": "system",
216
- "content": (
217
- "You are a Candid subject matter expert on the social sector and philanthropy. "
218
- "You should address the user's queries and stay on topic."
219
- )
220
- })
221
 
222
  history.append({"role": "user", "content": user_input["text"]})
223
  inputs = {"messages": history}
@@ -225,14 +26,22 @@ def run_chat(
225
  thread_id = get_session_id(thread_id)
226
  config = {"configurable": {"thread_id": thread_id}}
227
 
228
- workflow = build_compute_graph(llm=llm, indices=indices)
229
 
230
  memory = MemorySaver() # TODO: don't use for Prod
231
  graph = workflow.compile(checkpointer=memory)
232
  response = graph.invoke(inputs, config=config)
233
  messages = response["messages"]
234
- last_message = messages[-1]
235
- ai_answer = last_message.content
 
 
 
 
 
 
 
 
236
  sources_html = ""
237
  for message in messages[-2:]:
238
  if message.type == "HTML":
@@ -240,12 +49,10 @@ def run_chat(
240
 
241
  history.append({"role": "assistant", "content": ai_answer})
242
  if sources_html:
243
- history.append(
244
- {
245
- "role": "assistant",
246
- "content": sources_html,
247
- "metadata": {"title": "Sources HTML"},
248
- }
249
- )
250
 
251
  return gr.MultimodalTextbox(value=None, interactive=True), history, thread_id
 
1
+ from typing import List, Dict, Tuple, Optional, Any
 
 
 
2
 
3
  import gradio as gr
 
 
 
 
4
  from langchain_core.language_models.llms import LLM
 
 
5
  from langgraph.checkpoint.memory import MemorySaver
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ from ask_candid.utils import get_session_id
8
+ from ask_candid.graph import build_compute_graph
9
+ from ask_candid.base.config.constants import START_SYSTEM_PROMPT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  def run_chat(
 
15
  history: List[Dict],
16
  llm: LLM,
17
  indices: Optional[List[str]] = None,
18
+ enable_recommendations: bool = False
19
+ ) -> Tuple[gr.MultimodalTextbox, List[Dict[str, Any]], str]:
 
20
  if len(history) == 0:
21
+ history.append({"role": "system", "content": START_SYSTEM_PROMPT})
 
 
 
 
 
 
22
 
23
  history.append({"role": "user", "content": user_input["text"]})
24
  inputs = {"messages": history}
 
26
  thread_id = get_session_id(thread_id)
27
  config = {"configurable": {"thread_id": thread_id}}
28
 
29
+ workflow = build_compute_graph(llm=llm, indices=indices, enable_recommendations=enable_recommendations)
30
 
31
  memory = MemorySaver() # TODO: don't use for Prod
32
  graph = workflow.compile(checkpointer=memory)
33
  response = graph.invoke(inputs, config=config)
34
  messages = response["messages"]
35
+
36
+ # Return the recommendation if there is any
37
+ recommendation = response.get("recommendation", None)
38
+ if recommendation:
39
+ ai_answer = recommendation
40
+ else:
41
+ # Fallback to the chatbot response
42
+ last_message = messages[-1]
43
+ ai_answer = last_message.content
44
+
45
  sources_html = ""
46
  for message in messages[-2:]:
47
  if message.type == "HTML":
 
49
 
50
  history.append({"role": "assistant", "content": ai_answer})
51
  if sources_html:
52
+ history.append({
53
+ "role": "assistant",
54
+ "content": sources_html,
55
+ "metadata": {"title": "Sources HTML"},
56
+ })
 
 
57
 
58
  return gr.MultimodalTextbox(value=None, interactive=True), history, thread_id
ask_candid/graph.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from functools import partial
3
+ import logging
4
+
5
+ from langchain_core.messages import AIMessage, BaseMessage
6
+ from langchain_core.output_parsers import StrOutputParser
7
+ from langchain_core.prompts import ChatPromptTemplate
8
+ from langchain_core.language_models.llms import LLM
9
+
10
+ from langgraph.prebuilt import tools_condition, ToolNode
11
+ from langgraph.graph.state import StateGraph
12
+ from langgraph.constants import START, END
13
+
14
+ from ask_candid.retrieval.elastic import retriever_tool
15
+ from ask_candid.tools.recommendation import (
16
+ detect_intent_with_llm,
17
+ determine_context,
18
+ make_recommendation
19
+ )
20
+ from ask_candid.tools.question_reformulation import reformulate_question_using_history
21
+ from ask_candid.tools.org_seach import has_org_name, insert_org_link
22
+ from ask_candid.tools.search import search_agent
23
+ from ask_candid.agents.schema import AgentState
24
+
25
+ from ask_candid.utils import html_format_docs_chat
26
+
27
+ logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
28
+ logger = logging.getLogger(__name__)
29
+ logger.setLevel(logging.INFO)
30
+
31
+
32
+ def generate_with_context(state: AgentState, llm: LLM) -> AgentState:
33
+ """Generate answer.
34
+
35
+ Parameters
36
+ ----------
37
+ state : AgentState
38
+ The current state
39
+ llm : LLM
40
+
41
+ Returns
42
+ -------
43
+ AgentState
44
+ The updated state with the agent response appended to messages
45
+ """
46
+
47
+ logger.info("---GENERATE ANSWER---")
48
+ messages = state["messages"]
49
+ question = state["user_input"]
50
+ last_message = messages[-1]
51
+
52
+ sources_str = last_message.content
53
+ sources_list = last_message.artifact # cannot use directly as list of Documents
54
+ # converting to html string
55
+ sources_html = html_format_docs_chat(sources_list)
56
+ if sources_list:
57
+ logger.info("---ADD SOURCES---")
58
+ state["messages"].append(BaseMessage(content=sources_html, type="HTML"))
59
+
60
+ # Prompt
61
+ qa_system_prompt = """
62
+ You are an assistant for question-answering tasks in the social and philanthropic sector. \n
63
+ Use the following pieces of retrieved context to answer the question at the end. \n
64
+ If you don't know the answer, just say that you don't know. \n
65
+ Keep the response professional, friendly, and as concise as possible. \n
66
+ Question: {question}
67
+ Context: {context}
68
+ Answer:
69
+ """
70
+
71
+ qa_prompt = ChatPromptTemplate([
72
+ ("system", qa_system_prompt),
73
+ ("human", question),
74
+ ])
75
+
76
+ rag_chain = qa_prompt | llm | StrOutputParser()
77
+ response = rag_chain.invoke({"context": sources_str, "question": question})
78
+ return {"messages": [AIMessage(content=response)], "user_input": question}
79
+
80
+
81
+ def add_recommendations_pipeline_(
82
+ G: StateGraph,
83
+ reformulation_node_name: str = "reformulate",
84
+ search_node_name: str = "search_agent"
85
+ ) -> None:
86
+ """Adds execution sub-graph for recommendation engine flow. Graph changes are in-place.
87
+
88
+ Parameters
89
+ ----------
90
+ G : StateGraph
91
+ Execution graph
92
+ reformulation_node_name : str, optional
93
+ Name of the node which reforumates input queries, by default "reformulate"
94
+ search_node_name : str, optional
95
+ Name of the node which executes document search + retrieval, by default "search_agent"
96
+ """
97
+
98
+ # Nodes for recommendation functionalities
99
+ G.add_node("detect_intent_with_llm", detect_intent_with_llm)
100
+ G.add_node("determine_context", determine_context)
101
+ G.add_node("make_recommendation", make_recommendation)
102
+
103
+ # Check for recommendation query first
104
+ # Execute until reaching END if user asks for recommendation
105
+ G.add_edge(reformulation_node_name, "detect_intent_with_llm")
106
+ G.add_conditional_edges(
107
+ source="detect_intent_with_llm",
108
+ path=lambda state: "determine_context" if state["intent"] in ["rfp", "funder"] else search_node_name,
109
+ path_map={
110
+ "determine_context": "determine_context",
111
+ search_node_name: search_node_name
112
+ },
113
+ )
114
+ G.add_edge("determine_context", "make_recommendation")
115
+ G.add_edge("make_recommendation", END)
116
+
117
+
118
+ def build_compute_graph(
119
+ llm: LLM,
120
+ indices: List[str],
121
+ enable_recommendations: bool = False
122
+ ) -> StateGraph:
123
+ """Execution graph builder, the output is the execution flow for an interaction with the assistant.
124
+
125
+ Parameters
126
+ ----------
127
+ llm : LLM
128
+ indices : List[str]
129
+ Semantic index names to search over
130
+ enable_recommendations : bool, optional
131
+ Set to `True` to allow the flow to generate recommendations based on context, by default False
132
+
133
+ Returns
134
+ -------
135
+ StateGraph
136
+ Execution graph
137
+ """
138
+
139
+ candid_retriever_tool = retriever_tool(indices=indices)
140
+ retrieve = ToolNode([candid_retriever_tool])
141
+ tools = [candid_retriever_tool]
142
+
143
+ G = StateGraph(AgentState)
144
+
145
+ G.add_node("reformulate", partial(reformulate_question_using_history, llm=llm))
146
+ G.add_node("search_agent", partial(search_agent, llm=llm, tools=tools))
147
+ G.add_node("retrieve", retrieve)
148
+ G.add_node("generate_with_context", partial(generate_with_context, llm=llm))
149
+ G.add_node("has_org_name", partial(has_org_name, llm=llm))
150
+ G.add_node("insert_org_link", insert_org_link)
151
+
152
+ if enable_recommendations:
153
+ add_recommendations_pipeline_(G, reformulation_node_name="reformulate", search_node_name="search_agent")
154
+ else:
155
+ G.add_edge("reformulate", "search_agent")
156
+
157
+ G.add_edge(START, "reformulate")
158
+ G.add_conditional_edges(
159
+ source="search_agent",
160
+ path=tools_condition,
161
+ path_map={
162
+ "tools": "retrieve",
163
+ END: "has_org_name",
164
+ },
165
+ )
166
+ G.add_edge("retrieve", "generate_with_context")
167
+ G.add_edge("generate_with_context", "has_org_name")
168
+ G.add_conditional_edges(
169
+ source="has_org_name",
170
+ path=lambda x: x["next"], # Now we're accessing the 'next' key from the dict
171
+ path_map={
172
+ "insert_org_link": "insert_org_link",
173
+ END: END
174
+ },
175
+ )
176
+ G.add_edge("insert_org_link", END)
177
+ return G
ask_candid/retrieval/elastic.py CHANGED
@@ -1,14 +1,17 @@
1
- from typing import List, Tuple, Dict, Iterable, Iterator, Optional, Any
2
  from dataclasses import dataclass
3
  from functools import partial
4
  from itertools import groupby
5
 
 
 
6
  from pydantic import BaseModel, Field
7
  from langchain_core.documents import Document
8
  from langchain_core.tools import Tool
9
 
10
  from elasticsearch import Elasticsearch
11
 
 
12
  from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA
13
  from ask_candid.base.config.data import ElasticIndexMapping, ALL_INDICES
14
 
@@ -34,6 +37,21 @@ def build_text_expansion_query(
34
  fields: Tuple[str],
35
  model_id: str = ".elser_model_2_linux-x86_64"
36
  ) -> Dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  output = []
39
 
@@ -60,7 +78,21 @@ def build_text_expansion_query(
60
  return {"query": {"bool": {"should": output}}}
61
 
62
 
63
- def query_builder(query: str, indices: List[str]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  queries = []
65
  if indices is None:
66
  indices = list(ALL_INDICES)
@@ -111,7 +143,19 @@ def query_builder(query: str, indices: List[str]):
111
  return queries
112
 
113
 
114
- def multi_search(queries: List[ElasticHitsResult]):
 
 
 
 
 
 
 
 
 
 
 
 
115
  results = []
116
  with Elasticsearch(
117
  cloud_id=SEMANTIC_ELASTIC_QA.cloud_id,
@@ -132,13 +176,77 @@ def multi_search(queries: List[ElasticHitsResult]):
132
  return results
133
 
134
 
135
- def get_query_results(search_text: str, indices: Optional[List[str]] = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  queries = query_builder(query=search_text, indices=indices)
137
  return multi_search(queries)
138
 
139
 
140
- def reranker(query_results: Iterable[ElasticHitsResult]) -> Iterator[ElasticHitsResult]:
141
- """Reranks Elasticsearch hits coming from multiple indicies/queries which may have scores on different scales.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  This will shuffle results
143
 
144
  Parameters
@@ -151,6 +259,7 @@ def reranker(query_results: Iterable[ElasticHitsResult]) -> Iterator[ElasticHits
151
  """
152
 
153
  results: List[ElasticHitsResult] = []
 
154
  for _, data in groupby(query_results, key=lambda x: x.index):
155
  data = list(data)
156
  max_score = max(data, key=lambda x: x.score).score
@@ -160,19 +269,44 @@ def reranker(query_results: Iterable[ElasticHitsResult]) -> Iterator[ElasticHits
160
  d.score = (d.score - min_score) / (max_score - min_score + 1e-9)
161
  results.append(d)
162
 
 
 
 
 
 
 
 
 
 
163
  yield from sorted(results, key=lambda x: x.score, reverse=True)
164
 
165
 
166
- def get_results(user_input: str, indices: List[str]) -> List[ElasticHitsResult]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  output = ["Search didn't return any Candid sources"]
168
- page_content=[]
169
  content = "Search didn't return any Candid sources"
170
  results = get_query_results(search_text=user_input, indices=indices)
171
  if results:
172
- output = get_reranked_results(results)
173
  for doc in output:
174
  page_content.append(doc.page_content)
175
- content = "/n/n".join(page_content)
 
176
  # for the tool we need to return a tuple for content_and_artifact type
177
  return content, output
178
 
@@ -197,27 +331,38 @@ def get_context(field_name: str, hit: ElasticHitsResult, context_length: int = 1
197
  chunks_with_context = []
198
  long_text = hit.source.get(f"{field_name}", "")
199
  inner_hits_field = f"embeddings.{field_name}.chunks"
200
- inner_hits = hit.inner_hits
201
- found_chunks = inner_hits.get(inner_hits_field, {})
202
  if found_chunks:
203
  hits = found_chunks.get("hits", {}).get("hits", [])
204
  for h in hits:
205
  chunk = h.get("fields", {})[inner_hits_field][0]["chunk"][0]
206
- chunk = chunk[3:-3] # cutting the middle because we may have tokenizing artefacts there
 
 
 
207
  # Find the start and end indices of the chunk in the large text
208
  start_index = long_text.find(chunk)
209
  if start_index != -1: # Chunk is found
210
  end_index = start_index + len(chunk)
211
  pre_start_index = max(0, start_index - context_length)
212
  post_end_index = min(len(long_text), end_index + context_length)
213
- context = long_text[pre_start_index:post_end_index]
214
- chunks_with_context.append(context)
215
- chunks_with_context_txt = '\n\n'.join(chunks_with_context)
216
 
217
- return chunks_with_context_txt
218
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
- def process_hit(hit: ElasticHitsResult) -> Document | None:
221
  if "issuelab-elser" in hit.index:
222
  combined_item_description = hit.source.get("combined_item_description", "") # title inside
223
  description = hit.source.get("description", "")
@@ -300,17 +445,43 @@ def process_hit(hit: ElasticHitsResult) -> Document | None:
300
  return doc
301
 
302
 
303
- def get_reranked_results(results: List[ElasticHitsResult]) -> List[Document]:
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  output = []
305
- for r in reranker(results):
306
  hit = process_hit(r)
307
- output.append(hit)
 
308
  return output
309
 
310
 
311
  def retriever_tool(indices: List[str]) -> Tool:
312
- # cannot use create_retriever_tool because it only provides content losing all metadata on the way
313
- # https://python.langchain.com/docs/how_to/custom_tools/#returning-artifacts-of-tool-execution
 
 
 
 
 
 
 
 
 
 
 
 
314
  return Tool(
315
  name="retrieve_social_sector_information",
316
  func=partial(get_results, indices=indices),
 
1
+ from typing import List, Tuple, Dict, Iterable, Iterator, Optional, Union, Any
2
  from dataclasses import dataclass
3
  from functools import partial
4
  from itertools import groupby
5
 
6
+ from torch.nn import functional as F
7
+
8
  from pydantic import BaseModel, Field
9
  from langchain_core.documents import Document
10
  from langchain_core.tools import Tool
11
 
12
  from elasticsearch import Elasticsearch
13
 
14
+ from ask_candid.services.small_lm import CandidSLM
15
  from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA
16
  from ask_candid.base.config.data import ElasticIndexMapping, ALL_INDICES
17
 
 
37
  fields: Tuple[str],
38
  model_id: str = ".elser_model_2_linux-x86_64"
39
  ) -> Dict[str, Any]:
40
+ """Builds a valid Elasticsearch text expansion query payload
41
+
42
+ Parameters
43
+ ----------
44
+ query : str
45
+ Search context string
46
+ fields : Tuple[str]
47
+ Semantic text field names
48
+ model_id : str, optional
49
+ ID of model deployed in Elasticsearch, by default ".elser_model_2_linux-x86_64"
50
+
51
+ Returns
52
+ -------
53
+ Dict[str, Any]
54
+ """
55
 
56
  output = []
57
 
 
78
  return {"query": {"bool": {"should": output}}}
79
 
80
 
81
+ def query_builder(query: str, indices: List[str]) -> List[Dict[str, Any]]:
82
+ """Builds Elasticsearch multi-search query payload
83
+
84
+ Parameters
85
+ ----------
86
+ query : str
87
+ Search context string
88
+ indices : List[str]
89
+ Semantic index names to search over
90
+
91
+ Returns
92
+ -------
93
+ List[Dict[str, Any]]
94
+ """
95
+
96
  queries = []
97
  if indices is None:
98
  indices = list(ALL_INDICES)
 
143
  return queries
144
 
145
 
146
+ def multi_search(queries: List[Dict[str, Any]]) -> List[ElasticHitsResult]:
147
+ """Runs multi-search query
148
+
149
+ Parameters
150
+ ----------
151
+ queries : List[Dict[str, Any]]
152
+ Pre-built multi-search query payload
153
+
154
+ Returns
155
+ -------
156
+ List[ElasticHitsResult]
157
+ """
158
+
159
  results = []
160
  with Elasticsearch(
161
  cloud_id=SEMANTIC_ELASTIC_QA.cloud_id,
 
176
  return results
177
 
178
 
179
+ def get_query_results(search_text: str, indices: Optional[List[str]] = None) -> List[ElasticHitsResult]:
180
+ """Builds and executes Elasticsearch data queries from a search string.
181
+
182
+ Parameters
183
+ ----------
184
+ search_text : str
185
+ Search context string
186
+ indices : Optional[List[str]], optional
187
+ Semantic index names to search over, by default None
188
+
189
+ Returns
190
+ -------
191
+ List[ElasticHitsResult]
192
+ """
193
+
194
  queries = query_builder(query=search_text, indices=indices)
195
  return multi_search(queries)
196
 
197
 
198
+ def retrieved_text(hits: Dict[str, Any]) -> str:
199
+ """Extracts retrieved sub-texts from documents which are strong hits from semantic queries for the purpose of
200
+ re-scoring by a secondary language model.
201
+
202
+ Parameters
203
+ ----------
204
+ hits : Dict[str, Any]
205
+
206
+ Returns
207
+ -------
208
+ str
209
+ """
210
+
211
+ text = []
212
+ for _, v in hits.items():
213
+ for h in (v.get("hits", {}).get("hits") or []):
214
+ for _, field in h.get("fields", {}).items():
215
+ for chunk in field:
216
+ if chunk.get("chunk"):
217
+ text.extend(chunk["chunk"])
218
+ return '\n'.join(text)
219
+
220
+
221
+ def cosine_rescore(query: str, contexts: List[str]) -> List[float]:
222
+ """Computes cosine scores between retrieved contexts and the original query to re-score results based on overall
223
+ relevance to the original query.
224
+
225
+ Parameters
226
+ ----------
227
+ query : str
228
+ Search context string
229
+ contexts : List[str]
230
+ Semantic field sub-texts, order is by document retrieved from the original multi-search query.
231
+
232
+ Returns
233
+ -------
234
+ List[float]
235
+ Scores in the same order as the input document contexts
236
+ """
237
+
238
+ nlp = CandidSLM()
239
+ X = nlp.encode([query, *contexts]).vectors
240
+ X = F.normalize(X, dim=-1, p=2.)
241
+ cosine = X[1:] @ X[:1].T
242
+ return cosine.flatten().cpu().numpy().tolist()
243
+
244
+
245
+ def reranker(
246
+ query_results: Iterable[ElasticHitsResult],
247
+ search_text: Optional[str] = None
248
+ ) -> Iterator[ElasticHitsResult]:
249
+ """Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales.
250
  This will shuffle results
251
 
252
  Parameters
 
259
  """
260
 
261
  results: List[ElasticHitsResult] = []
262
+ texts: List[str] = []
263
  for _, data in groupby(query_results, key=lambda x: x.index):
264
  data = list(data)
265
  max_score = max(data, key=lambda x: x.score).score
 
269
  d.score = (d.score - min_score) / (max_score - min_score + 1e-9)
270
  results.append(d)
271
 
272
+ if search_text:
273
+ text = retrieved_text(d.inner_hits)
274
+ texts.append(text)
275
+
276
+ # if search_text and len(texts) == len(results):
277
+ # scores = cosine_rescore(search_text, texts)
278
+ # for r, s in zip(results, scores):
279
+ # r.score = s
280
+
281
  yield from sorted(results, key=lambda x: x.score, reverse=True)
282
 
283
 
284
+ def get_results(user_input: str, indices: List[str]) -> Tuple[str, List[Document]]:
285
+ """End-to-end search and re-rank function.
286
+
287
+ Parameters
288
+ ----------
289
+ user_input : str
290
+ Search context string
291
+ indices : List[str]
292
+ Semantic index names to search over
293
+
294
+ Returns
295
+ -------
296
+ Tuple[str, List[Document]]
297
+ (concatenated text from search results, documents list)
298
+ """
299
+
300
  output = ["Search didn't return any Candid sources"]
301
+ page_content = []
302
  content = "Search didn't return any Candid sources"
303
  results = get_query_results(search_text=user_input, indices=indices)
304
  if results:
305
+ output = get_reranked_results(results, search_text=user_input)
306
  for doc in output:
307
  page_content.append(doc.page_content)
308
+ content = "\n\n".join(page_content)
309
+
310
  # for the tool we need to return a tuple for content_and_artifact type
311
  return content, output
312
 
 
331
  chunks_with_context = []
332
  long_text = hit.source.get(f"{field_name}", "")
333
  inner_hits_field = f"embeddings.{field_name}.chunks"
334
+ found_chunks = hit.inner_hits.get(inner_hits_field, {})
 
335
  if found_chunks:
336
  hits = found_chunks.get("hits", {}).get("hits", [])
337
  for h in hits:
338
  chunk = h.get("fields", {})[inner_hits_field][0]["chunk"][0]
339
+
340
+ # cutting the middle because we may have tokenizing artifacts there
341
+ chunk = chunk[3: -3]
342
+
343
  # Find the start and end indices of the chunk in the large text
344
  start_index = long_text.find(chunk)
345
  if start_index != -1: # Chunk is found
346
  end_index = start_index + len(chunk)
347
  pre_start_index = max(0, start_index - context_length)
348
  post_end_index = min(len(long_text), end_index + context_length)
349
+ chunks_with_context.append(long_text[pre_start_index:post_end_index])
350
+
351
+ return '\n\n'.join(chunks_with_context)
352
 
 
353
 
354
+ def process_hit(hit: ElasticHitsResult) -> Union[Document, None]:
355
+ """Parse Elasticsearch hit results into data structures handled by the RAG pipeline.
356
+
357
+ Parameters
358
+ ----------
359
+ hit : ElasticHitsResult
360
+
361
+ Returns
362
+ -------
363
+ Union[Document, None]
364
+ """
365
 
 
366
  if "issuelab-elser" in hit.index:
367
  combined_item_description = hit.source.get("combined_item_description", "") # title inside
368
  description = hit.source.get("description", "")
 
445
  return doc
446
 
447
 
448
+ def get_reranked_results(results: List[ElasticHitsResult], search_text: Optional[str] = None) -> List[Document]:
449
+ """Run data re-ranking and document building for tool usage.
450
+
451
+ Parameters
452
+ ----------
453
+ results : List[ElasticHitsResult]
454
+ search_text : Optional[str], optional
455
+ Search context string, by default None
456
+
457
+ Returns
458
+ -------
459
+ List[Document]
460
+ """
461
+
462
  output = []
463
+ for r in reranker(results, search_text=search_text):
464
  hit = process_hit(r)
465
+ if hit is not None:
466
+ output.append(hit)
467
  return output
468
 
469
 
470
  def retriever_tool(indices: List[str]) -> Tool:
471
+ """Tool component for use in conditional edge building for RAG execution graph.
472
+ Cannot use `create_retriever_tool` because it only provides content losing all metadata on the way
473
+ https://python.langchain.com/docs/how_to/custom_tools/#returning-artifacts-of-tool-execution
474
+
475
+ Parameters
476
+ ----------
477
+ indices : List[str]
478
+ Semantic index names to search over
479
+
480
+ Returns
481
+ -------
482
+ Tool
483
+ """
484
+
485
  return Tool(
486
  name="retrieve_social_sector_information",
487
  func=partial(get_results, indices=indices),
ask_candid/services/small_lm.py CHANGED
@@ -15,6 +15,7 @@ class Encoding:
15
 
16
  class CandidSLM(LambdaInvokeBase):
17
  """Wrapper around Candid's custom small language model.
 
18
  This services includes:
19
  * text encoding
20
  * document summarization
@@ -43,7 +44,7 @@ class CandidSLM(LambdaInvokeBase):
43
  )
44
 
45
  def encode(self, text: List[str]) -> Encoding:
46
- response = self._submit_request({"text": text})
47
 
48
  output = Encoding(
49
  inputs=(response.get("inputs") or []),
 
15
 
16
  class CandidSLM(LambdaInvokeBase):
17
  """Wrapper around Candid's custom small language model.
18
+ For more details see https://dev.azure.com/guidestar/DataScience/_git/graph-ai?path=/releases/language.
19
  This services includes:
20
  * text encoding
21
  * document summarization
 
44
  )
45
 
46
  def encode(self, text: List[str]) -> Encoding:
47
+ response = self._submit_request({"text": text, "path": self.Tasks.ENCODE.value})
48
 
49
  output = Encoding(
50
  inputs=(response.get("inputs") or []),
ask_candid/tools/org_seach.py CHANGED
@@ -1,18 +1,26 @@
1
  from typing import List
 
2
  import re
3
 
4
  from fuzzywuzzy import fuzz
5
 
6
  from langchain.output_parsers.openai_tools import JsonOutputToolsParser
7
- from langchain_openai.chat_models import ChatOpenAI
8
  from langchain_core.runnables import RunnableSequence
9
  from langchain_core.prompts import ChatPromptTemplate
 
 
 
10
  from pydantic import BaseModel, Field
11
 
 
12
  from ask_candid.services.org_search import OrgSearch
13
- from ask_candid.base.config.rest import OPENAI
14
 
15
  search = OrgSearch()
 
 
 
16
 
17
 
18
  class OrganizationNames(BaseModel):
@@ -20,7 +28,7 @@ class OrganizationNames(BaseModel):
20
  orgnames: List[str] = Field(description="List of organization names")
21
 
22
 
23
- def extract_org_links_from_chatbot(chatbot_output: str):
24
  """
25
  Extracts a list of organization names from the provided text.
26
 
@@ -51,9 +59,10 @@ def extract_org_links_from_chatbot(chatbot_output: str):
51
 
52
  try:
53
  parser = JsonOutputToolsParser()
54
- llm = ChatOpenAI(model="gpt-4o", api_key=OPENAI["key"]).bind_tools([OrganizationNames])
 
55
  prompt = ChatPromptTemplate.from_template(prompt)
56
- chain = RunnableSequence(prompt, llm, parser)
57
 
58
  # Run the chain with the input data
59
  result = chain.invoke({"chatbot_output": chatbot_output})
@@ -192,3 +201,44 @@ def embed_org_links_in_text(input_text: str, org_link_dict: dict):
192
  return input_text
193
 
194
  return input_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import List
2
+ import logging
3
  import re
4
 
5
  from fuzzywuzzy import fuzz
6
 
7
  from langchain.output_parsers.openai_tools import JsonOutputToolsParser
8
+ # from langchain_openai.chat_models import ChatOpenAI
9
  from langchain_core.runnables import RunnableSequence
10
  from langchain_core.prompts import ChatPromptTemplate
11
+ from langchain_core.language_models.llms import LLM
12
+ from langchain_core.messages import AIMessage
13
+ from langgraph.constants import END
14
  from pydantic import BaseModel, Field
15
 
16
+ from ask_candid.agents.schema import AgentState
17
  from ask_candid.services.org_search import OrgSearch
18
+ # from ask_candid.base.config.rest import OPENAI
19
 
20
  search = OrgSearch()
21
+ logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
22
+ logger = logging.getLogger(__name__)
23
+ logger.setLevel(logging.INFO)
24
 
25
 
26
  class OrganizationNames(BaseModel):
 
28
  orgnames: List[str] = Field(description="List of organization names")
29
 
30
 
31
+ def extract_org_links_from_chatbot(chatbot_output: str, llm: LLM):
32
  """
33
  Extracts a list of organization names from the provided text.
34
 
 
59
 
60
  try:
61
  parser = JsonOutputToolsParser()
62
+ # llm = ChatOpenAI(model="gpt-4o", api_key=OPENAI["key"]).bind_tools([OrganizationNames])
63
+ model = llm.bind_tools([OrganizationNames])
64
  prompt = ChatPromptTemplate.from_template(prompt)
65
+ chain = RunnableSequence(prompt, model, parser)
66
 
67
  # Run the chain with the input data
68
  result = chain.invoke({"chatbot_output": chatbot_output})
 
201
  return input_text
202
 
203
  return input_text
204
+
205
+
206
+ def has_org_name(state: AgentState, llm: LLM) -> AgentState:
207
+ """
208
+ Processes the latest message to extract organization links and determine the next step.
209
+
210
+ Args:
211
+ state (AgentState): The current state of the agent, including a list of messages.
212
+
213
+ Returns:
214
+ dict: A dictionary with the next agent action and, if available, a dictionary of organization links.
215
+ """
216
+ logger.info("---HAS ORG NAMES?---")
217
+ messages = state["messages"]
218
+ last_message = messages[-1].content
219
+ output_list = extract_org_links_from_chatbot(last_message, llm=llm)
220
+ link_dict = generate_org_link_dict(output_list) if output_list else {}
221
+ if link_dict:
222
+ logger.info("---FOUND ORG NAMES---")
223
+ return {"next": "insert_org_link", "org_dict": link_dict}
224
+ logger.info("---NO ORG NAMES FOUND---")
225
+ return {"next": END, "messages": messages}
226
+
227
+
228
+ def insert_org_link(state: AgentState) -> AgentState:
229
+ """
230
+ Embeds organization links in the latest message content and returns it as an AI message.
231
+
232
+ Args:
233
+ state (dict): The current state, including the organization links and latest message.
234
+
235
+ Returns:
236
+ dict: A dictionary with the updated message content as an AIMessage.
237
+ """
238
+ logger.info("---INSERT ORG LINKS---")
239
+ messages = state["messages"]
240
+ last_message = messages[-1].content
241
+ messages.pop(-1) # Deleting the original message because we will append the same one but with links
242
+ link_dict = state["org_dict"]
243
+ last_message = embed_org_links_in_text(last_message, link_dict)
244
+ return {"messages": [AIMessage(content=last_message)]}
ask_candid/tools/recommendation.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ from openai import OpenAI
5
+ import requests
6
+
7
+ from ask_candid.agents.schema import AgentState, Context
8
+ from ask_candid.base.config.rest import OPENAI
9
+
10
+ logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
11
+ logger = logging.getLogger(__name__)
12
+ logger.setLevel(logging.INFO)
13
+
14
+
15
+ def detect_intent_with_llm(state: AgentState) -> AgentState:
16
+ """Detect query intent (which type of recommendation) and update the state."""
17
+ logger.info("---DETECT INTENT---")
18
+ client = OpenAI(api_key=OPENAI['key'])
19
+ # query = state["messages"][-1]["content"]
20
+ query = state["messages"][-1].content
21
+ prompt = f"""Classify the following query as one of the following categories:
22
+ - 'none': The query is not asking for funding recommendations.
23
+ - 'funder': The query is asking for recommendations about funders, such as foundations or donors, who might provide longer-term or general funding.
24
+ - 'rfp': The query is asking for recommendations about specific, active Requests for Proposals (RFPs) that typically focus on short-term, recent opportunities with a deadline.
25
+
26
+ Your classification should consider:
27
+ 1. RFPs often focus on active and time-bound opportunities for specific projects or programs.
28
+ 2. Funders refer to broader, long-term funding sources like organizations or individuals who offer grants.
29
+
30
+ Query: "{query}"
31
+ """
32
+
33
+ response = client.chat.completions.create(
34
+ # model="gpt-4-turbo",
35
+ model="gpt-4o",
36
+ messages=[{"role": "system", "content": prompt}],
37
+ max_tokens=10,
38
+ stop=["\n"]
39
+ )
40
+
41
+ intent = response.choices[0].message.content.strip().lower()
42
+ state["intent"] = intent.strip("'").strip('"') # Remove extra quotes
43
+ return state
44
+
45
+
46
+ def determine_context(state: AgentState) -> AgentState:
47
+ """Extract subject/geography/population codes and update the state."""
48
+ logger.info("---GETTING RECOMMENDATION CONTEXT---")
49
+ query = state["messages"][-1].content
50
+
51
+ subject_codes, population_codes, geo_ids = [], [], []
52
+
53
+ # subject and population
54
+ autocoding_headers = {
55
+ 'x-api-key': os.getenv("AUTOCODING_API_KEY"),
56
+ 'Content-Type': 'application/json'
57
+ }
58
+ autocoding_params = {
59
+ 'text': query,
60
+ 'taxonomy': 'pcs-v3'
61
+ }
62
+ autocoding_response = requests.get(
63
+ os.getenv("AUTOCODING_API_URL"),
64
+ headers=autocoding_headers,
65
+ params=autocoding_params,
66
+ timeout=30
67
+ )
68
+ if autocoding_response.status_code == 200:
69
+ returned_pcs = autocoding_response.json()["data"]
70
+ population_codes = [item['full_code'] for item in returned_pcs.get("population", [])]
71
+ subject_codes = [item['full_code'] for item in returned_pcs.get("subject", [])]
72
+
73
+ # geography
74
+ geo_headers = {
75
+ 'x-api-key': os.getenv("GEO_API_KEY"),
76
+ 'Content-Type': 'application/json'
77
+ }
78
+ geo_data = {
79
+ 'text': query
80
+ }
81
+ geo_response = requests.post(os.getenv("GEO_API_URL"), headers=geo_headers, json=geo_data, timeout=30)
82
+ if geo_response.status_code == 200:
83
+ entities = geo_response.json()['data']['entities']
84
+ geo_ids = [entity['geo']['id'] for entity in entities if 'id' in entity['geo']]
85
+
86
+ state["context"] = Context(
87
+ subject=subject_codes,
88
+ population=population_codes,
89
+ geography=geo_ids
90
+ )
91
+ return state
92
+
93
+
94
+ def make_recommendation(state: AgentState) -> AgentState:
95
+ """Make an API call based on the extracted context and update the state."""
96
+ # query = state["messages"][-1]["content"]
97
+ logger.info("---RECOMMENDING---")
98
+ org_id = "6908122"
99
+ funder_or_rfp = state["intent"]
100
+
101
+ # Extract context
102
+ contexts = state["context"]
103
+ subject_codes = contexts.get("subject", [])
104
+ population_codes = contexts.get("population", [])
105
+ geo_ids = contexts.get("geography", [])
106
+
107
+ # Prepare parameters
108
+ params = {
109
+ "subjects": ",".join(subject_codes),
110
+ "geos": ",".join([str(geo) for geo in geo_ids]),
111
+ "populations": ",".join(population_codes)
112
+ }
113
+ headers = {"x-api-key": os.getenv("FUNDER_REC_API_KEY")}
114
+ base_url = os.getenv("FUNDER_REC_API_URL")
115
+
116
+ # Initialize response
117
+ response = None
118
+
119
+ recommendation_display_text = ""
120
+
121
+ try:
122
+ # Make the API call based on intent
123
+ if funder_or_rfp == "funder":
124
+ response = requests.get(base_url, headers=headers, params=params, timeout=30)
125
+ elif funder_or_rfp == "rfp":
126
+ params["candid_entity_id"] = org_id #placeholder
127
+ response = requests.get(f"{base_url}/rfp", headers=headers, params=params, timeout=30)
128
+ else:
129
+ # Handle unknown intent
130
+ state["recommendation"] = "Unknown intent. Intent 'funder' or 'rfp' expected."
131
+ return state
132
+
133
+ # Validate response
134
+ if response and response.status_code == 200:
135
+ recommendations = response.json().get("recommendations", [])
136
+ if recommendations:
137
+ if funder_or_rfp == "funder":
138
+ # Format recommendations
139
+ recommendation_display_text = "Here are the top 10 recommendations. Click their profiles to learn more:\n" + "\n".join([
140
+ f"{recommendation['funder_data']['main_sort_name']} - Profile: https://app.candid.org/profile/{recommendation['funder_id']}"
141
+ for recommendation in recommendations
142
+ ])
143
+ elif funder_or_rfp == "rfp":
144
+ recommendation_display_text = "Here are the top recommendations:\n" + "\n".join([
145
+ f"Title: {rec['title']}\n"
146
+ f"Funder: {rec['funder_name']}\n"
147
+ f"Amount: {rec.get('amount', 'Not specified')}\n"
148
+ f"Description: {rec.get('description', 'No description available')}\n"
149
+ f"Deadline: {rec.get('deadline', 'No deadline provided')}\n"
150
+ f"Application URL: {rec.get('application_url', 'No URL available')}\n"
151
+ for rec in recommendations
152
+ ])
153
+ else:
154
+ # No recommendations found
155
+ recommendation_display_text = "No recommendations were found for your query. Please try refining your search criteria."
156
+ elif response and response.status_code == 400:
157
+ # Handle bad request
158
+ error_details = response.json()
159
+ recommendation_display_text = (
160
+ "An error occurred while processing your request. "
161
+ f"Details: {error_details.get('message', 'Unknown error.')}"
162
+ )
163
+ elif response:
164
+ # Handle other unexpected status codes
165
+ recommendation_display_text = (
166
+ f"An unexpected error occurred (Status Code: {response.status_code}). "
167
+ "Please try again later or contact support if the problem persists."
168
+ )
169
+ else:
170
+ # Handle case where response is None
171
+ recommendation_display_text = "No response from the server. Please check your connection or try again later."
172
+
173
+ except requests.exceptions.RequestException as e:
174
+ # Handle network-related errors
175
+ recommendation_display_text = (
176
+ "A network error occurred while trying to connect to the recommendation service. "
177
+ f"Details: {str(e)}"
178
+ )
179
+ except Exception as e:
180
+ # Handle other unexpected errors
181
+ print(params)
182
+ recommendation_display_text = (
183
+ "An unexpected error occurred while processing your request. "
184
+ f"Details: {str(e)}"
185
+ )
186
+
187
+ # Update state with recommendations or error messages
188
+ state["recommendation"] = recommendation_display_text
189
+ return state
ask_candid/tools/search.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import logging
3
+
4
+ from langchain_core.language_models.llms import LLM
5
+ from langchain_core.tools import Tool
6
+
7
+ from ask_candid.agents.schema import AgentState
8
+
9
+ logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
10
+ logger = logging.getLogger(__name__)
11
+ logger.setLevel(logging.INFO)
12
+
13
+
14
+ def search_agent(state, llm: LLM, tools: List[Tool]) -> AgentState:
15
+ """Invokes the agent model to generate a response based on the current state. Given
16
+ the question, it will decide to retrieve using the retriever tool, or simply end.
17
+
18
+ Parameters
19
+ ----------
20
+ state : _type_
21
+ The current state
22
+ llm : LLM
23
+ tools : List[Tool]
24
+
25
+ Returns
26
+ -------
27
+ AgentState
28
+ The updated state with the agent response appended to messages
29
+ """
30
+
31
+ logger.info("---SEARCH AGENT---")
32
+ messages = state["messages"]
33
+ question = messages[-1].content
34
+
35
+ model = llm.bind_tools(tools)
36
+ response = model.invoke(messages)
37
+ # return a list, because this will get added to the existing list
38
+ return {"messages": [response], "user_input": question}
static/chatStyle.css CHANGED
@@ -18,21 +18,22 @@
18
  }
19
 
20
  .source-item {
21
- display: block;
22
  min-width: 100px;
23
- max-width: 45%;
 
24
  height: fit-content;
25
  background-color: #febe10;
26
  color: black;
27
  padding: 5px;
28
  text-align: center;
29
  border-radius: 5px;
30
- box-shadow: 0 2px 5px 0 rgba(0, 0, 0,0.2);
31
  font-size: medium;
32
  }
33
 
34
  .ssearch-source {
35
- text-decoration: none;
36
  box-sizing: border-box;
37
  }
38
 
@@ -42,7 +43,7 @@ button.upload-button.svelte-1d7elt4 {
42
 
43
  .candid-org-link {
44
  font-weight: bold;
45
- text-decoration: none;
46
  }
47
 
48
  .candid-app-link {
 
18
  }
19
 
20
  .source-item {
21
+ /* display: block;
22
  min-width: 100px;
23
+ max-width: 80%;
24
+ margin: 0 auto;
25
  height: fit-content;
26
  background-color: #febe10;
27
  color: black;
28
  padding: 5px;
29
  text-align: center;
30
  border-radius: 5px;
31
+ box-shadow: 0 2px 5px 0 rgba(0, 0, 0,0.2); */
32
  font-size: medium;
33
  }
34
 
35
  .ssearch-source {
36
+ /* text-decoration: none; */
37
  box-sizing: border-box;
38
  }
39
 
 
43
 
44
  .candid-org-link {
45
  font-weight: bold;
46
+ /* text-decoration: none; */
47
  }
48
 
49
  .candid-app-link {
static/elastic_agent_worflow.jpeg ADDED