brainsqueeze commited on
Commit
f86d7f2
·
verified ·
1 Parent(s): cc80c3d

Smarter document context retrieval

Browse files

* Retrieved documents re-ranking w/ SPLADE-v3
* Enable news as a default source

app.py CHANGED
@@ -113,7 +113,8 @@ def build_rag_chat() -> Tuple[LoggedComponents, gr.Blocks]:
113
  with gr.Accordion(label="Advanced settings", open=False):
114
  es_indices = gr.CheckboxGroup(
115
  choices=list(ALL_INDICES),
116
- value=[idx for idx in ALL_INDICES if "news" not in idx],
 
117
  label="Sources to include",
118
  interactive=True,
119
  )
 
113
  with gr.Accordion(label="Advanced settings", open=False):
114
  es_indices = gr.CheckboxGroup(
115
  choices=list(ALL_INDICES),
116
+ # value=[idx for idx in ALL_INDICES if "news" not in idx],
117
+ value=list(ALL_INDICES),
118
  label="Sources to include",
119
  interactive=True,
120
  )
ask_candid/agents/elastic.py CHANGED
@@ -2,6 +2,9 @@ from typing import TypedDict, List
2
  from functools import partial
3
  import json
4
  import ast
 
 
 
5
  from pydantic import BaseModel, Field
6
 
7
  from langchain_core.runnables import RunnableSequence
@@ -24,10 +27,118 @@ from ask_candid.tools.elastic.index_search_tool import create_search_tool
24
  tools = [
25
  IndexShowDataTool(),
26
  IndexDetailsTool(),
27
- create_search_tool(),
28
  ]
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  class GraphState(TypedDict):
32
  query: str = Field(
33
  ..., description="The user's query to be processed by the system."
@@ -47,6 +158,7 @@ class GraphState(TypedDict):
47
  ...,
48
  description="The Elasticsearch query result generated or used by the agent.",
49
  )
 
50
 
51
 
52
  class AnalysisResult(BaseModel):
@@ -334,8 +446,6 @@ def build_compute_graph(llm: LLM) -> StateGraph:
334
 
335
 
336
  class ElasticGraph(StateGraph):
337
- """Elastic Seach Agent State Graph"""
338
-
339
  llm: LLM
340
  tools: List[Tool]
341
 
@@ -345,6 +455,41 @@ class ElasticGraph(StateGraph):
345
  self.tools = tools
346
  self.construct_graph()
347
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  def agent_factory(self) -> AgentExecutor:
349
  """
350
  Creates and configures an AgentExecutor instance for interacting with Elasticsearch.
@@ -387,7 +532,7 @@ class ElasticGraph(StateGraph):
387
  return_intermediate_steps=True,
388
  )
389
 
390
- def agent_factory_claude(self) -> AgentExecutor:
391
  """
392
  Creates and configures an AgentExecutor instance for interacting with Elasticsearch.
393
 
@@ -400,15 +545,6 @@ class ElasticGraph(StateGraph):
400
  AgentExecutor: Configured agent ready to execute tasks with specified tools,
401
  providing detailed intermediate steps for transparency.
402
  """
403
- prefix = """
404
- You are an intelligent agent tasked with generating accurate Elasticsearch DSL queries.
405
- Analyze the intent behind the query and determine the appropriate Elasticsearch operations required.
406
- Guidelines for generating right elastic seach query:
407
- 1. Automatically determine whether to return document hits or aggregation results based on the query structure.
408
- 2. Use keyword fields instead of text fields for aggregations and sorting to avoid fielddata errors
409
- 3. Avoid using field.keyword if a keyword field is already present to prevent redundant queries.
410
- 4. Ensure efficient query execution by selecting appropriate query types for filtering, searching, and aggregating.
411
- """
412
  prompt = ChatPromptTemplate.from_messages(
413
  [
414
  ("system", f"You are a helpful elasticsearch assistant. {prefix}"),
@@ -418,9 +554,19 @@ class ElasticGraph(StateGraph):
418
  ]
419
  )
420
 
421
- agent = create_tool_calling_agent(self.llm, self.tools, prompt)
 
 
 
 
 
 
 
422
  agent_executor = AgentExecutor.from_agent_and_tools(
423
- agent=agent, tools=self.tools, verbose=True, return_intermediate_steps=True
 
 
 
424
  )
425
  # Create the agent
426
  return agent_executor
@@ -467,6 +613,8 @@ class ElasticGraph(StateGraph):
467
 
468
  def grant_index_agent(self, state: GraphState) -> GraphState:
469
  print("> Grant Index Agent")
 
 
470
  input_data = {
471
  "input": f"""
472
  You are an Elasticsearch database agent designed to accurately understand and respond to user queries. Follow these steps:
@@ -479,52 +627,51 @@ class ElasticGraph(StateGraph):
479
  Users may not always provide the exact name, so the Elasticsearch query should accommodate partial or incomplete names
480
  by searching for relevant keywords.
481
  6. Present the response in a clear and natural language format, addressing the user's question directly.
482
-
483
-
484
  Description of some of the fields in the index but rest of the fields which are not here should be easy to understand:
485
- fiscal_year: Year when grantmaker allocates budget for funding and grants. format YYYY
486
- text: Objectives,mission, program and funding related information
487
- Program_area: program area where organization is working on
488
- Title: the title of the funding
489
- pcs_v3: PCS is taxonomy, describing the work of grantmakers, recipient organizations and the philanthropic transactions between those entities.
490
- The facets of the PCS illuminate the work and answer the following questions about philanthropy:
491
- Who? = Population Served
492
- What? = Subject and Organization Type
493
- How? = Support Strategy and Transaction Type
494
- the Facets:
495
- Subjects: Describes WHAT is being supported. Example: Elementary education or Clean water supply.
496
- Populations: Describes WHO is being supported. Example: Girls or People with disabilities.
497
- Organization Type: Describes WHAT type of organization is providing or receiving support.
498
- Transaction Type: Describes HOW support is being provided.
499
- Support Strategies: Describes HOW activities are being implemented.
500
-
501
- pcs_v3 itself is in a json format:
502
- key - subject
503
- value: it is a list of dictionary so might need to loop around to find the particular aspect
504
- hierarchy: (it is a list having subject name)
505
- [
506
- {{
507
- 'name':
508
- }},
509
- {{
510
- 'name':
511
- }}
512
- ]
513
- Before Writing elastic search query think through which field to use
514
-
515
- Note: first you should focus on query `text` then look into pcs_v3. Make sure you pick the right size for the query
516
 
 
517
  User's query:
518
  ```{state["query"]}```
519
  """
520
  }
521
- agent_exec = self.agent_factory_claude()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
  res = agent_exec.invoke(input_data)
523
  state["agent_out"] = res["output"]
524
-
525
  es_queries, es_results = {}, {}
526
  for i, action in enumerate(res.get("intermediate_steps", []), start=1):
527
  if action[0].tool == "elastic_index_search_tool":
 
528
  es_queries[f"query_{i}"] = json.loads(
529
  action[0].tool_input.get("query") or "{}"
530
  )
@@ -550,6 +697,18 @@ class ElasticGraph(StateGraph):
550
  """
551
 
552
  print("> Org Index Agent")
 
 
 
 
 
 
 
 
 
 
 
 
553
  input_data = {
554
  "input": f"""
555
  You are an Elasticsearch database agent designed to accurately understand and respond to user queries. Follow these steps:
@@ -557,14 +716,45 @@ class ElasticGraph(StateGraph):
557
  1. Understand the user query to determine the required information.
558
  2. Query the indices in the Elasticsearch database.
559
  3. Retrieve the mappings and field names relevant to the query.
560
- 4. Use the `organization_qa_2` index to extract the necessary data.
561
  5. Present the response in a clear and natural language format, addressing the user's question directly.
 
 
 
 
 
562
 
563
- User's quer:
564
  ```{state["query"]}```
565
  """
566
  }
567
- agent_exec = self.agent_factory_claude()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
568
  res = agent_exec.invoke(input_data)
