Spaces:
Running
Running
Smarter document context retrieval
Browse files* Retrieved documents re-ranking w/ SPLADE-v3
* Enable news as a default source
- app.py +2 -1
- ask_candid/agents/elastic.py +246 -54
- ask_candid/retrieval/elastic.py +39 -160
- ask_candid/retrieval/sources/candid_blog.py +22 -1
- ask_candid/retrieval/sources/candid_help.py +20 -1
- ask_candid/retrieval/sources/candid_learning.py +22 -1
- ask_candid/retrieval/sources/candid_news.py +14 -1
- ask_candid/retrieval/sources/issuelab.py +27 -2
- ask_candid/retrieval/sources/schema.py +12 -1
- ask_candid/retrieval/sources/utils.py +47 -0
- ask_candid/retrieval/sources/youtube.py +20 -1
- ask_candid/retrieval/sparse_lexical.py +14 -4
- ask_candid/tools/elastic/index_search_tool.py +9 -2
- ask_candid/tools/question_reformulation.py +43 -39
app.py
CHANGED
@@ -113,7 +113,8 @@ def build_rag_chat() -> Tuple[LoggedComponents, gr.Blocks]:
|
|
113 |
with gr.Accordion(label="Advanced settings", open=False):
|
114 |
es_indices = gr.CheckboxGroup(
|
115 |
choices=list(ALL_INDICES),
|
116 |
-
value=[idx for idx in ALL_INDICES if "news" not in idx],
|
|
|
117 |
label="Sources to include",
|
118 |
interactive=True,
|
119 |
)
|
|
|
113 |
with gr.Accordion(label="Advanced settings", open=False):
|
114 |
es_indices = gr.CheckboxGroup(
|
115 |
choices=list(ALL_INDICES),
|
116 |
+
# value=[idx for idx in ALL_INDICES if "news" not in idx],
|
117 |
+
value=list(ALL_INDICES),
|
118 |
label="Sources to include",
|
119 |
interactive=True,
|
120 |
)
|
ask_candid/agents/elastic.py
CHANGED
@@ -2,6 +2,9 @@ from typing import TypedDict, List
|
|
2 |
from functools import partial
|
3 |
import json
|
4 |
import ast
|
|
|
|
|
|
|
5 |
from pydantic import BaseModel, Field
|
6 |
|
7 |
from langchain_core.runnables import RunnableSequence
|
@@ -24,10 +27,118 @@ from ask_candid.tools.elastic.index_search_tool import create_search_tool
|
|
24 |
tools = [
|
25 |
IndexShowDataTool(),
|
26 |
IndexDetailsTool(),
|
27 |
-
create_search_tool(),
|
28 |
]
|
29 |
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
class GraphState(TypedDict):
|
32 |
query: str = Field(
|
33 |
..., description="The user's query to be processed by the system."
|
@@ -47,6 +158,7 @@ class GraphState(TypedDict):
|
|
47 |
...,
|
48 |
description="The Elasticsearch query result generated or used by the agent.",
|
49 |
)
|
|
|
50 |
|
51 |
|
52 |
class AnalysisResult(BaseModel):
|
@@ -334,8 +446,6 @@ def build_compute_graph(llm: LLM) -> StateGraph:
|
|
334 |
|
335 |
|
336 |
class ElasticGraph(StateGraph):
|
337 |
-
"""Elastic Seach Agent State Graph"""
|
338 |
-
|
339 |
llm: LLM
|
340 |
tools: List[Tool]
|
341 |
|
@@ -345,6 +455,41 @@ class ElasticGraph(StateGraph):
|
|
345 |
self.tools = tools
|
346 |
self.construct_graph()
|
347 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
348 |
def agent_factory(self) -> AgentExecutor:
|
349 |
"""
|
350 |
Creates and configures an AgentExecutor instance for interacting with Elasticsearch.
|
@@ -387,7 +532,7 @@ class ElasticGraph(StateGraph):
|
|
387 |
return_intermediate_steps=True,
|
388 |
)
|
389 |
|
390 |
-
def agent_factory_claude(self) -> AgentExecutor:
|
391 |
"""
|
392 |
Creates and configures an AgentExecutor instance for interacting with Elasticsearch.
|
393 |
|
@@ -400,15 +545,6 @@ class ElasticGraph(StateGraph):
|
|
400 |
AgentExecutor: Configured agent ready to execute tasks with specified tools,
|
401 |
providing detailed intermediate steps for transparency.
|
402 |
"""
|
403 |
-
prefix = """
|
404 |
-
You are an intelligent agent tasked with generating accurate Elasticsearch DSL queries.
|
405 |
-
Analyze the intent behind the query and determine the appropriate Elasticsearch operations required.
|
406 |
-
Guidelines for generating right elastic seach query:
|
407 |
-
1. Automatically determine whether to return document hits or aggregation results based on the query structure.
|
408 |
-
2. Use keyword fields instead of text fields for aggregations and sorting to avoid fielddata errors
|
409 |
-
3. Avoid using field.keyword if a keyword field is already present to prevent redundant queries.
|
410 |
-
4. Ensure efficient query execution by selecting appropriate query types for filtering, searching, and aggregating.
|
411 |
-
"""
|
412 |
prompt = ChatPromptTemplate.from_messages(
|
413 |
[
|
414 |
("system", f"You are a helpful elasticsearch assistant. {prefix}"),
|
@@ -418,9 +554,19 @@ class ElasticGraph(StateGraph):
|
|
418 |
]
|
419 |
)
|
420 |
|
421 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
422 |
agent_executor = AgentExecutor.from_agent_and_tools(
|
423 |
-
agent=agent,
|
|
|
|
|
|
|
424 |
)
|
425 |
# Create the agent
|
426 |
return agent_executor
|
@@ -467,6 +613,8 @@ class ElasticGraph(StateGraph):
|
|
467 |
|
468 |
def grant_index_agent(self, state: GraphState) -> GraphState:
|
469 |
print("> Grant Index Agent")
|
|
|
|
|
470 |
input_data = {
|
471 |
"input": f"""
|
472 |
You are an Elasticsearch database agent designed to accurately understand and respond to user queries. Follow these steps:
|
@@ -479,52 +627,51 @@ class ElasticGraph(StateGraph):
|
|
479 |
Users may not always provide the exact name, so the Elasticsearch query should accommodate partial or incomplete names
|
480 |
by searching for relevant keywords.
|
481 |
6. Present the response in a clear and natural language format, addressing the user's question directly.
|
482 |
-
|
483 |
-
|
484 |
Description of some of the fields in the index but rest of the fields which are not here should be easy to understand:
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
pcs_v3: PCS is taxonomy, describing the work of grantmakers, recipient organizations and the philanthropic transactions between those entities.
|
490 |
-
The facets of the PCS illuminate the work and answer the following questions about philanthropy:
|
491 |
-
Who? = Population Served
|
492 |
-
What? = Subject and Organization Type
|
493 |
-
How? = Support Strategy and Transaction Type
|
494 |
-
the Facets:
|
495 |
-
Subjects: Describes WHAT is being supported. Example: Elementary education or Clean water supply.
|
496 |
-
Populations: Describes WHO is being supported. Example: Girls or People with disabilities.
|
497 |
-
Organization Type: Describes WHAT type of organization is providing or receiving support.
|
498 |
-
Transaction Type: Describes HOW support is being provided.
|
499 |
-
Support Strategies: Describes HOW activities are being implemented.
|
500 |
-
|
501 |
-
pcs_v3 itself is in a json format:
|
502 |
-
key - subject
|
503 |
-
value: it is a list of dictionary so might need to loop around to find the particular aspect
|
504 |
-
hierarchy: (it is a list having subject name)
|
505 |
-
[
|
506 |
-
{{
|
507 |
-
'name':
|
508 |
-
}},
|
509 |
-
{{
|
510 |
-
'name':
|
511 |
-
}}
|
512 |
-
]
|
513 |
-
Before Writing elastic search query think through which field to use
|
514 |
-
|
515 |
-
Note: first you should focus on query `text` then look into pcs_v3. Make sure you pick the right size for the query
|
516 |
|
|
|
517 |
User's query:
|
518 |
```{state["query"]}```
|
519 |
"""
|
520 |
}
|
521 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
522 |
res = agent_exec.invoke(input_data)
|
523 |
state["agent_out"] = res["output"]
|
524 |
-
|
525 |
es_queries, es_results = {}, {}
|
526 |
for i, action in enumerate(res.get("intermediate_steps", []), start=1):
|
527 |
if action[0].tool == "elastic_index_search_tool":
|
|
|
528 |
es_queries[f"query_{i}"] = json.loads(
|
529 |
action[0].tool_input.get("query") or "{}"
|
530 |
)
|
@@ -550,6 +697,18 @@ class ElasticGraph(StateGraph):
|
|
550 |
"""
|
551 |
|
552 |
print("> Org Index Agent")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
553 |
input_data = {
|
554 |
"input": f"""
|
555 |
You are an Elasticsearch database agent designed to accurately understand and respond to user queries. Follow these steps:
|
@@ -557,14 +716,45 @@ class ElasticGraph(StateGraph):
|
|
557 |
1. Understand the user query to determine the required information.
|
558 |
2. Query the indices in the Elasticsearch database.
|
559 |
3. Retrieve the mappings and field names relevant to the query.
|
560 |
-
4. Use the `
|
561 |
5. Present the response in a clear and natural language format, addressing the user's question directly.
|
|
|
|
|
|
|
|
|
|
|
562 |
|
563 |
-
User's
|
564 |
```{state["query"]}```
|
565 |
"""
|
566 |
}
|
567 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
568 |
res = agent_exec.invoke(input_data)
|
569 |
state["agent_out"] = res["output"]
|
570 |
|
@@ -622,13 +812,15 @@ class ElasticGraph(StateGraph):
|
|
622 |
"""
|
623 |
|
624 |
# Add nodes
|
|
|
625 |
self.add_node("analyse", self.analyse_query)
|
626 |
self.add_node("grant-index", self.grant_index_agent)
|
627 |
self.add_node("org-index", self.org_index_agent)
|
628 |
self.add_node("final_answer", self.final_answer)
|
629 |
|
630 |
# Set entry point
|
631 |
-
self.set_entry_point("
|
|
|
632 |
|
633 |
# Add conditional edges
|
634 |
self.add_conditional_edges(
|
|
|
2 |
from functools import partial
|
3 |
import json
|
4 |
import ast
|
5 |
+
from ask_candid.base.api_base import BaseAPI
|
6 |
+
import os
|
7 |
+
import pandas as pd
|
8 |
from pydantic import BaseModel, Field
|
9 |
|
10 |
from langchain_core.runnables import RunnableSequence
|
|
|
27 |
tools = [
|
28 |
IndexShowDataTool(),
|
29 |
IndexDetailsTool(),
|
30 |
+
create_search_tool(pcs_codes={}),
|
31 |
]
|
32 |
|
33 |
|
34 |
+
class AutocodingAPI(BaseAPI):
|
35 |
+
def __init__(self):
|
36 |
+
super().__init__(
|
37 |
+
url=os.getenv("AUTOCODING_API_URL"),
|
38 |
+
headers={
|
39 |
+
"x-api-key": os.getenv("AUTOCODING_API_KEY"),
|
40 |
+
"Content-Type": "application/json",
|
41 |
+
},
|
42 |
+
)
|
43 |
+
|
44 |
+
def __call__(self, text: str, taxonomy: str = "pcs-v3"):
|
45 |
+
params = {"text": text, "taxonomy": taxonomy}
|
46 |
+
return self.get(**params)
|
47 |
+
|
48 |
+
|
49 |
+
def find_subject_levels(filtered_df, subject_level_i, target_value):
|
50 |
+
"""
|
51 |
+
Filters the DataFrame from the last valid NaN in 'Subject Level i' and retrieves corresponding values for lower levels.
|
52 |
+
|
53 |
+
Parameters:
|
54 |
+
filtered_df (pd.DataFrame): The input DataFrame.
|
55 |
+
subject_level_i (int): The subject level to filter from (1 to 4).
|
56 |
+
target_value (str): The value to search for in 'Subject Level i'.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
dict: A dictionary containing values for 'Subject Level i' to 'Subject Level 1'.
|
60 |
+
pd.DataFrame: The filtered DataFrame from the determined start index to the target_value row.
|
61 |
+
"""
|
62 |
+
if subject_level_i < 1 or subject_level_i > 4:
|
63 |
+
raise ValueError("subject_level_i should be between 1 and 4")
|
64 |
+
|
65 |
+
# Define the target column dynamically
|
66 |
+
target_column = f"Subject Level {subject_level_i}"
|
67 |
+
|
68 |
+
# Find indices where the target column has the target value
|
69 |
+
target_indices = filtered_df[
|
70 |
+
filtered_df[target_column].astype(str).str.strip() == target_value
|
71 |
+
].index
|
72 |
+
|
73 |
+
if target_indices.empty:
|
74 |
+
return {}, pd.DataFrame() # Return empty if target_value is not found
|
75 |
+
|
76 |
+
# Get the first occurrence of the target value
|
77 |
+
first_target_index = target_indices[0]
|
78 |
+
|
79 |
+
# Initialize dictionary to store subject level values
|
80 |
+
subject_level_values = {target_column: target_value}
|
81 |
+
|
82 |
+
# Initialize subject level start index
|
83 |
+
subject_level_start = first_target_index
|
84 |
+
|
85 |
+
# Find the last non-NaN row for each subject level
|
86 |
+
for level in range(subject_level_i - 1, 0, -1): # Loop from subject_level_i-1 to 1
|
87 |
+
column_name = f"Subject Level {level}"
|
88 |
+
|
89 |
+
# Start checking above the previous found index
|
90 |
+
current_index = subject_level_start - 1
|
91 |
+
|
92 |
+
while current_index >= 0 and pd.isna(
|
93 |
+
filtered_df.loc[current_index, column_name]
|
94 |
+
):
|
95 |
+
current_index -= 1 # Move up while NaN is found
|
96 |
+
|
97 |
+
# Move one row down to get the last valid row in 'Subject Level level'
|
98 |
+
subject_level_start = current_index + 1
|
99 |
+
|
100 |
+
# Ensure we store the correct value at each subject level
|
101 |
+
if subject_level_start in filtered_df.index:
|
102 |
+
subject_level_values[column_name] = filtered_df.loc[
|
103 |
+
subject_level_start - 1, column_name
|
104 |
+
]
|
105 |
+
|
106 |
+
# Ensure valid slicing range
|
107 |
+
min_start_index = subject_level_start
|
108 |
+
|
109 |
+
if min_start_index < first_target_index:
|
110 |
+
filtered_df = filtered_df.loc[min_start_index:first_target_index]
|
111 |
+
else:
|
112 |
+
filtered_df = pd.DataFrame()
|
113 |
+
|
114 |
+
return subject_level_values, filtered_df
|
115 |
+
|
116 |
+
|
117 |
+
def extract_heirarchy(full_code, target_value):
|
118 |
+
# df = pd.read_excel(
|
119 |
+
# r"C:\Users\mukul.rawat\OneDrive - Candid\Documents\Projects\Gen AI\azure_devops\ask-candid-assistant\PCS_Taxonomy_Definitions_2024.xlsx"
|
120 |
+
# )
|
121 |
+
df = pd.read_excel(r"C:\Users\siqi.deng\Downloads\PCS_Taxonomy_Definitions_2024.xlsx")
|
122 |
+
filtered_df = df[df["PCS Code"].str.startswith(full_code[:2], na=False)]
|
123 |
+
for i in range(1, 5):
|
124 |
+
column_name = f"Subject Level {i}"
|
125 |
+
if (df[column_name].str.strip() == target_value).any():
|
126 |
+
break
|
127 |
+
|
128 |
+
subject_level_values, filtered_df = find_subject_levels(
|
129 |
+
filtered_df, i, target_value
|
130 |
+
)
|
131 |
+
sorted_values = [
|
132 |
+
value
|
133 |
+
for key, value in sorted(
|
134 |
+
subject_level_values.items(), key=lambda x: int(x[0].split()[-1])
|
135 |
+
)
|
136 |
+
]
|
137 |
+
# Joining values in the required format
|
138 |
+
result = " : ".join(sorted_values)
|
139 |
+
return result
|
140 |
+
|
141 |
+
|
142 |
class GraphState(TypedDict):
|
143 |
query: str = Field(
|
144 |
..., description="The user's query to be processed by the system."
|
|
|
158 |
...,
|
159 |
description="The Elasticsearch query result generated or used by the agent.",
|
160 |
)
|
161 |
+
pcs_codes: dict = Field(..., description="pcs codes")
|
162 |
|
163 |
|
164 |
class AnalysisResult(BaseModel):
|
|
|
446 |
|
447 |
|
448 |
class ElasticGraph(StateGraph):
|
|
|
|
|
449 |
llm: LLM
|
450 |
tools: List[Tool]
|
451 |
|
|
|
455 |
self.tools = tools
|
456 |
self.construct_graph()
|
457 |
|
458 |
+
def Extract_PCS_Codes(self, state):
|
459 |
+
"""Todo: Add Subject heirarchies, Population, Geo"""
|
460 |
+
print("query", state["query"])
|
461 |
+
autocoding_api = AutocodingAPI()
|
462 |
+
autocoding_response = autocoding_api(text=state["query"]).get("data", {})
|
463 |
+
# population_served = autocoding_response.get("population", {})
|
464 |
+
subjects = autocoding_response.get("subject", {})
|
465 |
+
descriptions = []
|
466 |
+
heirarchy_string = []
|
467 |
+
if subjects and isinstance(subjects, list) and "description" in subjects[0]:
|
468 |
+
for subject in subjects:
|
469 |
+
# if subject['description'] in subjects_list:
|
470 |
+
descriptions.append(subject["description"])
|
471 |
+
heirarchy_string.append(
|
472 |
+
extract_heirarchy(subject["full_code"], subject["description"])
|
473 |
+
)
|
474 |
+
print("descriptions", descriptions)
|
475 |
+
|
476 |
+
populations = autocoding_response.get("population", {})
|
477 |
+
population_dict = []
|
478 |
+
if (
|
479 |
+
populations
|
480 |
+
and isinstance(populations, list)
|
481 |
+
and "description" in populations[0]
|
482 |
+
):
|
483 |
+
for population in populations:
|
484 |
+
population_dict.append(population["description"])
|
485 |
+
state["pcs_codes"] = {
|
486 |
+
"subject": descriptions,
|
487 |
+
"heirarchy_string": heirarchy_string,
|
488 |
+
"population": population_dict,
|
489 |
+
}
|
490 |
+
print("pcs_codes_new", state["pcs_codes"])
|
491 |
+
return state
|
492 |
+
|
493 |
def agent_factory(self) -> AgentExecutor:
|
494 |
"""
|
495 |
Creates and configures an AgentExecutor instance for interacting with Elasticsearch.
|
|
|
532 |
return_intermediate_steps=True,
|
533 |
)
|
534 |
|
535 |
+
def agent_factory_claude(self, pcs_codes, prefix) -> AgentExecutor:
|
536 |
"""
|
537 |
Creates and configures an AgentExecutor instance for interacting with Elasticsearch.
|
538 |
|
|
|
545 |
AgentExecutor: Configured agent ready to execute tasks with specified tools,
|
546 |
providing detailed intermediate steps for transparency.
|
547 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
548 |
prompt = ChatPromptTemplate.from_messages(
|
549 |
[
|
550 |
("system", f"You are a helpful elasticsearch assistant. {prefix}"),
|
|
|
554 |
]
|
555 |
)
|
556 |
|
557 |
+
tools = [
|
558 |
+
# ListIndicesTool(),
|
559 |
+
IndexShowDataTool(),
|
560 |
+
IndexDetailsTool(),
|
561 |
+
create_search_tool(pcs_codes=pcs_codes),
|
562 |
+
]
|
563 |
+
agent = create_tool_calling_agent(self.llm, tools, prompt)
|
564 |
+
|
565 |
agent_executor = AgentExecutor.from_agent_and_tools(
|
566 |
+
agent=agent,
|
567 |
+
tools=tools,
|
568 |
+
verbose=True,
|
569 |
+
return_intermediate_steps=True,
|
570 |
)
|
571 |
# Create the agent
|
572 |
return agent_executor
|
|
|
613 |
|
614 |
def grant_index_agent(self, state: GraphState) -> GraphState:
|
615 |
print("> Grant Index Agent")
|
616 |
+
# autocoding test
|
617 |
+
|
618 |
input_data = {
|
619 |
"input": f"""
|
620 |
You are an Elasticsearch database agent designed to accurately understand and respond to user queries. Follow these steps:
|
|
|
627 |
Users may not always provide the exact name, so the Elasticsearch query should accommodate partial or incomplete names
|
628 |
by searching for relevant keywords.
|
629 |
6. Present the response in a clear and natural language format, addressing the user's question directly.
|
630 |
+
|
|
|
631 |
Description of some of the fields in the index but rest of the fields which are not here should be easy to understand:
|
632 |
+
*fiscal_year: Year when grantmaker allocates budget for funding and grants. format YYYY
|
633 |
+
*recipient_state: is abbreviated for eg. NY, FL, CA
|
634 |
+
*recipient_city - Full Name of the City e.g, New York City, Boston
|
635 |
+
*recipient_country - Country Abbreviation of the recipient organization e.g USA
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
636 |
|
637 |
+
Note: Do not include `title`, `program_area`, `text` field in the elastic search query
|
638 |
User's query:
|
639 |
```{state["query"]}```
|
640 |
"""
|
641 |
}
|
642 |
+
pcs_codes = state["pcs_codes"]
|
643 |
+
pcs_match_term = ""
|
644 |
+
for pcs_code in pcs_codes["subject"]:
|
645 |
+
if pcs_code != "Philanthropy":
|
646 |
+
pcs_match_term += f"*'pcs_v3.subject.value.name': {pcs_code}* \n"
|
647 |
+
|
648 |
+
for pcs_code in pcs_codes["population"]:
|
649 |
+
if pcs_code != "Other population":
|
650 |
+
pcs_match_term += f"*'pcs_v3.population.value.name': {pcs_code}* \n"
|
651 |
+
print("pcs_match_term", pcs_match_term)
|
652 |
+
prefix = f"""
|
653 |
+
You are an intelligent agent tasked with generating accurate Elasticsearch DSL queries.
|
654 |
+
Analyze the intent behind the query and determine the appropriate Elasticsearch operations required.
|
655 |
+
Guidelines for generating right elastic seach query:
|
656 |
+
1. Automatically determine whether to return document hits or aggregation results based on the query structure.
|
657 |
+
2. Use keyword fields instead of text fields for aggregations and sorting to avoid fielddata errors
|
658 |
+
3. Avoid using field.keyword if a keyword field is already present to prevent redundant queries.
|
659 |
+
4. Ensure efficient query execution by selecting appropriate query types for filtering, searching, and aggregating.
|
660 |
+
|
661 |
+
Instruction for pcs_v3 Field-
|
662 |
+
If {pcs_codes['subject']} not empty:
|
663 |
+
Only include all of the following match terms. No other pcs_v3 fields should be added, duplicated, or altered except for those listed below.
|
664 |
+
- {pcs_match_term}
|
665 |
+
"""
|
666 |
+
agent_exec = self.agent_factory_claude(
|
667 |
+
pcs_codes=state["pcs_codes"], prefix=prefix
|
668 |
+
)
|
669 |
res = agent_exec.invoke(input_data)
|
670 |
state["agent_out"] = res["output"]
|
|
|
671 |
es_queries, es_results = {}, {}
|
672 |
for i, action in enumerate(res.get("intermediate_steps", []), start=1):
|
673 |
if action[0].tool == "elastic_index_search_tool":
|
674 |
+
print("query", action[0].tool_input.get("query"))
|
675 |
es_queries[f"query_{i}"] = json.loads(
|
676 |
action[0].tool_input.get("query") or "{}"
|
677 |
)
|
|
|
697 |
"""
|
698 |
|
699 |
print("> Org Index Agent")
|
700 |
+
mapping_description = """
|
701 |
+
"admin1_code": "state abbreviation"
|
702 |
+
"admin1_description": "Full name/label of the state"
|
703 |
+
"city": Full Name of the city with 1st letter being capital for e.g. New York City
|
704 |
+
"assets": "The assets value of the most recent fiscals available for the organization."
|
705 |
+
"country_code": "Country abbreviation"
|
706 |
+
"country_name": "Country name"
|
707 |
+
"fiscal_year": "The year of the most recent fiscals available for the organization. (YYYY format)"
|
708 |
+
"mission_statement": "The mission statement of the organization."
|
709 |
+
"roles": "grantmaker: Indicates the organization gives grants., recipient: Indicates the organization receives grants., company: Indicates the organization is a company/corporation."
|
710 |
+
|
711 |
+
"""
|
712 |
input_data = {
|
713 |
"input": f"""
|
714 |
You are an Elasticsearch database agent designed to accurately understand and respond to user queries. Follow these steps:
|
|
|
716 |
1. Understand the user query to determine the required information.
|
717 |
2. Query the indices in the Elasticsearch database.
|
718 |
3. Retrieve the mappings and field names relevant to the query.
|
719 |
+
4. Use the `organization_qa_ds1` index to extract the necessary data.
|
720 |
5. Present the response in a clear and natural language format, addressing the user's question directly.
|
721 |
+
|
722 |
+
|
723 |
+
Given Below is mapping description of some of the fields
|
724 |
+
```{mapping_description}```
|
725 |
+
|
726 |
|
727 |
+
User's query:
|
728 |
```{state["query"]}```
|
729 |
"""
|
730 |
}
|
731 |
+
|
732 |
+
pcs_codes = state["pcs_codes"]
|
733 |
+
pcs_match_term = ""
|
734 |
+
for pcs_code in pcs_codes["subject"]:
|
735 |
+
pcs_match_term += f'"taxonomy_descriptions": "{pcs_code}" \n"'
|
736 |
+
|
737 |
+
print("pcs_match_term", pcs_match_term)
|
738 |
+
prefix = f"""You are an intelligent agent tasked with generating accurate Elasticsearch DSL queries.
|
739 |
+
Analyze the intent behind the query and determine the appropriate Elasticsearch operations required.
|
740 |
+
Guidelines for generating right elastic seach query:
|
741 |
+
1. Automatically determine whether to return document hits or aggregation results based on the query structure.
|
742 |
+
2. Use keyword fields instead of text fields for aggregations and sorting to avoid fielddata errors
|
743 |
+
3. Avoid using field.keyword if a keyword field is already present to prevent redundant queries.
|
744 |
+
4. Ensure efficient query execution by selecting appropriate query types for filtering, searching, and aggregating.
|
745 |
+
|
746 |
+
Instructions to use `taxonomy_descriptions` field:
|
747 |
+
If {pcs_codes['subject']} not empty, only add the following match term:
|
748 |
+
Only add the following `match` term, No other `taxonomy_descriptions` fields should be added, duplicated, or modified except belowIf {pcs_codes['subject']} not empty,
|
749 |
+
- {pcs_match_term}
|
750 |
+
|
751 |
+
|
752 |
+
Avoid using `ntee_major_description` field in the es query
|
753 |
+
|
754 |
+
"""
|
755 |
+
agent_exec = self.agent_factory_claude(
|
756 |
+
pcs_codes=state["pcs_codes"], prefix=prefix
|
757 |
+
)
|
758 |
res = agent_exec.invoke(input_data)
|
759 |
state["agent_out"] = res["output"]
|
760 |
|
|
|
812 |
"""
|
813 |
|
814 |
# Add nodes
|
815 |
+
self.add_node("Context_Extraction", self.Extract_PCS_Codes)
|
816 |
self.add_node("analyse", self.analyse_query)
|
817 |
self.add_node("grant-index", self.grant_index_agent)
|
818 |
self.add_node("org-index", self.org_index_agent)
|
819 |
self.add_node("final_answer", self.final_answer)
|
820 |
|
821 |
# Set entry point
|
822 |
+
self.set_entry_point("Context_Extraction")
|
823 |
+
self.add_edge("Context_Extraction", "analyse")
|
824 |
|
825 |
# Add conditional edges
|
826 |
self.add_conditional_edges(
|
ask_candid/retrieval/elastic.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
from typing import List, Tuple, Dict, Iterable, Iterator, Optional, Union, Any
|
2 |
-
from dataclasses import dataclass
|
3 |
from itertools import groupby
|
4 |
|
5 |
from torch.nn import functional as F
|
@@ -10,12 +9,14 @@ from langchain_core.documents import Document
|
|
10 |
from elasticsearch import Elasticsearch
|
11 |
|
12 |
from ask_candid.retrieval.sparse_lexical import SpladeEncoder
|
13 |
-
from ask_candid.retrieval.sources.
|
14 |
-
from ask_candid.retrieval.sources.
|
15 |
-
from ask_candid.retrieval.sources.
|
16 |
-
from ask_candid.retrieval.sources.
|
17 |
-
from ask_candid.retrieval.sources.
|
18 |
-
from ask_candid.retrieval.sources.
|
|
|
|
|
19 |
from ask_candid.services.small_lm import CandidSLM
|
20 |
from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA, NEWS_ELASTIC
|
21 |
from ask_candid.base.config.data import DataIndices, ALL_INDICES
|
@@ -23,17 +24,6 @@ from ask_candid.base.config.data import DataIndices, ALL_INDICES
|
|
23 |
encoder = SpladeEncoder()
|
24 |
|
25 |
|
26 |
-
@dataclass
|
27 |
-
class ElasticHitsResult:
|
28 |
-
"""Dataclass for Elasticsearch hits results
|
29 |
-
"""
|
30 |
-
index: str
|
31 |
-
id: Any
|
32 |
-
score: float
|
33 |
-
source: Dict[str, Any]
|
34 |
-
inner_hits: Dict[str, Any]
|
35 |
-
|
36 |
-
|
37 |
class RetrieverInput(BaseModel):
|
38 |
"""Input to the Elasticsearch retriever."""
|
39 |
user_input: str = Field(description="query to look up in retriever")
|
@@ -101,7 +91,7 @@ def news_query_builder(query: str) -> Dict[str, Any]:
|
|
101 |
tokens = encoder.token_expand(query)
|
102 |
|
103 |
query = {
|
104 |
-
"_source": ["id", "link", "title", "content"],
|
105 |
"query": {
|
106 |
"bool": {
|
107 |
"filter": [
|
@@ -150,27 +140,27 @@ def query_builder(query: str, indices: List[DataIndices]) -> Tuple[List[Dict[str
|
|
150 |
if index == "issuelab":
|
151 |
q = build_sparse_vector_query(query=query, fields=IssueLabConfig.text_fields)
|
152 |
q["_source"] = {"excludes": ["embeddings"]}
|
153 |
-
q["size"] =
|
154 |
queries.extend([{"index": IssueLabConfig.index_name}, q])
|
155 |
elif index == "youtube":
|
156 |
q = build_sparse_vector_query(query=query, fields=YoutubeConfig.text_fields)
|
157 |
q["_source"] = {"excludes": ["embeddings", *YoutubeConfig.excluded_fields]}
|
158 |
-
q["size"] =
|
159 |
queries.extend([{"index": YoutubeConfig.index_name}, q])
|
160 |
elif index == "candid_blog":
|
161 |
q = build_sparse_vector_query(query=query, fields=CandidBlogConfig.text_fields)
|
162 |
q["_source"] = {"excludes": ["embeddings"]}
|
163 |
-
q["size"] =
|
164 |
queries.extend([{"index": CandidBlogConfig.index_name}, q])
|
165 |
elif index == "candid_learning":
|
166 |
q = build_sparse_vector_query(query=query, fields=CandidLearningConfig.text_fields)
|
167 |
q["_source"] = {"excludes": ["embeddings"]}
|
168 |
-
q["size"] =
|
169 |
queries.extend([{"index": CandidLearningConfig.index_name}, q])
|
170 |
elif index == "candid_help":
|
171 |
q = build_sparse_vector_query(query=query, fields=CandidHelpConfig.text_fields)
|
172 |
q["_source"] = {"excludes": ["embeddings"]}
|
173 |
-
q["size"] =
|
174 |
queries.extend([{"index": CandidHelpConfig.index_name}, q])
|
175 |
elif index == "news":
|
176 |
q = news_query_builder(query=query)
|
@@ -199,12 +189,18 @@ def multi_search(
|
|
199 |
def _msearch_response_generator(responses: List[Dict[str, Any]]) -> Iterator[ElasticHitsResult]:
|
200 |
for query_group in responses:
|
201 |
for h in query_group.get("hits", {}).get("hits", []):
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
yield ElasticHitsResult(
|
203 |
index=h["_index"],
|
204 |
id=h["_id"],
|
205 |
score=h["_score"],
|
206 |
source=h["_source"],
|
207 |
-
inner_hits=
|
208 |
)
|
209 |
|
210 |
results = []
|
@@ -264,6 +260,10 @@ def retrieved_text(hits: Dict[str, Any]) -> str:
|
|
264 |
|
265 |
text = []
|
266 |
for _, v in hits.items():
|
|
|
|
|
|
|
|
|
267 |
for h in (v.get("hits", {}).get("hits") or []):
|
268 |
for _, field in h.get("fields", {}).items():
|
269 |
for chunk in field:
|
@@ -298,7 +298,8 @@ def cosine_rescore(query: str, contexts: List[str]) -> List[float]:
|
|
298 |
|
299 |
def reranker(
|
300 |
query_results: Iterable[ElasticHitsResult],
|
301 |
-
search_text: Optional[str] = None
|
|
|
302 |
) -> Iterator[ElasticHitsResult]:
|
303 |
"""Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales.
|
304 |
This will shuffle results
|
@@ -327,58 +328,13 @@ def reranker(
|
|
327 |
text = retrieved_text(d.inner_hits)
|
328 |
texts.append(text)
|
329 |
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
yield from sorted(results, key=lambda x: x.score, reverse=True)
|
336 |
|
337 |
-
|
338 |
-
def get_context(field_name: str, hit: ElasticHitsResult, context_length: int = 1024, add_context: bool = True) -> str:
|
339 |
-
"""Pads the relevant chunk of text with context before and after
|
340 |
-
|
341 |
-
Parameters
|
342 |
-
----------
|
343 |
-
field_name : str
|
344 |
-
a field with the long text that was chunked into pieces
|
345 |
-
hit : ElasticHitsResult
|
346 |
-
context_length : int, optional
|
347 |
-
length of text to add before and after the chunk, by default 1024
|
348 |
-
|
349 |
-
Returns
|
350 |
-
-------
|
351 |
-
str
|
352 |
-
longer chunks stuffed together
|
353 |
-
"""
|
354 |
-
|
355 |
-
chunks = []
|
356 |
-
# NOTE chunks have tokens, long text is a normal text, but may contain html that also gets weird after tokenization
|
357 |
-
long_text = hit.source.get(f"{field_name}", "")
|
358 |
-
long_text = long_text.lower()
|
359 |
-
inner_hits_field = f"embeddings.{field_name}.chunks"
|
360 |
-
found_chunks = hit.inner_hits.get(inner_hits_field, {})
|
361 |
-
if found_chunks:
|
362 |
-
hits = found_chunks.get("hits", {}).get("hits", [])
|
363 |
-
for h in hits:
|
364 |
-
chunk = h.get("fields", {})[inner_hits_field][0]["chunk"][0]
|
365 |
-
|
366 |
-
# cutting the middle because we may have tokenizing artifacts there
|
367 |
-
chunk = chunk[3: -3]
|
368 |
-
|
369 |
-
if add_context:
|
370 |
-
# Find the start and end indices of the chunk in the large text
|
371 |
-
start_index = long_text.find(chunk[:20])
|
372 |
-
|
373 |
-
# Chunk is found
|
374 |
-
if start_index != -1:
|
375 |
-
end_index = start_index + len(chunk)
|
376 |
-
pre_start_index = max(0, start_index - context_length)
|
377 |
-
post_end_index = min(len(long_text), end_index + context_length)
|
378 |
-
chunks.append(long_text[pre_start_index:post_end_index])
|
379 |
-
else:
|
380 |
-
chunks.append(chunk)
|
381 |
-
return '\n\n'.join(chunks)
|
382 |
|
383 |
|
384 |
def process_hit(hit: ElasticHitsResult) -> Union[Document, None]:
|
@@ -394,94 +350,17 @@ def process_hit(hit: ElasticHitsResult) -> Union[Document, None]:
|
|
394 |
"""
|
395 |
|
396 |
if "issuelab-elser" in hit.index:
|
397 |
-
|
398 |
-
description = hit.source.get("description", "")
|
399 |
-
combined_issuelab_findings = hit.source.get("combined_issuelab_findings", "")
|
400 |
-
# we only need to process long texts
|
401 |
-
chunks_with_context_txt = get_context("content", hit, context_length=12)
|
402 |
-
doc = Document(
|
403 |
-
page_content='\n\n'.join([
|
404 |
-
combined_item_description,
|
405 |
-
combined_issuelab_findings,
|
406 |
-
description,
|
407 |
-
chunks_with_context_txt
|
408 |
-
]),
|
409 |
-
metadata={
|
410 |
-
"title": hit.source["title"],
|
411 |
-
"source": "IssueLab",
|
412 |
-
"source_id": hit.source["resource_id"],
|
413 |
-
"url": hit.source.get("permalink", "")
|
414 |
-
}
|
415 |
-
)
|
416 |
elif "youtube" in hit.index:
|
417 |
-
|
418 |
-
# we only need to process long texts
|
419 |
-
description_cleaned_with_context_txt = get_context("description_cleaned", hit, context_length=12)
|
420 |
-
captions_cleaned_with_context_txt = get_context("captions_cleaned", hit, context_length=12)
|
421 |
-
doc = Document(
|
422 |
-
page_content='\n\n'.join([title, description_cleaned_with_context_txt, captions_cleaned_with_context_txt]),
|
423 |
-
metadata={
|
424 |
-
"title": title,
|
425 |
-
"source": "Candid YouTube",
|
426 |
-
"source_id": hit.source['video_id'],
|
427 |
-
"url": f"https://www.youtube.com/watch?v={hit.source['video_id']}"
|
428 |
-
}
|
429 |
-
)
|
430 |
elif "candid-blog" in hit.index:
|
431 |
-
|
432 |
-
title = hit.source.get("title", "")
|
433 |
-
# we only need to process long text
|
434 |
-
content_with_context_txt = get_context("content", hit, context_length=12, add_context=False)
|
435 |
-
authors = get_context("authors_text", hit, context_length=12, add_context=False)
|
436 |
-
tags = hit.source.get("title_summary_tags", "")
|
437 |
-
doc = Document(
|
438 |
-
page_content='\n\n'.join([title, excerpt, content_with_context_txt, authors, tags]),
|
439 |
-
metadata={
|
440 |
-
"title": title,
|
441 |
-
"source": "Candid Blog",
|
442 |
-
"source_id": hit.source["id"],
|
443 |
-
"url": hit.source["link"]
|
444 |
-
}
|
445 |
-
)
|
446 |
elif "candid-learning" in hit.index:
|
447 |
-
|
448 |
-
content_with_context_txt = get_context("content", hit, context_length=12)
|
449 |
-
training_topics = hit.source.get("training_topics", "")
|
450 |
-
staff_recommendations = hit.source.get("staff_recommendations", "")
|
451 |
-
|
452 |
-
doc = Document(
|
453 |
-
page_content='\n\n'.join([title, staff_recommendations, training_topics, content_with_context_txt]),
|
454 |
-
metadata={
|
455 |
-
"title": hit.source["title"],
|
456 |
-
"source": "Candid Learning",
|
457 |
-
"source_id": hit.source["post_id"],
|
458 |
-
"url": hit.source.get("url", "")
|
459 |
-
}
|
460 |
-
)
|
461 |
elif "candid-help" in hit.index:
|
462 |
-
|
463 |
-
content_with_context_txt = get_context("content", hit, context_length=12)
|
464 |
-
combined_article_description = hit.source.get("combined_article_description", "")
|
465 |
-
|
466 |
-
doc = Document(
|
467 |
-
page_content='\n\n'.join([combined_article_description, content_with_context_txt]),
|
468 |
-
metadata={
|
469 |
-
"title": title,
|
470 |
-
"source": "Candid Help",
|
471 |
-
"source_id": hit.source["id"],
|
472 |
-
"url": hit.source.get("link", "")
|
473 |
-
}
|
474 |
-
)
|
475 |
elif "news" in hit.index:
|
476 |
-
doc =
|
477 |
-
page_content='\n\n'.join([hit.source.get("title", ""), hit.source.get("content", "")]),
|
478 |
-
metadata={
|
479 |
-
"title": hit.source.get("title", ""),
|
480 |
-
"source": "Candid News",
|
481 |
-
"source_id": hit.source["id"],
|
482 |
-
"url": hit.source.get("link", "")
|
483 |
-
}
|
484 |
-
)
|
485 |
else:
|
486 |
doc = None
|
487 |
return doc
|
|
|
1 |
from typing import List, Tuple, Dict, Iterable, Iterator, Optional, Union, Any
|
|
|
2 |
from itertools import groupby
|
3 |
|
4 |
from torch.nn import functional as F
|
|
|
9 |
from elasticsearch import Elasticsearch
|
10 |
|
11 |
from ask_candid.retrieval.sparse_lexical import SpladeEncoder
|
12 |
+
from ask_candid.retrieval.sources.schema import ElasticHitsResult
|
13 |
+
from ask_candid.retrieval.sources.issuelab import IssueLabConfig, process_issuelab_hit
|
14 |
+
from ask_candid.retrieval.sources.youtube import YoutubeConfig, process_youtube_hit
|
15 |
+
from ask_candid.retrieval.sources.candid_blog import CandidBlogConfig, process_blog_hit
|
16 |
+
from ask_candid.retrieval.sources.candid_learning import CandidLearningConfig, process_learning_hit
|
17 |
+
from ask_candid.retrieval.sources.candid_help import CandidHelpConfig, process_help_hit
|
18 |
+
from ask_candid.retrieval.sources.candid_news import CandidNewsConfig, process_news_hit
|
19 |
+
|
20 |
from ask_candid.services.small_lm import CandidSLM
|
21 |
from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA, NEWS_ELASTIC
|
22 |
from ask_candid.base.config.data import DataIndices, ALL_INDICES
|
|
|
24 |
encoder = SpladeEncoder()
|
25 |
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
class RetrieverInput(BaseModel):
|
28 |
"""Input to the Elasticsearch retriever."""
|
29 |
user_input: str = Field(description="query to look up in retriever")
|
|
|
91 |
tokens = encoder.token_expand(query)
|
92 |
|
93 |
query = {
|
94 |
+
"_source": ["id", "link", "title", "content", "site_name"],
|
95 |
"query": {
|
96 |
"bool": {
|
97 |
"filter": [
|
|
|
140 |
if index == "issuelab":
|
141 |
q = build_sparse_vector_query(query=query, fields=IssueLabConfig.text_fields)
|
142 |
q["_source"] = {"excludes": ["embeddings"]}
|
143 |
+
q["size"] = 2
|
144 |
queries.extend([{"index": IssueLabConfig.index_name}, q])
|
145 |
elif index == "youtube":
|
146 |
q = build_sparse_vector_query(query=query, fields=YoutubeConfig.text_fields)
|
147 |
q["_source"] = {"excludes": ["embeddings", *YoutubeConfig.excluded_fields]}
|
148 |
+
q["size"] = 5
|
149 |
queries.extend([{"index": YoutubeConfig.index_name}, q])
|
150 |
elif index == "candid_blog":
|
151 |
q = build_sparse_vector_query(query=query, fields=CandidBlogConfig.text_fields)
|
152 |
q["_source"] = {"excludes": ["embeddings"]}
|
153 |
+
q["size"] = 5
|
154 |
queries.extend([{"index": CandidBlogConfig.index_name}, q])
|
155 |
elif index == "candid_learning":
|
156 |
q = build_sparse_vector_query(query=query, fields=CandidLearningConfig.text_fields)
|
157 |
q["_source"] = {"excludes": ["embeddings"]}
|
158 |
+
q["size"] = 5
|
159 |
queries.extend([{"index": CandidLearningConfig.index_name}, q])
|
160 |
elif index == "candid_help":
|
161 |
q = build_sparse_vector_query(query=query, fields=CandidHelpConfig.text_fields)
|
162 |
q["_source"] = {"excludes": ["embeddings"]}
|
163 |
+
q["size"] = 5
|
164 |
queries.extend([{"index": CandidHelpConfig.index_name}, q])
|
165 |
elif index == "news":
|
166 |
q = news_query_builder(query=query)
|
|
|
189 |
def _msearch_response_generator(responses: List[Dict[str, Any]]) -> Iterator[ElasticHitsResult]:
|
190 |
for query_group in responses:
|
191 |
for h in query_group.get("hits", {}).get("hits", []):
|
192 |
+
inner_hits = h.get("inner_hits", {})
|
193 |
+
|
194 |
+
if not inner_hits:
|
195 |
+
if "news" in h.get("_index"):
|
196 |
+
inner_hits = {"text": h.get("_source", {}).get("content")}
|
197 |
+
|
198 |
yield ElasticHitsResult(
|
199 |
index=h["_index"],
|
200 |
id=h["_id"],
|
201 |
score=h["_score"],
|
202 |
source=h["_source"],
|
203 |
+
inner_hits=inner_hits
|
204 |
)
|
205 |
|
206 |
results = []
|
|
|
260 |
|
261 |
text = []
|
262 |
for _, v in hits.items():
|
263 |
+
if _ == "text":
|
264 |
+
text.append(v)
|
265 |
+
continue
|
266 |
+
|
267 |
for h in (v.get("hits", {}).get("hits") or []):
|
268 |
for _, field in h.get("fields", {}).items():
|
269 |
for chunk in field:
|
|
|
298 |
|
299 |
def reranker(
|
300 |
query_results: Iterable[ElasticHitsResult],
|
301 |
+
search_text: Optional[str] = None,
|
302 |
+
max_num_results: int = 10
|
303 |
) -> Iterator[ElasticHitsResult]:
|
304 |
"""Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales.
|
305 |
This will shuffle results
|
|
|
328 |
text = retrieved_text(d.inner_hits)
|
329 |
texts.append(text)
|
330 |
|
331 |
+
if search_text and len(texts) == len(results):
|
332 |
+
# scores = cosine_rescore(search_text, texts)
|
333 |
+
scores = encoder.query_reranking(query=search_text, documents=texts)
|
334 |
+
for r, s in zip(results, scores):
|
335 |
+
r.score = s
|
|
|
336 |
|
337 |
+
yield from sorted(results, key=lambda x: x.score, reverse=True)[:max_num_results]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
|
339 |
|
340 |
def process_hit(hit: ElasticHitsResult) -> Union[Document, None]:
|
|
|
350 |
"""
|
351 |
|
352 |
if "issuelab-elser" in hit.index:
|
353 |
+
doc = process_issuelab_hit(hit)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
354 |
elif "youtube" in hit.index:
|
355 |
+
doc = process_youtube_hit(hit)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
elif "candid-blog" in hit.index:
|
357 |
+
doc = process_blog_hit(hit)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
358 |
elif "candid-learning" in hit.index:
|
359 |
+
doc = process_learning_hit(hit)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
360 |
elif "candid-help" in hit.index:
|
361 |
+
doc = process_help_hit(hit)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
362 |
elif "news" in hit.index:
|
363 |
+
doc = process_news_hit(hit)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
else:
|
365 |
doc = None
|
366 |
return doc
|
ask_candid/retrieval/sources/candid_blog.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1 |
from typing import Dict, Any
|
2 |
-
from ask_candid.retrieval.sources.schema import ElasticSourceConfig
|
3 |
|
|
|
|
|
|
|
|
|
4 |
|
5 |
CandidBlogConfig = ElasticSourceConfig(
|
6 |
index_name="search-semantic-candid-blog",
|
@@ -8,6 +11,24 @@ CandidBlogConfig = ElasticSourceConfig(
|
|
8 |
)
|
9 |
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
12 |
url = f"{doc['link']}"
|
13 |
fields = ["title", "excerpt"]
|
|
|
1 |
from typing import Dict, Any
|
|
|
2 |
|
3 |
+
from langchain_core.documents import Document
|
4 |
+
|
5 |
+
from ask_candid.retrieval.sources.schema import ElasticSourceConfig, ElasticHitsResult
|
6 |
+
from ask_candid.retrieval.sources.utils import get_context
|
7 |
|
8 |
CandidBlogConfig = ElasticSourceConfig(
|
9 |
index_name="search-semantic-candid-blog",
|
|
|
11 |
)
|
12 |
|
13 |
|
14 |
+
def process_blog_hit(hit: ElasticHitsResult) -> Document:
|
15 |
+
excerpt = hit.source.get("excerpt", "")
|
16 |
+
title = hit.source.get("title", "")
|
17 |
+
# we only need to process long text
|
18 |
+
content_with_context_txt = get_context("content", hit, context_length=12, add_context=False)
|
19 |
+
authors = get_context("authors_text", hit, context_length=12, add_context=False)
|
20 |
+
tags = hit.source.get("title_summary_tags", "")
|
21 |
+
return Document(
|
22 |
+
page_content='\n\n'.join([title, excerpt, content_with_context_txt, authors, tags]),
|
23 |
+
metadata={
|
24 |
+
"title": title,
|
25 |
+
"source": "Candid Blog",
|
26 |
+
"source_id": hit.source["id"],
|
27 |
+
"url": hit.source["link"]
|
28 |
+
}
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
33 |
url = f"{doc['link']}"
|
34 |
fields = ["title", "excerpt"]
|
ask_candid/retrieval/sources/candid_help.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1 |
from typing import Dict, Any
|
2 |
-
from ask_candid.retrieval.sources.schema import ElasticSourceConfig
|
3 |
|
|
|
|
|
|
|
|
|
4 |
|
5 |
CandidHelpConfig = ElasticSourceConfig(
|
6 |
index_name="search-semantic-candid-help-elser_ve1",
|
@@ -8,6 +11,22 @@ CandidHelpConfig = ElasticSourceConfig(
|
|
8 |
)
|
9 |
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
12 |
url = f"{doc['link']}"
|
13 |
fields = ["title", "summary"]
|
|
|
1 |
from typing import Dict, Any
|
|
|
2 |
|
3 |
+
from langchain_core.documents import Document
|
4 |
+
|
5 |
+
from ask_candid.retrieval.sources.schema import ElasticSourceConfig, ElasticHitsResult
|
6 |
+
from ask_candid.retrieval.sources.utils import get_context
|
7 |
|
8 |
CandidHelpConfig = ElasticSourceConfig(
|
9 |
index_name="search-semantic-candid-help-elser_ve1",
|
|
|
11 |
)
|
12 |
|
13 |
|
14 |
+
def process_help_hit(hit: ElasticHitsResult) -> Document:
|
15 |
+
title = hit.source.get("title", "")
|
16 |
+
content_with_context_txt = get_context("content", hit, context_length=12)
|
17 |
+
combined_article_description = hit.source.get("combined_article_description", "")
|
18 |
+
|
19 |
+
return Document(
|
20 |
+
page_content='\n\n'.join([combined_article_description, content_with_context_txt]),
|
21 |
+
metadata={
|
22 |
+
"title": title,
|
23 |
+
"source": "Candid Help",
|
24 |
+
"source_id": hit.source["id"],
|
25 |
+
"url": hit.source.get("link", "")
|
26 |
+
}
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
31 |
url = f"{doc['link']}"
|
32 |
fields = ["title", "summary"]
|
ask_candid/retrieval/sources/candid_learning.py
CHANGED
@@ -1,5 +1,9 @@
|
|
1 |
from typing import Dict, Any
|
2 |
-
|
|
|
|
|
|
|
|
|
3 |
|
4 |
|
5 |
CandidLearningConfig = ElasticSourceConfig(
|
@@ -8,6 +12,23 @@ CandidLearningConfig = ElasticSourceConfig(
|
|
8 |
)
|
9 |
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
12 |
url = f"{doc['url']}"
|
13 |
fields = ["title", "excerpt"]
|
|
|
1 |
from typing import Dict, Any
|
2 |
+
|
3 |
+
from langchain_core.documents import Document
|
4 |
+
|
5 |
+
from ask_candid.retrieval.sources.schema import ElasticSourceConfig, ElasticHitsResult
|
6 |
+
from ask_candid.retrieval.sources.utils import get_context
|
7 |
|
8 |
|
9 |
CandidLearningConfig = ElasticSourceConfig(
|
|
|
12 |
)
|
13 |
|
14 |
|
15 |
+
def process_learning_hit(hit: ElasticHitsResult) -> Document:
|
16 |
+
title = hit.source.get("title", "")
|
17 |
+
content_with_context_txt = get_context("content", hit, context_length=12)
|
18 |
+
training_topics = hit.source.get("training_topics", "")
|
19 |
+
staff_recommendations = hit.source.get("staff_recommendations", "")
|
20 |
+
|
21 |
+
return Document(
|
22 |
+
page_content='\n\n'.join([title, staff_recommendations, training_topics, content_with_context_txt]),
|
23 |
+
metadata={
|
24 |
+
"title": hit.source["title"],
|
25 |
+
"source": "Candid Learning",
|
26 |
+
"source_id": hit.source["post_id"],
|
27 |
+
"url": hit.source.get("url", "")
|
28 |
+
}
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
33 |
url = f"{doc['url']}"
|
34 |
fields = ["title", "excerpt"]
|
ask_candid/retrieval/sources/candid_news.py
CHANGED
@@ -1,7 +1,20 @@
|
|
1 |
-
from
|
2 |
|
|
|
3 |
|
4 |
CandidNewsConfig = ElasticSourceConfig(
|
5 |
index_name="news_1",
|
6 |
text_fields=("title", "content")
|
7 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.documents import Document
|
2 |
|
3 |
+
from ask_candid.retrieval.sources.schema import ElasticSourceConfig, ElasticHitsResult
|
4 |
|
5 |
CandidNewsConfig = ElasticSourceConfig(
|
6 |
index_name="news_1",
|
7 |
text_fields=("title", "content")
|
8 |
)
|
9 |
+
|
10 |
+
|
11 |
+
def process_news_hit(hit: ElasticHitsResult) -> Document:
|
12 |
+
return Document(
|
13 |
+
page_content='\n\n'.join([hit.source.get("title", ""), hit.source.get("content", "")]),
|
14 |
+
metadata={
|
15 |
+
"title": hit.source.get("title", ""),
|
16 |
+
"source": hit.source.get("site_name") or "Candid News",
|
17 |
+
"source_id": hit.source["id"],
|
18 |
+
"url": hit.source.get("link", "")
|
19 |
+
}
|
20 |
+
)
|
ask_candid/retrieval/sources/issuelab.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1 |
from typing import Dict, Any
|
2 |
-
from ask_candid.retrieval.sources.schema import ElasticSourceConfig
|
3 |
|
|
|
|
|
|
|
|
|
4 |
|
5 |
IssueLabConfig = ElasticSourceConfig(
|
6 |
index_name="search-semantic-issuelab-elser_ve2",
|
@@ -8,11 +11,33 @@ IssueLabConfig = ElasticSourceConfig(
|
|
8 |
)
|
9 |
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
def issuelab_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
12 |
chunks_html = ""
|
13 |
if show_chunks:
|
14 |
cleaned_text = []
|
15 |
-
for
|
16 |
hits = v["hits"]["hits"]
|
17 |
for h in hits:
|
18 |
for k1, v1 in h["fields"].items():
|
|
|
1 |
from typing import Dict, Any
|
|
|
2 |
|
3 |
+
from langchain_core.documents import Document
|
4 |
+
|
5 |
+
from ask_candid.retrieval.sources.schema import ElasticSourceConfig, ElasticHitsResult
|
6 |
+
from ask_candid.retrieval.sources.utils import get_context
|
7 |
|
8 |
IssueLabConfig = ElasticSourceConfig(
|
9 |
index_name="search-semantic-issuelab-elser_ve2",
|
|
|
11 |
)
|
12 |
|
13 |
|
14 |
+
def process_issuelab_hit(hit: ElasticHitsResult) -> Document:
|
15 |
+
combined_item_description = hit.source.get("combined_item_description", "") # title inside
|
16 |
+
description = hit.source.get("description", "")
|
17 |
+
combined_issuelab_findings = hit.source.get("combined_issuelab_findings", "")
|
18 |
+
# we only need to process long texts
|
19 |
+
chunks_with_context_txt = get_context("content", hit, context_length=12)
|
20 |
+
return Document(
|
21 |
+
page_content='\n\n'.join([
|
22 |
+
combined_item_description,
|
23 |
+
combined_issuelab_findings,
|
24 |
+
description,
|
25 |
+
chunks_with_context_txt
|
26 |
+
]),
|
27 |
+
metadata={
|
28 |
+
"title": hit.source["title"],
|
29 |
+
"source": "IssueLab",
|
30 |
+
"source_id": hit.source["resource_id"],
|
31 |
+
"url": hit.source.get("permalink", "")
|
32 |
+
}
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
def issuelab_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
37 |
chunks_html = ""
|
38 |
if show_chunks:
|
39 |
cleaned_text = []
|
40 |
+
for _, v in doc["inner_hits"].items():
|
41 |
hits = v["hits"]["hits"]
|
42 |
for h in hits:
|
43 |
for k1, v1 in h["fields"].items():
|
ask_candid/retrieval/sources/schema.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from typing import Tuple, Optional
|
2 |
from dataclasses import dataclass, field
|
3 |
|
4 |
|
@@ -7,3 +7,14 @@ class ElasticSourceConfig:
|
|
7 |
index_name: str
|
8 |
text_fields: Tuple[str]
|
9 |
excluded_fields: Optional[Tuple[str]] = field(default_factory=tuple)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, Dict, Optional, Any
|
2 |
from dataclasses import dataclass, field
|
3 |
|
4 |
|
|
|
7 |
index_name: str
|
8 |
text_fields: Tuple[str]
|
9 |
excluded_fields: Optional[Tuple[str]] = field(default_factory=tuple)
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class ElasticHitsResult:
|
14 |
+
"""Dataclass for Elasticsearch hits results
|
15 |
+
"""
|
16 |
+
index: str
|
17 |
+
id: Any
|
18 |
+
score: float
|
19 |
+
source: Dict[str, Any]
|
20 |
+
inner_hits: Dict[str, Any]
|
ask_candid/retrieval/sources/utils.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ask_candid.retrieval.sources.schema import ElasticHitsResult
|
2 |
+
|
3 |
+
|
4 |
+
def get_context(field_name: str, hit: ElasticHitsResult, context_length: int = 1024, add_context: bool = True) -> str:
|
5 |
+
"""Pads the relevant chunk of text with context before and after
|
6 |
+
|
7 |
+
Parameters
|
8 |
+
----------
|
9 |
+
field_name : str
|
10 |
+
a field with the long text that was chunked into pieces
|
11 |
+
hit : ElasticHitsResult
|
12 |
+
context_length : int, optional
|
13 |
+
length of text to add before and after the chunk, by default 1024
|
14 |
+
|
15 |
+
Returns
|
16 |
+
-------
|
17 |
+
str
|
18 |
+
longer chunks stuffed together
|
19 |
+
"""
|
20 |
+
|
21 |
+
chunks = []
|
22 |
+
# NOTE chunks have tokens, long text is a normal text, but may contain html that also gets weird after tokenization
|
23 |
+
long_text = hit.source.get(f"{field_name}", "")
|
24 |
+
long_text = long_text.lower()
|
25 |
+
inner_hits_field = f"embeddings.{field_name}.chunks"
|
26 |
+
found_chunks = hit.inner_hits.get(inner_hits_field, {})
|
27 |
+
if found_chunks:
|
28 |
+
hits = found_chunks.get("hits", {}).get("hits", [])
|
29 |
+
for h in hits:
|
30 |
+
chunk = h.get("fields", {})[inner_hits_field][0]["chunk"][0]
|
31 |
+
|
32 |
+
# cutting the middle because we may have tokenizing artifacts there
|
33 |
+
chunk = chunk[3: -3]
|
34 |
+
|
35 |
+
if add_context:
|
36 |
+
# Find the start and end indices of the chunk in the large text
|
37 |
+
start_index = long_text.find(chunk[:20])
|
38 |
+
|
39 |
+
# Chunk is found
|
40 |
+
if start_index != -1:
|
41 |
+
end_index = start_index + len(chunk)
|
42 |
+
pre_start_index = max(0, start_index - context_length)
|
43 |
+
post_end_index = min(len(long_text), end_index + context_length)
|
44 |
+
chunks.append(long_text[pre_start_index:post_end_index])
|
45 |
+
else:
|
46 |
+
chunks.append(chunk)
|
47 |
+
return '\n\n'.join(chunks)
|
ask_candid/retrieval/sources/youtube.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1 |
from typing import Dict, Any
|
2 |
-
from ask_candid.retrieval.sources.schema import ElasticSourceConfig
|
3 |
|
|
|
|
|
|
|
|
|
4 |
|
5 |
YoutubeConfig = ElasticSourceConfig(
|
6 |
index_name="search-semantic-youtube-elser_ve1",
|
@@ -9,6 +12,22 @@ YoutubeConfig = ElasticSourceConfig(
|
|
9 |
)
|
10 |
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
13 |
url = f"https://www.youtube.com/watch?v={doc['video_id']}"
|
14 |
fields = ["title", "description_cleaned"]
|
|
|
1 |
from typing import Dict, Any
|
|
|
2 |
|
3 |
+
from langchain_core.documents import Document
|
4 |
+
|
5 |
+
from ask_candid.retrieval.sources.schema import ElasticSourceConfig, ElasticHitsResult
|
6 |
+
from ask_candid.retrieval.sources.utils import get_context
|
7 |
|
8 |
YoutubeConfig = ElasticSourceConfig(
|
9 |
index_name="search-semantic-youtube-elser_ve1",
|
|
|
12 |
)
|
13 |
|
14 |
|
15 |
+
def process_youtube_hit(hit: ElasticHitsResult) -> Document:
|
16 |
+
title = hit.source.get("title", "")
|
17 |
+
# we only need to process long texts
|
18 |
+
description_cleaned_with_context_txt = get_context("description_cleaned", hit, context_length=12)
|
19 |
+
captions_cleaned_with_context_txt = get_context("captions_cleaned", hit, context_length=12)
|
20 |
+
return Document(
|
21 |
+
page_content='\n\n'.join([title, description_cleaned_with_context_txt, captions_cleaned_with_context_txt]),
|
22 |
+
metadata={
|
23 |
+
"title": title,
|
24 |
+
"source": "Candid YouTube",
|
25 |
+
"source_id": hit.source['video_id'],
|
26 |
+
"url": f"https://www.youtube.com/watch?v={hit.source['video_id']}"
|
27 |
+
}
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
32 |
url = f"https://www.youtube.com/watch?v={doc['video_id']}"
|
33 |
fields = ["title", "description_cleaned"]
|
ask_candid/retrieval/sparse_lexical.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
-
from typing import Dict
|
2 |
|
3 |
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
|
|
4 |
import torch
|
5 |
|
6 |
|
@@ -14,14 +15,23 @@ class SpladeEncoder:
|
|
14 |
self.idx2token = {idx: token for token, idx in self.tokenizer.get_vocab().items()}
|
15 |
|
16 |
@torch.no_grad()
|
17 |
-
def
|
18 |
-
tokens = self.tokenizer(
|
19 |
output = self.model(**tokens)
|
20 |
-
|
21 |
vec = torch.max(
|
22 |
torch.log(1 + torch.relu(output.logits)) * tokens.attention_mask.unsqueeze(-1),
|
23 |
dim=1
|
24 |
)[0].squeeze()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
cols = vec.nonzero().squeeze().cpu().tolist()
|
26 |
weights = vec[cols].cpu().tolist()
|
27 |
|
|
|
1 |
+
from typing import List, Dict
|
2 |
|
3 |
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
4 |
+
from torch.nn import functional as F
|
5 |
import torch
|
6 |
|
7 |
|
|
|
15 |
self.idx2token = {idx: token for token, idx in self.tokenizer.get_vocab().items()}
|
16 |
|
17 |
@torch.no_grad()
|
18 |
+
def forward(self, texts: List[str]):
|
19 |
+
tokens = self.tokenizer(texts, return_tensors='pt', truncation=True, padding=True)
|
20 |
output = self.model(**tokens)
|
|
|
21 |
vec = torch.max(
|
22 |
torch.log(1 + torch.relu(output.logits)) * tokens.attention_mask.unsqueeze(-1),
|
23 |
dim=1
|
24 |
)[0].squeeze()
|
25 |
+
return vec
|
26 |
+
|
27 |
+
def query_reranking(self, query: str, documents: List[str]):
|
28 |
+
vec = self.forward([query, *documents])
|
29 |
+
xQ = F.normalize(vec[:1], dim=-1, p=2.)
|
30 |
+
xD = F.normalize(vec[1:], dim=-1, p=2.)
|
31 |
+
return (xQ * xD).sum(dim=-1).cpu().tolist()
|
32 |
+
|
33 |
+
def token_expand(self, query: str) -> Dict[str, float]:
|
34 |
+
vec = self.forward([query])
|
35 |
cols = vec.nonzero().squeeze().cpu().tolist()
|
36 |
weights = vec[cols].cpu().tolist()
|
37 |
|
ask_candid/tools/elastic/index_search_tool.py
CHANGED
@@ -40,6 +40,7 @@ class SearchToolInput(BaseModel):
|
|
40 |
|
41 |
|
42 |
def elastic_search(
|
|
|
43 |
index_name: str,
|
44 |
query: str,
|
45 |
from_: int = 0,
|
@@ -107,9 +108,15 @@ def elastic_search(
|
|
107 |
return msg
|
108 |
|
109 |
|
110 |
-
def create_search_tool():
|
111 |
return StructuredTool.from_function(
|
112 |
-
elastic_search
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
name="elastic_index_search_tool",
|
114 |
description=(
|
115 |
"""This tool allows executing queries on an Elasticsearch index efficiently. Provide:
|
|
|
40 |
|
41 |
|
42 |
def elastic_search(
|
43 |
+
pcs_codes: dict,
|
44 |
index_name: str,
|
45 |
query: str,
|
46 |
from_: int = 0,
|
|
|
108 |
return msg
|
109 |
|
110 |
|
111 |
+
def create_search_tool(pcs_codes):
|
112 |
return StructuredTool.from_function(
|
113 |
+
func=lambda index_name, query, from_, size: elastic_search(
|
114 |
+
pcs_codes=pcs_codes,
|
115 |
+
index_name=index_name,
|
116 |
+
query=query,
|
117 |
+
from_=from_,
|
118 |
+
size=size,
|
119 |
+
),
|
120 |
name="elastic_index_search_tool",
|
121 |
description=(
|
122 |
"""This tool allows executing queries on an Elasticsearch index efficiently. Provide:
|
ask_candid/tools/question_reformulation.py
CHANGED
@@ -1,55 +1,55 @@
|
|
1 |
from langchain_core.prompts import ChatPromptTemplate
|
2 |
from langchain_core.output_parsers import StrOutputParser
|
|
|
3 |
|
|
|
4 |
|
5 |
-
def reformulate_question_using_history(state, llm, focus_on_recommendations=False):
|
6 |
-
"""
|
7 |
-
Transform the query to produce a better query with details from previous messages and emphasize
|
8 |
-
aspects important for recommendations if needed.
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
"""
|
|
|
19 |
print("---REFORMULATE THE USER INPUT---")
|
20 |
messages = state["messages"]
|
21 |
question = messages[-1].content
|
22 |
|
23 |
-
if len(messages) > 1:
|
24 |
if focus_on_recommendations:
|
25 |
-
prompt_text = """Given a chat history and the latest user input
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
{chat_history}
|
32 |
-
\n ------- \n
|
33 |
-
User input:
|
34 |
-
\n ------- \n
|
35 |
-
{question}
|
36 |
-
\n ------- \n
|
37 |
Reformulate the question without adding implications or assumptions about the user's needs or intentions.
|
38 |
Focus solely on clarifying any contextual details present in the original input."""
|
39 |
else:
|
40 |
-
prompt_text = """Given a chat history and the latest user input
|
41 |
-
which
|
42 |
-
|
43 |
-
Chat history:
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
User input:
|
48 |
-
\n ------- \n
|
49 |
-
{question}
|
50 |
-
\n ------- \n
|
51 |
-
Do NOT answer the question, \
|
52 |
-
just reformulate it if needed and otherwise return it as is.
|
53 |
"""
|
54 |
|
55 |
contextualize_q_prompt = ChatPromptTemplate([
|
@@ -58,7 +58,11 @@ def reformulate_question_using_history(state, llm, focus_on_recommendations=Fals
|
|
58 |
])
|
59 |
|
60 |
rag_chain = contextualize_q_prompt | llm | StrOutputParser()
|
61 |
-
new_question = rag_chain.invoke({"chat_history": messages, "question": question})
|
|
|
|
|
|
|
|
|
62 |
print(f"user asked: '{question}', agent reformulated the question basing on the chat history: {new_question}")
|
63 |
return {"messages": [new_question], "user_input" : question}
|
64 |
return {"messages": [question], "user_input" : question}
|
|
|
1 |
from langchain_core.prompts import ChatPromptTemplate
|
2 |
from langchain_core.output_parsers import StrOutputParser
|
3 |
+
from langchain_core.language_models.llms import LLM
|
4 |
|
5 |
+
from ask_candid.agents.schema import AgentState
|
6 |
|
|
|
|
|
|
|
|
|
7 |
|
8 |
+
def reformulate_question_using_history(
|
9 |
+
state: AgentState,
|
10 |
+
llm: LLM,
|
11 |
+
focus_on_recommendations: bool = False
|
12 |
+
) -> AgentState:
|
13 |
+
"""Transform the query to produce a better query with details from previous messages and emphasize aspects important
|
14 |
+
for recommendations if needed.
|
15 |
+
|
16 |
+
Parameters
|
17 |
+
----------
|
18 |
+
state : AgentState
|
19 |
+
The current state
|
20 |
+
llm : LLM
|
21 |
+
focus_on_recommendations : bool, optional
|
22 |
+
Flag to determine if the reformulation should emphasize recommendation-relevant aspects such as geographies,
|
23 |
+
cause areas, etc., by default False
|
24 |
+
|
25 |
+
Returns
|
26 |
+
-------
|
27 |
+
AgentState
|
28 |
+
The updated state
|
29 |
"""
|
30 |
+
|
31 |
print("---REFORMULATE THE USER INPUT---")
|
32 |
messages = state["messages"]
|
33 |
question = messages[-1].content
|
34 |
|
35 |
+
if len(messages[:-1]) > 1: # need to skip the system message
|
36 |
if focus_on_recommendations:
|
37 |
+
prompt_text = """Given a chat history and the latest user input which might reference context in the chat
|
38 |
+
history, especially geographic locations, cause areas and/or population groups, formulate a standalone input
|
39 |
+
which can be understood without the chat history.
|
40 |
+
Chat history: ```{chat_history}```
|
41 |
+
User input: ```{question}```
|
42 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
Reformulate the question without adding implications or assumptions about the user's needs or intentions.
|
44 |
Focus solely on clarifying any contextual details present in the original input."""
|
45 |
else:
|
46 |
+
prompt_text = """Given a chat history and the latest user input which might reference context in the chat
|
47 |
+
history, formulate a standalone input which can be understood without the chat history. Include hints as to
|
48 |
+
what the user is getting at given the context in the chat history.
|
49 |
+
Chat history: ```{chat_history}```
|
50 |
+
User input: ```{question}```
|
51 |
+
|
52 |
+
Do NOT answer the question, just reformulate it if needed and otherwise return it as is.
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
"""
|
54 |
|
55 |
contextualize_q_prompt = ChatPromptTemplate([
|
|
|
58 |
])
|
59 |
|
60 |
rag_chain = contextualize_q_prompt | llm | StrOutputParser()
|
61 |
+
# new_question = rag_chain.invoke({"chat_history": messages, "question": question})
|
62 |
+
new_question = rag_chain.invoke({
|
63 |
+
"chat_history": '\n'.join(f"{m.type.upper()}: {m.content}" for m in messages[1:]),
|
64 |
+
"question": question
|
65 |
+
})
|
66 |
print(f"user asked: '{question}', agent reformulated the question basing on the chat history: {new_question}")
|
67 |
return {"messages": [new_question], "user_input" : question}
|
68 |
return {"messages": [question], "user_input" : question}
|