569
  state["agent_out"] = res["output"]
570
 
@@ -622,13 +812,15 @@ class ElasticGraph(StateGraph):
622
  """
623
 
624
  # Add nodes
 
625
  self.add_node("analyse", self.analyse_query)
626
  self.add_node("grant-index", self.grant_index_agent)
627
  self.add_node("org-index", self.org_index_agent)
628
  self.add_node("final_answer", self.final_answer)
629
 
630
  # Set entry point
631
- self.set_entry_point("analyse")
 
632
 
633
  # Add conditional edges
634
  self.add_conditional_edges(
 
2
  from functools import partial
3
  import json
4
  import ast
5
+ from ask_candid.base.api_base import BaseAPI
6
+ import os
7
+ import pandas as pd
8
  from pydantic import BaseModel, Field
9
 
10
  from langchain_core.runnables import RunnableSequence
 
27
  tools = [
28
  IndexShowDataTool(),
29
  IndexDetailsTool(),
30
+ create_search_tool(pcs_codes={}),
31
  ]
32
 
33
 
34
+ class AutocodingAPI(BaseAPI):
35
+ def __init__(self):
36
+ super().__init__(
37
+ url=os.getenv("AUTOCODING_API_URL"),
38
+ headers={
39
+ "x-api-key": os.getenv("AUTOCODING_API_KEY"),
40
+ "Content-Type": "application/json",
41
+ },
42
+ )
43
+
44
+ def __call__(self, text: str, taxonomy: str = "pcs-v3"):
45
+ params = {"text": text, "taxonomy": taxonomy}
46
+ return self.get(**params)
47
+
48
+
49
+ def find_subject_levels(filtered_df, subject_level_i, target_value):
50
+ """
51
+ Filters the DataFrame from the last valid NaN in 'Subject Level i' and retrieves corresponding values for lower levels.
52
+
53
+ Parameters:
54
+ filtered_df (pd.DataFrame): The input DataFrame.
55
+ subject_level_i (int): The subject level to filter from (1 to 4).
56
+ target_value (str): The value to search for in 'Subject Level i'.
57
+
58
+ Returns:
59
+ dict: A dictionary containing values for 'Subject Level i' to 'Subject Level 1'.
60
+ pd.DataFrame: The filtered DataFrame from the determined start index to the target_value row.
61
+ """
62
+ if subject_level_i < 1 or subject_level_i > 4:
63
+ raise ValueError("subject_level_i should be between 1 and 4")
64
+
65
+ # Define the target column dynamically
66
+ target_column = f"Subject Level {subject_level_i}"
67
+
68
+ # Find indices where the target column has the target value
69
+ target_indices = filtered_df[
70
+ filtered_df[target_column].astype(str).str.strip() == target_value
71
+ ].index
72
+
73
+ if target_indices.empty:
74
+ return {}, pd.DataFrame() # Return empty if target_value is not found
75
+
76
+ # Get the first occurrence of the target value
77
+ first_target_index = target_indices[0]
78
+
79
+ # Initialize dictionary to store subject level values
80
+ subject_level_values = {target_column: target_value}
81
+
82
+ # Initialize subject level start index
83
+ subject_level_start = first_target_index
84
+
85
+ # Find the last non-NaN row for each subject level
86
+ for level in range(subject_level_i - 1, 0, -1): # Loop from subject_level_i-1 to 1
87
+ column_name = f"Subject Level {level}"
88
+
89
+ # Start checking above the previous found index
90
+ current_index = subject_level_start - 1
91
+
92
+ while current_index >= 0 and pd.isna(
93
+ filtered_df.loc[current_index, column_name]
94
+ ):
95
+ current_index -= 1 # Move up while NaN is found
96
+
97
+ # Move one row down to get the last valid row in 'Subject Level level'
98
+ subject_level_start = current_index + 1
99
+
100
+ # Ensure we store the correct value at each subject level
101
+ if subject_level_start in filtered_df.index:
102
+ subject_level_values[column_name] = filtered_df.loc[
103
+ subject_level_start - 1, column_name
104
+ ]
105
+
106
+ # Ensure valid slicing range
107
+ min_start_index = subject_level_start
108
+
109
+ if min_start_index < first_target_index:
110
+ filtered_df = filtered_df.loc[min_start_index:first_target_index]
111
+ else:
112
+ filtered_df = pd.DataFrame()
113
+
114
+ return subject_level_values, filtered_df
115
+
116
+
117
+ def extract_heirarchy(full_code, target_value):
118
+ # df = pd.read_excel(
119
+ # r"C:\Users\mukul.rawat\OneDrive - Candid\Documents\Projects\Gen AI\azure_devops\ask-candid-assistant\PCS_Taxonomy_Definitions_2024.xlsx"
120
+ # )
121
+ df = pd.read_excel(r"C:\Users\siqi.deng\Downloads\PCS_Taxonomy_Definitions_2024.xlsx")
122
+ filtered_df = df[df["PCS Code"].str.startswith(full_code[:2], na=False)]
123
+ for i in range(1, 5):
124
+ column_name = f"Subject Level {i}"
125
+ if (df[column_name].str.strip() == target_value).any():
126
+ break
127
+
128
+ subject_level_values, filtered_df = find_subject_levels(
129
+ filtered_df, i, target_value
130
+ )
131
+ sorted_values = [
132
+ value
133
+ for key, value in sorted(
134
+ subject_level_values.items(), key=lambda x: int(x[0].split()[-1])
135
+ )
136
+ ]
137
+ # Joining values in the required format
138
+ result = " : ".join(sorted_values)
139
+ return result
140
+
141
+
142
  class GraphState(TypedDict):
143
  query: str = Field(
144
  ..., description="The user's query to be processed by the system."
 
158
  ...,
159
  description="The Elasticsearch query result generated or used by the agent.",
160
  )
161
+ pcs_codes: dict = Field(..., description="pcs codes")
162
 
163
 
164
  class AnalysisResult(BaseModel):
 
446
 
447
 
448
  class ElasticGraph(StateGraph):
 
 
449
  llm: LLM
450
  tools: List[Tool]
451
 
 
455
  self.tools = tools
456
  self.construct_graph()
457
 
458
+ def Extract_PCS_Codes(self, state):
459
+ """Todo: Add Subject heirarchies, Population, Geo"""
460
+ print("query", state["query"])
461
+ autocoding_api = AutocodingAPI()
462
+ autocoding_response = autocoding_api(text=state["query"]).get("data", {})
463
+ # population_served = autocoding_response.get("population", {})
464
+ subjects = autocoding_response.get("subject", {})
465
+ descriptions = []
466
+ heirarchy_string = []
467
+ if subjects and isinstance(subjects, list) and "description" in subjects[0]:
468
+ for subject in subjects:
469
+ # if subject['description'] in subjects_list:
470
+ descriptions.append(subject["description"])
471
+ heirarchy_string.append(
472
+ extract_heirarchy(subject["full_code"], subject["description"])
473
+ )
474
+ print("descriptions", descriptions)
475
+
476
+ populations = autocoding_response.get("population", {})
477
+ population_dict = []
478
+ if (
479
+ populations
480
+ and isinstance(populations, list)
481
+ and "description" in populations[0]
482
+ ):
483
+ for population in populations:
484
+ population_dict.append(population["description"])
485
+ state["pcs_codes"] = {
486
+ "subject": descriptions,
487
+ "heirarchy_string": heirarchy_string,
488
+ "population": population_dict,
489
+ }
490
+ print("pcs_codes_new", state["pcs_codes"])
491
+ return state
492
+
493
  def agent_factory(self) -> AgentExecutor:
494
  """
495
  Creates and configures an AgentExecutor instance for interacting with Elasticsearch.
 
532
  return_intermediate_steps=True,
533
  )
534
 
535
+ def agent_factory_claude(self, pcs_codes, prefix) -> AgentExecutor:
536
  """
537
  Creates and configures an AgentExecutor instance for interacting with Elasticsearch.
538
 
 
545
  AgentExecutor: Configured agent ready to execute tasks with specified tools,
546
  providing detailed intermediate steps for transparency.
547
  """
 
 
 
 
 
 
 
 
 
548
  prompt = ChatPromptTemplate.from_messages(
549
  [
550
  ("system", f"You are a helpful elasticsearch assistant. {prefix}"),
 
554
  ]
555
  )
556
 
557
+ tools = [
558
+ # ListIndicesTool(),
559
+ IndexShowDataTool(),
560
+ IndexDetailsTool(),
561
+ create_search_tool(pcs_codes=pcs_codes),
562
+ ]
563
+ agent = create_tool_calling_agent(self.llm, tools, prompt)
564
+
565
  agent_executor = AgentExecutor.from_agent_and_tools(
566
+ agent=agent,
567
+ tools=tools,
568
+ verbose=True,
569
+ return_intermediate_steps=True,
570
  )
571
  # Create the agent
572
  return agent_executor
 
613
 
614
  def grant_index_agent(self, state: GraphState) -> GraphState:
615
  print("> Grant Index Agent")
616
+ # autocoding test
617
+
618
  input_data = {
619
  "input": f"""
620
  You are an Elasticsearch database agent designed to accurately understand and respond to user queries. Follow these steps:
 
627
  Users may not always provide the exact name, so the Elasticsearch query should accommodate partial or incomplete names
628
  by searching for relevant keywords.
629
  6. Present the response in a clear and natural language format, addressing the user's question directly.
630
+
 
631
  Description of some of the fields in the index but rest of the fields which are not here should be easy to understand:
632
+ *fiscal_year: Year when grantmaker allocates budget for funding and grants. format YYYY
633
+ *recipient_state: is abbreviated for eg. NY, FL, CA
634
+ *recipient_city - Full Name of the City e.g, New York City, Boston
635
+ *recipient_country - Country Abbreviation of the recipient organization e.g USA
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
636
 
637
+ Note: Do not include `title`, `program_area`, `text` field in the elastic search query
638
  User's query:
639
  ```{state["query"]}```
640
  """
641
  }
642
+ pcs_codes = state["pcs_codes"]
643
+ pcs_match_term = ""
644
+ for pcs_code in pcs_codes["subject"]:
645
+ if pcs_code != "Philanthropy":
646
+ pcs_match_term += f"*'pcs_v3.subject.value.name': {pcs_code}* \n"
647
+
648
+ for pcs_code in pcs_codes["population"]:
649
+ if pcs_code != "Other population":
650
+ pcs_match_term += f"*'pcs_v3.population.value.name': {pcs_code}* \n"
651
+ print("pcs_match_term", pcs_match_term)
652
+ prefix = f"""
653
+ You are an intelligent agent tasked with generating accurate Elasticsearch DSL queries.
654
+ Analyze the intent behind the query and determine the appropriate Elasticsearch operations required.
655
+ Guidelines for generating right elastic seach query:
656
+ 1. Automatically determine whether to return document hits or aggregation results based on the query structure.
657
+ 2. Use keyword fields instead of text fields for aggregations and sorting to avoid fielddata errors
658
+ 3. Avoid using field.keyword if a keyword field is already present to prevent redundant queries.
659
+ 4. Ensure efficient query execution by selecting appropriate query types for filtering, searching, and aggregating.
660
+
661
+ Instruction for pcs_v3 Field-
662
+ If {pcs_codes['subject']} not empty:
663
+ Only include all of the following match terms. No other pcs_v3 fields should be added, duplicated, or altered except for those listed below.
664
+ - {pcs_match_term}
665
+ """
666
+ agent_exec = self.agent_factory_claude(
667
+ pcs_codes=state["pcs_codes"], prefix=prefix
668
+ )
669
  res = agent_exec.invoke(input_data)
670
  state["agent_out"] = res["output"]
 
671
  es_queries, es_results = {}, {}
672
  for i, action in enumerate(res.get("intermediate_steps", []), start=1):
673
  if action[0].tool == "elastic_index_search_tool":
674
+ print("query", action[0].tool_input.get("query"))
675
  es_queries[f"query_{i}"] = json.loads(
676
  action[0].tool_input.get("query") or "{}"
677
  )
 
697
  """
698
 
699
  print("> Org Index Agent")
700
+ mapping_description = """
701
+ "admin1_code": "state abbreviation"
702
+ "admin1_description": "Full name/label of the state"
703
+ "city": Full Name of the city with 1st letter being capital for e.g. New York City
704
+ "assets": "The assets value of the most recent fiscals available for the organization."
705
+ "country_code": "Country abbreviation"
706
+ "country_name": "Country name"
707
+ "fiscal_year": "The year of the most recent fiscals available for the organization. (YYYY format)"
708
+ "mission_statement": "The mission statement of the organization."
709
+ "roles": "grantmaker: Indicates the organization gives grants., recipient: Indicates the organization receives grants., company: Indicates the organization is a company/corporation."
710
+
711
+ """
712
  input_data = {
713
  "input": f"""
714
  You are an Elasticsearch database agent designed to accurately understand and respond to user queries. Follow these steps:
 
716
  1. Understand the user query to determine the required information.
717
  2. Query the indices in the Elasticsearch database.
718
  3. Retrieve the mappings and field names relevant to the query.
719
+ 4. Use the `organization_qa_ds1` index to extract the necessary data.
720
  5. Present the response in a clear and natural language format, addressing the user's question directly.
721
+
722
+
723
+ Given Below is mapping description of some of the fields
724
+ ```{mapping_description}```
725
+
726
 
727
+ User's query:
728
  ```{state["query"]}```
729
  """
730
  }
731
+
732
+ pcs_codes = state["pcs_codes"]
733
+ pcs_match_term = ""
734
+ for pcs_code in pcs_codes["subject"]:
735
+ pcs_match_term += f'"taxonomy_descriptions": "{pcs_code}" \n"'
736
+
737
+ print("pcs_match_term", pcs_match_term)
738
+ prefix = f"""You are an intelligent agent tasked with generating accurate Elasticsearch DSL queries.
739
+ Analyze the intent behind the query and determine the appropriate Elasticsearch operations required.
740
+ Guidelines for generating right elastic seach query:
741
+ 1. Automatically determine whether to return document hits or aggregation results based on the query structure.
742
+ 2. Use keyword fields instead of text fields for aggregations and sorting to avoid fielddata errors
743
+ 3. Avoid using field.keyword if a keyword field is already present to prevent redundant queries.
744
+ 4. Ensure efficient query execution by selecting appropriate query types for filtering, searching, and aggregating.
745
+
746
+ Instructions to use `taxonomy_descriptions` field:
747
+ If {pcs_codes['subject']} not empty, only add the following match term:
748
+ Only add the following `match` term, No other `taxonomy_descriptions` fields should be added, duplicated, or modified except belowIf {pcs_codes['subject']} not empty,
749
+ - {pcs_match_term}
750
+
751
+
752
+ Avoid using `ntee_major_description` field in the es query
753
+
754
+ """
755
+ agent_exec = self.agent_factory_claude(
756
+ pcs_codes=state["pcs_codes"], prefix=prefix
757
+ )
758
  res = agent_exec.invoke(input_data)
759
  state["agent_out"] = res["output"]
760
 
 
812
  """
813
 
814
  # Add nodes
815
+ self.add_node("Context_Extraction", self.Extract_PCS_Codes)
816
  self.add_node("analyse", self.analyse_query)
817
  self.add_node("grant-index", self.grant_index_agent)
818
  self.add_node("org-index", self.org_index_agent)
819
  self.add_node("final_answer", self.final_answer)
820
 
821
  # Set entry point
822
+ self.set_entry_point("Context_Extraction")
823
+ self.add_edge("Context_Extraction", "analyse")
824
 
825
  # Add conditional edges
826
  self.add_conditional_edges(
ask_candid/retrieval/elastic.py CHANGED
@@ -1,5 +1,4 @@
1
  from typing import List, Tuple, Dict, Iterable, Iterator, Optional, Union, Any
2
- from dataclasses import dataclass
3
  from itertools import groupby
4
 
5
  from torch.nn import functional as F
@@ -10,12 +9,14 @@ from langchain_core.documents import Document
10
  from elasticsearch import Elasticsearch
11
 
12
  from ask_candid.retrieval.sparse_lexical import SpladeEncoder
13
- from ask_candid.retrieval.sources.issuelab import IssueLabConfig
14
- from ask_candid.retrieval.sources.youtube import YoutubeConfig
15
- from ask_candid.retrieval.sources.candid_blog import CandidBlogConfig
16
- from ask_candid.retrieval.sources.candid_learning import CandidLearningConfig
17
- from ask_candid.retrieval.sources.candid_help import CandidHelpConfig
18
- from ask_candid.retrieval.sources.candid_news import CandidNewsConfig
 
 
19
  from ask_candid.services.small_lm import CandidSLM
20
  from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA, NEWS_ELASTIC
21
  from ask_candid.base.config.data import DataIndices, ALL_INDICES
@@ -23,17 +24,6 @@ from ask_candid.base.config.data import DataIndices, ALL_INDICES
23
  encoder = SpladeEncoder()
24
 
25
 
26
- @dataclass
27
- class ElasticHitsResult:
28
- """Dataclass for Elasticsearch hits results
29
- """
30
- index: str
31
- id: Any
32
- score: float
33
- source: Dict[str, Any]
34
- inner_hits: Dict[str, Any]
35
-
36
-
37
  class RetrieverInput(BaseModel):
38
  """Input to the Elasticsearch retriever."""
39
  user_input: str = Field(description="query to look up in retriever")
@@ -101,7 +91,7 @@ def news_query_builder(query: str) -> Dict[str, Any]:
101
  tokens = encoder.token_expand(query)
102
 
103
  query = {
104
- "_source": ["id", "link", "title", "content"],
105
  "query": {
106
  "bool": {
107
  "filter": [
@@ -150,27 +140,27 @@ def query_builder(query: str, indices: List[DataIndices]) -> Tuple[List[Dict[str
150
  if index == "issuelab":
151
  q = build_sparse_vector_query(query=query, fields=IssueLabConfig.text_fields)
152
  q["_source"] = {"excludes": ["embeddings"]}
153
- q["size"] = 1
154
  queries.extend([{"index": IssueLabConfig.index_name}, q])
155
  elif index == "youtube":
156
  q = build_sparse_vector_query(query=query, fields=YoutubeConfig.text_fields)
157
  q["_source"] = {"excludes": ["embeddings", *YoutubeConfig.excluded_fields]}
158
- q["size"] = 2
159
  queries.extend([{"index": YoutubeConfig.index_name}, q])
160
  elif index == "candid_blog":
161
  q = build_sparse_vector_query(query=query, fields=CandidBlogConfig.text_fields)
162
  q["_source"] = {"excludes": ["embeddings"]}
163
- q["size"] = 2
164
  queries.extend([{"index": CandidBlogConfig.index_name}, q])
165
  elif index == "candid_learning":
166
  q = build_sparse_vector_query(query=query, fields=CandidLearningConfig.text_fields)
167
  q["_source"] = {"excludes": ["embeddings"]}
168
- q["size"] = 2
169
  queries.extend([{"index": CandidLearningConfig.index_name}, q])
170
  elif index == "candid_help":
171
  q = build_sparse_vector_query(query=query, fields=CandidHelpConfig.text_fields)
172
  q["_source"] = {"excludes": ["embeddings"]}
173
- q["size"] = 2
174
  queries.extend([{"index": CandidHelpConfig.index_name}, q])
175
  elif index == "news":
176
  q = news_query_builder(query=query)
@@ -199,12 +189,18 @@ def multi_search(
199
  def _msearch_response_generator(responses: List[Dict[str, Any]]) -> Iterator[ElasticHitsResult]:
200
  for query_group in responses:
201
  for h in query_group.get("hits", {}).get("hits", []):
 
 
 
 
 
 
202
  yield ElasticHitsResult(
203
  index=h["_index"],
204
  id=h["_id"],
205
  score=h["_score"],
206
  source=h["_source"],
207
- inner_hits=h.get("inner_hits", {})
208
  )
209
 
210
  results = []
@@ -264,6 +260,10 @@ def retrieved_text(hits: Dict[str, Any]) -> str:
264
 
265
  text = []
266
  for _, v in hits.items():
 
 
 
 
267
  for h in (v.get("hits", {}).get("hits") or []):
268
  for _, field in h.get("fields", {}).items():
269
  for chunk in field:
@@ -298,7 +298,8 @@ def cosine_rescore(query: str, contexts: List[str]) -> List[float]:
298
 
299
  def reranker(
300
  query_results: Iterable[ElasticHitsResult],
301
- search_text: Optional[str] = None
 
302
  ) -> Iterator[ElasticHitsResult]:
303
  """Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales.
304
  This will shuffle results
@@ -327,58 +328,13 @@ def reranker(
327
  text = retrieved_text(d.inner_hits)
328
  texts.append(text)
329
 
330
- # if search_text and len(texts) == len(results):
331
- # scores = cosine_rescore(search_text, texts)
332
- # for r, s in zip(results, scores):
333
- # r.score = s
334
-
335
- yield from sorted(results, key=lambda x: x.score, reverse=True)
336
 
337
-
338
- def get_context(field_name: str, hit: ElasticHitsResult, context_length: int = 1024, add_context: bool = True) -> str:
339
- """Pads the relevant chunk of text with context before and after
340
-
341
- Parameters
342
- ----------
343
- field_name : str
344
- a field with the long text that was chunked into pieces
345
- hit : ElasticHitsResult
346
- context_length : int, optional
347
- length of text to add before and after the chunk, by default 1024
348
-
349
- Returns
350
- -------
351
- str
352
- longer chunks stuffed together
353
- """
354
-
355
- chunks = []
356
- # NOTE chunks have tokens, long text is a normal text, but may contain html that also gets weird after tokenization
357
- long_text = hit.source.get(f"{field_name}", "")
358
- long_text = long_text.lower()
359
- inner_hits_field = f"embeddings.{field_name}.chunks"
360
- found_chunks = hit.inner_hits.get(inner_hits_field, {})
361
- if found_chunks:
362
- hits = found_chunks.get("hits", {}).get("hits", [])
363
- for h in hits:
364
- chunk = h.get("fields", {})[inner_hits_field][0]["chunk"][0]
365
-
366
- # cutting the middle because we may have tokenizing artifacts there
367
- chunk = chunk[3: -3]
368
-
369
- if add_context:
370
- # Find the start and end indices of the chunk in the large text
371
- start_index = long_text.find(chunk[:20])
372
-
373
- # Chunk is found
374
- if start_index != -1:
375
- end_index = start_index + len(chunk)
376
- pre_start_index = max(0, start_index - context_length)
377
- post_end_index = min(len(long_text), end_index + context_length)
378
- chunks.append(long_text[pre_start_index:post_end_index])
379
- else:
380
- chunks.append(chunk)
381
- return '\n\n'.join(chunks)
382
 
383
 
384
  def process_hit(hit: ElasticHitsResult) -> Union[Document, None]:
@@ -394,94 +350,17 @@ def process_hit(hit: ElasticHitsResult) -> Union[Document, None]:
394
  """
395
 
396
  if "issuelab-elser" in hit.index:
397
- combined_item_description = hit.source.get("combined_item_description", "") # title inside
398
- description = hit.source.get("description", "")
399
- combined_issuelab_findings = hit.source.get("combined_issuelab_findings", "")
400
- # we only need to process long texts
401
- chunks_with_context_txt = get_context("content", hit, context_length=12)
402
- doc = Document(
403
- page_content='\n\n'.join([
404
- combined_item_description,
405
- combined_issuelab_findings,
406
- description,
407
- chunks_with_context_txt
408
- ]),
409
- metadata={
410
- "title": hit.source["title"],
411
- "source": "IssueLab",
412
- "source_id": hit.source["resource_id"],
413
- "url": hit.source.get("permalink", "")
414
- }
415
- )
416
  elif "youtube" in hit.index:
417
- title = hit.source.get("title", "")
418
- # we only need to process long texts
419
- description_cleaned_with_context_txt = get_context("description_cleaned", hit, context_length=12)
420
- captions_cleaned_with_context_txt = get_context("captions_cleaned", hit, context_length=12)
421
- doc = Document(
422
- page_content='\n\n'.join([title, description_cleaned_with_context_txt, captions_cleaned_with_context_txt]),
423
- metadata={
424
- "title": title,
425
- "source": "Candid YouTube",
426
- "source_id": hit.source['video_id'],
427
- "url": f"https://www.youtube.com/watch?v&#61;{hit.source['video_id']}"
428
- }
429
- )
430
  elif "candid-blog" in hit.index:
431
- excerpt = hit.source.get("excerpt", "")
432
- title = hit.source.get("title", "")
433
- # we only need to process long text
434
- content_with_context_txt = get_context("content", hit, context_length=12, add_context=False)
435
- authors = get_context("authors_text", hit, context_length=12, add_context=False)
436
- tags = hit.source.get("title_summary_tags", "")
437
- doc = Document(
438
- page_content='\n\n'.join([title, excerpt, content_with_context_txt, authors, tags]),
439
- metadata={
440
- "title": title,
441
- "source": "Candid Blog",
442
- "source_id": hit.source["id"],
443
- "url": hit.source["link"]
444
- }
445
- )
446
  elif "candid-learning" in hit.index:
447
- title = hit.source.get("title", "")
448
- content_with_context_txt = get_context("content", hit, context_length=12)
449
- training_topics = hit.source.get("training_topics", "")
450
- staff_recommendations = hit.source.get("staff_recommendations", "")
451
-
452
- doc = Document(
453
- page_content='\n\n'.join([title, staff_recommendations, training_topics, content_with_context_txt]),
454
- metadata={
455
- "title": hit.source["title"],
456
- "source": "Candid Learning",
457
- "source_id": hit.source["post_id"],
458
- "url": hit.source.get("url", "")
459
- }
460
- )
461
  elif "candid-help" in hit.index:
462
- title = hit.source.get("title", "")
463
- content_with_context_txt = get_context("content", hit, context_length=12)
464
- combined_article_description = hit.source.get("combined_article_description", "")
465
-
466
- doc = Document(
467
- page_content='\n\n'.join([combined_article_description, content_with_context_txt]),
468
- metadata={
469
- "title": title,
470
- "source": "Candid Help",
471
- "source_id": hit.source["id"],
472
- "url": hit.source.get("link", "")
473
- }
474
- )
475
  elif "news" in hit.index:
476
- doc = Document(
477
- page_content='\n\n'.join([hit.source.get("title", ""), hit.source.get("content", "")]),
478
- metadata={
479
- "title": hit.source.get("title", ""),
480
- "source": "Candid News",
481
- "source_id": hit.source["id"],
482
- "url": hit.source.get("link", "")
483
- }
484
- )
485
  else:
486
  doc = None
487
  return doc
 
1
  from typing import List, Tuple, Dict, Iterable, Iterator, Optional, Union, Any
 
2
  from itertools import groupby
3
 
4
  from torch.nn import functional as F
 
9
  from elasticsearch import Elasticsearch
10
 
11
  from ask_candid.retrieval.sparse_lexical import SpladeEncoder
12
+ from ask_candid.retrieval.sources.schema import ElasticHitsResult
13
+ from ask_candid.retrieval.sources.issuelab import IssueLabConfig, process_issuelab_hit
14
+ from ask_candid.retrieval.sources.youtube import YoutubeConfig, process_youtube_hit
15
+ from ask_candid.retrieval.sources.candid_blog import CandidBlogConfig, process_blog_hit
16
+ from ask_candid.retrieval.sources.candid_learning import CandidLearningConfig, process_learning_hit
17
+ from ask_candid.retrieval.sources.candid_help import CandidHelpConfig, process_help_hit
18
+ from ask_candid.retrieval.sources.candid_news import CandidNewsConfig, process_news_hit
19
+
20
  from ask_candid.services.small_lm import CandidSLM
21
  from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA, NEWS_ELASTIC
22
  from ask_candid.base.config.data import DataIndices, ALL_INDICES
 
24
  encoder = SpladeEncoder()
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
27
  class RetrieverInput(BaseModel):
28
  """Input to the Elasticsearch retriever."""
29
  user_input: str = Field(description="query to look up in retriever")
 
91
  tokens = encoder.token_expand(query)
92
 
93
  query = {
94
+ "_source": ["id", "link", "title", "content", "site_name"],
95
  "query": {
96
  "bool": {
97
  "filter": [
 
140
  if index == "issuelab":
141
  q = build_sparse_vector_query(query=query, fields=IssueLabConfig.text_fields)
142
  q["_source"] = {"excludes": ["embeddings"]}
143
+ q["size"] = 2
144
  queries.extend([{"index": IssueLabConfig.index_name}, q])
145
  elif index == "youtube":
146
  q = build_sparse_vector_query(query=query, fields=YoutubeConfig.text_fields)
147
  q["_source"] = {"excludes": ["embeddings", *YoutubeConfig.excluded_fields]}
148
+ q["size"] = 5
149
  queries.extend([{"index": YoutubeConfig.index_name}, q])
150
  elif index == "candid_blog":
151
  q = build_sparse_vector_query(query=query, fields=CandidBlogConfig.text_fields)
152
  q["_source"] = {"excludes": ["embeddings"]}
153
+ q["size"] = 5
154
  queries.extend([{"index": CandidBlogConfig.index_name}, q])
155
  elif index == "candid_learning":
156
  q = build_sparse_vector_query(query=query, fields=CandidLearningConfig.text_fields)
157
  q["_source"] = {"excludes": ["embeddings"]}
158
+ q["size"] = 5
159
  queries.extend([{"index": CandidLearningConfig.index_name}, q])
160
  elif index == "candid_help":
161
  q = build_sparse_vector_query(query=query, fields=CandidHelpConfig.text_fields)
162
  q["_source"] = {"excludes": ["embeddings"]}
163
+ q["size"] = 5
164
  queries.extend([{"index": CandidHelpConfig.index_name}, q])
165
  elif index == "news":
166
  q = news_query_builder(query=query)
 
189
  def _msearch_response_generator(responses: List[Dict[str, Any]]) -> Iterator[ElasticHitsResult]:
190
  for query_group in responses:
191
  for h in query_group.get("hits", {}).get("hits", []):
192
+ inner_hits = h.get("inner_hits", {})
193
+
194
+ if not inner_hits:
195
+ if "news" in h.get("_index"):
196
+ inner_hits = {"text": h.get("_source", {}).get("content")}
197
+
198
  yield ElasticHitsResult(
199
  index=h["_index"],
200
  id=h["_id"],
201
  score=h["_score"],
202
  source=h["_source"],
203
+ inner_hits=inner_hits
204
  )
205
 
206
  results = []
 
260
 
261
  text = []
262
  for _, v in hits.items():
263
+ if _ == "text":
264
+ text.append(v)
265
+ continue
266
+
267
  for h in (v.get("hits", {}).get("hits") or []):
268
  for _, field in h.get("fields", {}).items():
269
  for chunk in field:
 
298
 
299
  def reranker(
300
  query_results: Iterable[ElasticHitsResult],
301
+ search_text: Optional[str] = None,
302
+ max_num_results: int = 10
303
  ) -> Iterator[ElasticHitsResult]:
304
  """Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales.
305
  This will shuffle results
 
328
  text = retrieved_text(d.inner_hits)
329
  texts.append(text)
330
 
331
+ if search_text and len(texts) == len(results):
332
+ # scores = cosine_rescore(search_text, texts)
333
+ scores = encoder.query_reranking(query=search_text, documents=texts)
334
+ for r, s in zip(results, scores):
335
+ r.score = s
 
336
 
337
+ yield from sorted(results, key=lambda x: x.score, reverse=True)[:max_num_results]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
 
339
 
340
  def process_hit(hit: ElasticHitsResult) -> Union[Document, None]:
 
350
  """
351
 
352
  if "issuelab-elser" in hit.index:
353
+ doc = process_issuelab_hit(hit)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
  elif "youtube" in hit.index:
355
+ doc = process_youtube_hit(hit)
 
 
 
 
 
 
 
 
 
 
 
 
356
  elif "candid-blog" in hit.index:
357
+ doc = process_blog_hit(hit)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  elif "candid-learning" in hit.index:
359
+ doc = process_learning_hit(hit)
 
 
 
 
 
 
 
 
 
 
 
 
 
360
  elif "candid-help" in hit.index:
361
+ doc = process_help_hit(hit)
 
 
 
 
 
 
 
 
 
 
 
 
362
  elif "news" in hit.index:
363
+ doc = process_news_hit(hit)
 
 
 
 
 
 
 
 
364
  else:
365
  doc = None
366
  return doc
ask_candid/retrieval/sources/candid_blog.py CHANGED
@@ -1,6 +1,9 @@
1
  from typing import Dict, Any
2
- from ask_candid.retrieval.sources.schema import ElasticSourceConfig
3
 
 
 
 
 
4
 
5
  CandidBlogConfig = ElasticSourceConfig(
6
  index_name="search-semantic-candid-blog",
@@ -8,6 +11,24 @@ CandidBlogConfig = ElasticSourceConfig(
8
  )
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
12
  url = f"{doc['link']}"
13
  fields = ["title", "excerpt"]
 
1
  from typing import Dict, Any
 
2
 
3
+ from langchain_core.documents import Document
4
+
5
+ from ask_candid.retrieval.sources.schema import ElasticSourceConfig, ElasticHitsResult
6
+ from ask_candid.retrieval.sources.utils import get_context
7
 
8
  CandidBlogConfig = ElasticSourceConfig(
9
  index_name="search-semantic-candid-blog",
 
11
  )
12
 
13
 
14
+ def process_blog_hit(hit: ElasticHitsResult) -> Document:
15
+ excerpt = hit.source.get("excerpt", "")
16
+ title = hit.source.get("title", "")
17
+ # we only need to process long text
18
+ content_with_context_txt = get_context("content", hit, context_length=12, add_context=False)
19
+ authors = get_context("authors_text", hit, context_length=12, add_context=False)
20
+ tags = hit.source.get("title_summary_tags", "")
21
+ return Document(
22
+ page_content='\n\n'.join([title, excerpt, content_with_context_txt, authors, tags]),
23
+ metadata={
24
+ "title": title,
25
+ "source": "Candid Blog",
26
+ "source_id": hit.source["id"],
27
+ "url": hit.source["link"]
28
+ }
29
+ )
30
+
31
+
32
  def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
33
  url = f"{doc['link']}"
34
  fields = ["title", "excerpt"]
ask_candid/retrieval/sources/candid_help.py CHANGED
@@ -1,6 +1,9 @@
1
  from typing import Dict, Any
2
- from ask_candid.retrieval.sources.schema import ElasticSourceConfig
3
 
 
 
 
 
4
 
5
  CandidHelpConfig = ElasticSourceConfig(
6
  index_name="search-semantic-candid-help-elser_ve1",
@@ -8,6 +11,22 @@ CandidHelpConfig = ElasticSourceConfig(
8
  )
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
12
  url = f"{doc['link']}"
13
  fields = ["title", "summary"]
 
1
  from typing import Dict, Any
 
2
 
3
+ from langchain_core.documents import Document
4
+
5
+ from ask_candid.retrieval.sources.schema import ElasticSourceConfig, ElasticHitsResult
6
+ from ask_candid.retrieval.sources.utils import get_context
7
 
8
  CandidHelpConfig = ElasticSourceConfig(
9
  index_name="search-semantic-candid-help-elser_ve1",
 
11
  )
12
 
13
 
14
+ def process_help_hit(hit: ElasticHitsResult) -> Document:
15
+ title = hit.source.get("title", "")
16
+ content_with_context_txt = get_context("content", hit, context_length=12)
17
+ combined_article_description = hit.source.get("combined_article_description", "")
18
+
19
+ return Document(
20
+ page_content='\n\n'.join([combined_article_description, content_with_context_txt]),
21
+ metadata={
22
+ "title": title,
23
+ "source": "Candid Help",
24
+ "source_id": hit.source["id"],
25
+ "url": hit.source.get("link", "")
26
+ }
27
+ )
28
+
29
+
30
  def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
31
  url = f"{doc['link']}"
32
  fields = ["title", "summary"]
ask_candid/retrieval/sources/candid_learning.py CHANGED
@@ -1,5 +1,9 @@
1
  from typing import Dict, Any
2
- from ask_candid.retrieval.sources.schema import ElasticSourceConfig
 
 
 
 
3
 
4
 
5
  CandidLearningConfig = ElasticSourceConfig(
@@ -8,6 +12,23 @@ CandidLearningConfig = ElasticSourceConfig(
8
  )
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
12
  url = f"{doc['url']}"
13
  fields = ["title", "excerpt"]
 
1
  from typing import Dict, Any
2
+
3
+ from langchain_core.documents import Document
4
+
5
+ from ask_candid.retrieval.sources.schema import ElasticSourceConfig, ElasticHitsResult
6
+ from ask_candid.retrieval.sources.utils import get_context
7
 
8
 
9
  CandidLearningConfig = ElasticSourceConfig(
 
12
  )
13
 
14
 
15
+ def process_learning_hit(hit: ElasticHitsResult) -> Document:
16
+ title = hit.source.get("title", "")
17
+ content_with_context_txt = get_context("content", hit, context_length=12)
18
+ training_topics = hit.source.get("training_topics", "")
19
+ staff_recommendations = hit.source.get("staff_recommendations", "")
20
+
21
+ return Document(
22
+ page_content='\n\n'.join([title, staff_recommendations, training_topics, content_with_context_txt]),
23
+ metadata={
24
+ "title": hit.source["title"],
25
+ "source": "Candid Learning",
26
+ "source_id": hit.source["post_id"],
27
+ "url": hit.source.get("url", "")
28
+ }
29
+ )
30
+
31
+
32
  def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
33
  url = f"{doc['url']}"
34
  fields = ["title", "excerpt"]
ask_candid/retrieval/sources/candid_news.py CHANGED
@@ -1,7 +1,20 @@
1
- from ask_candid.retrieval.sources.schema import ElasticSourceConfig
2
 
 
3
 
4
  CandidNewsConfig = ElasticSourceConfig(
5
  index_name="news_1",
6
  text_fields=("title", "content")
7
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.documents import Document
2
 
3
+ from ask_candid.retrieval.sources.schema import ElasticSourceConfig, ElasticHitsResult
4
 
5
  CandidNewsConfig = ElasticSourceConfig(
6
  index_name="news_1",
7
  text_fields=("title", "content")
8
  )
9
+
10
+
11
+ def process_news_hit(hit: ElasticHitsResult) -> Document:
12
+ return Document(
13
+ page_content='\n\n'.join([hit.source.get("title", ""), hit.source.get("content", "")]),
14
+ metadata={
15
+ "title": hit.source.get("title", ""),
16
+ "source": hit.source.get("site_name") or "Candid News",
17
+ "source_id": hit.source["id"],
18
+ "url": hit.source.get("link", "")
19
+ }
20
+ )
ask_candid/retrieval/sources/issuelab.py CHANGED
@@ -1,6 +1,9 @@
1
  from typing import Dict, Any
2
- from ask_candid.retrieval.sources.schema import ElasticSourceConfig
3
 
 
 
 
 
4
 
5
  IssueLabConfig = ElasticSourceConfig(
6
  index_name="search-semantic-issuelab-elser_ve2",
@@ -8,11 +11,33 @@ IssueLabConfig = ElasticSourceConfig(
8
  )
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def issuelab_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
12
  chunks_html = ""
13
  if show_chunks:
14
  cleaned_text = []
15
- for k, v in doc["inner_hits"].items():
16
  hits = v["hits"]["hits"]
17
  for h in hits:
18
  for k1, v1 in h["fields"].items():
 
1
  from typing import Dict, Any
 
2
 
3
+ from langchain_core.documents import Document
4
+
5
+ from ask_candid.retrieval.sources.schema import ElasticSourceConfig, ElasticHitsResult
6
+ from ask_candid.retrieval.sources.utils import get_context
7
 
8
  IssueLabConfig = ElasticSourceConfig(
9
  index_name="search-semantic-issuelab-elser_ve2",
 
11
  )
12
 
13
 
14
+ def process_issuelab_hit(hit: ElasticHitsResult) -> Document:
15
+ combined_item_description = hit.source.get("combined_item_description", "") # title inside
16
+ description = hit.source.get("description", "")
17
+ combined_issuelab_findings = hit.source.get("combined_issuelab_findings", "")
18
+ # we only need to process long texts
19
+ chunks_with_context_txt = get_context("content", hit, context_length=12)
20
+ return Document(
21
+ page_content='\n\n'.join([
22
+ combined_item_description,
23
+ combined_issuelab_findings,
24
+ description,
25
+ chunks_with_context_txt
26
+ ]),
27
+ metadata={
28
+ "title": hit.source["title"],
29
+ "source": "IssueLab",
30
+ "source_id": hit.source["resource_id"],
31
+ "url": hit.source.get("permalink", "")
32
+ }
33
+ )
34
+
35
+
36
  def issuelab_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
37
  chunks_html = ""
38
  if show_chunks:
39
  cleaned_text = []
40
+ for _, v in doc["inner_hits"].items():
41
  hits = v["hits"]["hits"]
42
  for h in hits:
43
  for k1, v1 in h["fields"].items():
ask_candid/retrieval/sources/schema.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Tuple, Optional
2
  from dataclasses import dataclass, field
3
 
4
 
@@ -7,3 +7,14 @@ class ElasticSourceConfig:
7
  index_name: str
8
  text_fields: Tuple[str]
9
  excluded_fields: Optional[Tuple[str]] = field(default_factory=tuple)
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Dict, Optional, Any
2
  from dataclasses import dataclass, field
3
 
4
 
 
7
  index_name: str
8
  text_fields: Tuple[str]
9
  excluded_fields: Optional[Tuple[str]] = field(default_factory=tuple)
10
+
11
+
12
+ @dataclass
13
+ class ElasticHitsResult:
14
+ """Dataclass for Elasticsearch hits results
15
+ """
16
+ index: str
17
+ id: Any
18
+ score: float
19
+ source: Dict[str, Any]
20
+ inner_hits: Dict[str, Any]
ask_candid/retrieval/sources/utils.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ask_candid.retrieval.sources.schema import ElasticHitsResult
2
+
3
+
4
+ def get_context(field_name: str, hit: ElasticHitsResult, context_length: int = 1024, add_context: bool = True) -> str:
5
+ """Pads the relevant chunk of text with context before and after
6
+
7
+ Parameters
8
+ ----------
9
+ field_name : str
10
+ a field with the long text that was chunked into pieces
11
+ hit : ElasticHitsResult
12
+ context_length : int, optional
13
+ length of text to add before and after the chunk, by default 1024
14
+
15
+ Returns
16
+ -------
17
+ str
18
+ longer chunks stuffed together
19
+ """
20
+
21
+ chunks = []
22
+ # NOTE chunks have tokens, long text is a normal text, but may contain html that also gets weird after tokenization
23
+ long_text = hit.source.get(f"{field_name}", "")
24
+ long_text = long_text.lower()
25
+ inner_hits_field = f"embeddings.{field_name}.chunks"
26
+ found_chunks = hit.inner_hits.get(inner_hits_field, {})
27
+ if found_chunks:
28
+ hits = found_chunks.get("hits", {}).get("hits", [])
29
+ for h in hits:
30
+ chunk = h.get("fields", {})[inner_hits_field][0]["chunk"][0]
31
+
32
+ # cutting the middle because we may have tokenizing artifacts there
33
+ chunk = chunk[3: -3]
34
+
35
+ if add_context:
36
+ # Find the start and end indices of the chunk in the large text
37
+ start_index = long_text.find(chunk[:20])
38
+
39
+ # Chunk is found
40
+ if start_index != -1:
41
+ end_index = start_index + len(chunk)
42
+ pre_start_index = max(0, start_index - context_length)
43
+ post_end_index = min(len(long_text), end_index + context_length)
44
+ chunks.append(long_text[pre_start_index:post_end_index])
45
+ else:
46
+ chunks.append(chunk)
47
+ return '\n\n'.join(chunks)
ask_candid/retrieval/sources/youtube.py CHANGED
@@ -1,6 +1,9 @@
1
  from typing import Dict, Any
2
- from ask_candid.retrieval.sources.schema import ElasticSourceConfig
3
 
 
 
 
 
4
 
5
  YoutubeConfig = ElasticSourceConfig(
6
  index_name="search-semantic-youtube-elser_ve1",
@@ -9,6 +12,22 @@ YoutubeConfig = ElasticSourceConfig(
9
  )
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
13
  url = f"https://www.youtube.com/watch?v={doc['video_id']}"
14
  fields = ["title", "description_cleaned"]
 
1
  from typing import Dict, Any
 
2
 
3
+ from langchain_core.documents import Document
4
+
5
+ from ask_candid.retrieval.sources.schema import ElasticSourceConfig, ElasticHitsResult
6
+ from ask_candid.retrieval.sources.utils import get_context
7
 
8
  YoutubeConfig = ElasticSourceConfig(
9
  index_name="search-semantic-youtube-elser_ve1",
 
12
  )
13
 
14
 
15
+ def process_youtube_hit(hit: ElasticHitsResult) -> Document:
16
+ title = hit.source.get("title", "")
17
+ # we only need to process long texts
18
+ description_cleaned_with_context_txt = get_context("description_cleaned", hit, context_length=12)
19
+ captions_cleaned_with_context_txt = get_context("captions_cleaned", hit, context_length=12)
20
+ return Document(
21
+ page_content='\n\n'.join([title, description_cleaned_with_context_txt, captions_cleaned_with_context_txt]),
22
+ metadata={
23
+ "title": title,
24
+ "source": "Candid YouTube",
25
+ "source_id": hit.source['video_id'],
26
+ "url": f"https://www.youtube.com/watch?v&#61;{hit.source['video_id']}"
27
+ }
28
+ )
29
+
30
+
31
  def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
32
  url = f"https://www.youtube.com/watch?v={doc['video_id']}"
33
  fields = ["title", "description_cleaned"]
ask_candid/retrieval/sparse_lexical.py CHANGED
@@ -1,6 +1,7 @@
1
- from typing import Dict
2
 
3
  from transformers import AutoModelForMaskedLM, AutoTokenizer
 
4
  import torch
5
 
6
 
@@ -14,14 +15,23 @@ class SpladeEncoder:
14
  self.idx2token = {idx: token for token, idx in self.tokenizer.get_vocab().items()}
15
 
16
  @torch.no_grad()
17
- def token_expand(self, query: str) -> Dict[str, float]:
18
- tokens = self.tokenizer(query, return_tensors='pt')
19
  output = self.model(**tokens)
20
-
21
  vec = torch.max(
22
  torch.log(1 + torch.relu(output.logits)) * tokens.attention_mask.unsqueeze(-1),
23
  dim=1
24
  )[0].squeeze()
 
 
 
 
 
 
 
 
 
 
25
  cols = vec.nonzero().squeeze().cpu().tolist()
26
  weights = vec[cols].cpu().tolist()
27
 
 
1
+ from typing import List, Dict
2
 
3
  from transformers import AutoModelForMaskedLM, AutoTokenizer
4
+ from torch.nn import functional as F
5
  import torch
6
 
7
 
 
15
  self.idx2token = {idx: token for token, idx in self.tokenizer.get_vocab().items()}
16
 
17
  @torch.no_grad()
18
+ def forward(self, texts: List[str]):
19
+ tokens = self.tokenizer(texts, return_tensors='pt', truncation=True, padding=True)
20
  output = self.model(**tokens)
 
21
  vec = torch.max(
22
  torch.log(1 + torch.relu(output.logits)) * tokens.attention_mask.unsqueeze(-1),
23
  dim=1
24
  )[0].squeeze()
25
+ return vec
26
+
27
+ def query_reranking(self, query: str, documents: List[str]):
28
+ vec = self.forward([query, *documents])
29
+ xQ = F.normalize(vec[:1], dim=-1, p=2.)
30
+ xD = F.normalize(vec[1:], dim=-1, p=2.)
31
+ return (xQ * xD).sum(dim=-1).cpu().tolist()
32
+
33
+ def token_expand(self, query: str) -> Dict[str, float]:
34
+ vec = self.forward([query])
35
  cols = vec.nonzero().squeeze().cpu().tolist()
36
  weights = vec[cols].cpu().tolist()
37
 
ask_candid/tools/elastic/index_search_tool.py CHANGED
@@ -40,6 +40,7 @@ class SearchToolInput(BaseModel):
40
 
41
 
42
  def elastic_search(
 
43
  index_name: str,
44
  query: str,
45
  from_: int = 0,
@@ -107,9 +108,15 @@ def elastic_search(
107
  return msg
108
 
109
 
110
- def create_search_tool():
111
  return StructuredTool.from_function(
112
- elastic_search,
 
 
 
 
 
 
113
  name="elastic_index_search_tool",
114
  description=(
115
  """This tool allows executing queries on an Elasticsearch index efficiently. Provide:
 
40
 
41
 
42
  def elastic_search(
43
+ pcs_codes: dict,
44
  index_name: str,
45
  query: str,
46
  from_: int = 0,
 
108
  return msg
109
 
110
 
111
+ def create_search_tool(pcs_codes):
112
  return StructuredTool.from_function(
113
+ func=lambda index_name, query, from_, size: elastic_search(
114
+ pcs_codes=pcs_codes,
115
+ index_name=index_name,
116
+ query=query,
117
+ from_=from_,
118
+ size=size,
119
+ ),
120
  name="elastic_index_search_tool",
121
  description=(
122
  """This tool allows executing queries on an Elasticsearch index efficiently. Provide:
ask_candid/tools/question_reformulation.py CHANGED
@@ -1,55 +1,55 @@
1
  from langchain_core.prompts import ChatPromptTemplate
2
  from langchain_core.output_parsers import StrOutputParser
 
3
 
 
4
 
5
- def reformulate_question_using_history(state, llm, focus_on_recommendations=False):
6
- """
7
- Transform the query to produce a better query with details from previous messages and emphasize
8
- aspects important for recommendations if needed.
9
 
10
- Args:
11
- state (dict): The current state containing messages.
12
- llm: LLM to use for generating the reformulation.
13
- focus_on_recommendations (bool): Flag to determine if the reformulation should emphasize
14
- recommendation-relevant aspects such as geographies,
15
- cause areas, etc.
16
- Returns:
17
- dict: The updated state with re-phrased question and original user_input for UI
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  """
 
19
  print("---REFORMULATE THE USER INPUT---")
20
  messages = state["messages"]
21
  question = messages[-1].content
22
 
23
- if len(messages) > 1:
24
  if focus_on_recommendations:
25
- prompt_text = """Given a chat history and the latest user input \
26
- which might reference context in the chat history, \
27
- especially geographic locations, cause areas and/or population groups, \
28
- formulate a standalone input which can be understood without the chat history.
29
- Chat history:
30
- \n ------- \n
31
- {chat_history}
32
- \n ------- \n
33
- User input:
34
- \n ------- \n
35
- {question}
36
- \n ------- \n
37
  Reformulate the question without adding implications or assumptions about the user's needs or intentions.
38
  Focus solely on clarifying any contextual details present in the original input."""
39
  else:
40
- prompt_text = """Given a chat history and the latest user input \
41
- which might reference context in the chat history, formulate a standalone input \
42
- which can be understood without the chat history.
43
- Chat history:
44
- \n ------- \n
45
- {chat_history}
46
- \n ------- \n
47
- User input:
48
- \n ------- \n
49
- {question}
50
- \n ------- \n
51
- Do NOT answer the question, \
52
- just reformulate it if needed and otherwise return it as is.
53
  """
54
 
55
  contextualize_q_prompt = ChatPromptTemplate([
@@ -58,7 +58,11 @@ def reformulate_question_using_history(state, llm, focus_on_recommendations=Fals
58
  ])
59
 
60
  rag_chain = contextualize_q_prompt | llm | StrOutputParser()
61
- new_question = rag_chain.invoke({"chat_history": messages, "question": question})
 
 
 
 
62
  print(f"user asked: '{question}', agent reformulated the question basing on the chat history: {new_question}")
63
  return {"messages": [new_question], "user_input" : question}
64
  return {"messages": [question], "user_input" : question}
 
1
  from langchain_core.prompts import ChatPromptTemplate
2
  from langchain_core.output_parsers import StrOutputParser
3
+ from langchain_core.language_models.llms import LLM
4
 
5
+ from ask_candid.agents.schema import AgentState
6
 
 
 
 
 
7
 
8
+ def reformulate_question_using_history(
9
+ state: AgentState,
10
+ llm: LLM,
11
+ focus_on_recommendations: bool = False
12
+ ) -> AgentState:
13
+ """Transform the query to produce a better query with details from previous messages and emphasize aspects important
14
+ for recommendations if needed.
15
+
16
+ Parameters
17
+ ----------
18
+ state : AgentState
19
+ The current state
20
+ llm : LLM
21
+ focus_on_recommendations : bool, optional
22
+ Flag to determine if the reformulation should emphasize recommendation-relevant aspects such as geographies,
23
+ cause areas, etc., by default False
24
+
25
+ Returns
26
+ -------
27
+ AgentState
28
+ The updated state
29
  """
30
+
31
  print("---REFORMULATE THE USER INPUT---")
32
  messages = state["messages"]
33
  question = messages[-1].content
34
 
35
+ if len(messages[:-1]) > 1: # need to skip the system message
36
  if focus_on_recommendations:
37
+ prompt_text = """Given a chat history and the latest user input which might reference context in the chat
38
+ history, especially geographic locations, cause areas and/or population groups, formulate a standalone input
39
+ which can be understood without the chat history.
40
+ Chat history: ```{chat_history}```
41
+ User input: ```{question}```
42
+
 
 
 
 
 
 
43
  Reformulate the question without adding implications or assumptions about the user's needs or intentions.
44
  Focus solely on clarifying any contextual details present in the original input."""
45
  else:
46
+ prompt_text = """Given a chat history and the latest user input which might reference context in the chat
47
+ history, formulate a standalone input which can be understood without the chat history. Include hints as to
48
+ what the user is getting at given the context in the chat history.
49
+ Chat history: ```{chat_history}```
50
+ User input: ```{question}```
51
+
52
+ Do NOT answer the question, just reformulate it if needed and otherwise return it as is.
 
 
 
 
 
 
53
  """
54
 
55
  contextualize_q_prompt = ChatPromptTemplate([
 
58
  ])
59
 
60
  rag_chain = contextualize_q_prompt | llm | StrOutputParser()
61
+ # new_question = rag_chain.invoke({"chat_history": messages, "question": question})
62
+ new_question = rag_chain.invoke({
63
+ "chat_history": '\n'.join(f"{m.type.upper()}: {m.content}" for m in messages[1:]),
64
+ "question": question
65
+ })
66
  print(f"user asked: '{question}', agent reformulated the question basing on the chat history: {new_question}")
67
  return {"messages": [new_question], "user_input" : question}
68
  return {"messages": [question], "user_input" : question}