brainsqueeze's picture
Feedback and style changes
c751e97 verified
raw
history blame
16.3 kB
from typing import List, Tuple, Dict, Iterable, Iterator, Optional, Union, Any
from dataclasses import dataclass
from functools import partial
from itertools import groupby
from torch.nn import functional as F
from pydantic import BaseModel, Field
from langchain_core.documents import Document
from langchain_core.tools import Tool
from elasticsearch import Elasticsearch
from ask_candid.services.small_lm import CandidSLM
from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA
from ask_candid.base.config.data import ElasticIndexMapping, ALL_INDICES
@dataclass
class ElasticHitsResult:
"""Dataclass for Elasticsearch hits results
"""
index: str
id: Any
score: float
source: Dict[str, Any]
inner_hits: Dict[str, Any]
class RetrieverInput(BaseModel):
"""Input to the Elasticsearch retriever."""
user_input: str = Field(description="query to look up in retriever")
def build_text_expansion_query(
query: str,
fields: Tuple[str],
model_id: str = ".elser_model_2_linux-x86_64"
) -> Dict[str, Any]:
"""Builds a valid Elasticsearch text expansion query payload
Parameters
----------
query : str
Search context string
fields : Tuple[str]
Semantic text field names
model_id : str, optional
ID of model deployed in Elasticsearch, by default ".elser_model_2_linux-x86_64"
Returns
-------
Dict[str, Any]
"""
output = []
for f in fields:
output.append({
"nested": {
"path": f"embeddings.{f}.chunks",
"query": {
"text_expansion": {
f"embeddings.{f}.chunks.vector": {
"model_id": model_id,
"model_text": query,
"boost": 1 / len(fields)
}
}
},
"inner_hits": {
"_source": False,
"size": 2,
"fields": [f"embeddings.{f}.chunks.chunk"]
}
}
})
return {"query": {"bool": {"should": output}}}
def query_builder(query: str, indices: List[str]) -> List[Dict[str, Any]]:
"""Builds Elasticsearch multi-search query payload
Parameters
----------
query : str
Search context string
indices : List[str]
Semantic index names to search over
Returns
-------
List[Dict[str, Any]]
"""
queries = []
if indices is None:
indices = list(ALL_INDICES)
for index in indices:
if index == "issuelab":
q = build_text_expansion_query(
query=query,
fields=("description", "content", "combined_issuelab_findings", "combined_item_description")
)
q["_source"] = {"excludes": ["embeddings"]}
q["size"] = 1
queries.extend([{"index": ElasticIndexMapping.ISSUELAB_INDEX_ELSER}, q])
elif index == "youtube":
q = build_text_expansion_query(
query=query,
fields=("captions_cleaned", "description_cleaned", "title")
)
# text_cleaned duplicates captions_cleaned
q["_source"] = {"excludes": ["embeddings", "captions", "description", "text_cleaned"]}
q["size"] = 2
queries.extend([{"index": ElasticIndexMapping.YOUTUBE_INDEX_ELSER}, q])
elif index == "candid_blog":
q = build_text_expansion_query(
query=query,
fields=("content", "title")
)
q["_source"] = {"excludes": ["embeddings"]}
q["size"] = 2
queries.extend([{"index": ElasticIndexMapping.CANDID_BLOG_INDEX_ELSER}, q])
elif index == "candid_learning":
q = build_text_expansion_query(
query=query,
fields=("content", "title", "training_topics", "staff_recommendations")
)
q["_source"] = {"excludes": ["embeddings"]}
q["size"] = 2
queries.extend([{"index": ElasticIndexMapping.CANDID_LEARNING_INDEX_ELSER}, q])
elif index == "candid_help":
q = build_text_expansion_query(
query=query,
fields=("content", "combined_article_description")
)
q["_source"] = {"excludes": ["embeddings"]}
q["size"] = 2
queries.extend([{"index": ElasticIndexMapping.CANDID_HELP_INDEX_ELSER}, q])
return queries
def multi_search(queries: List[Dict[str, Any]]) -> List[ElasticHitsResult]:
"""Runs multi-search query
Parameters
----------
queries : List[Dict[str, Any]]
Pre-built multi-search query payload
Returns
-------
List[ElasticHitsResult]
"""
results = []
with Elasticsearch(
cloud_id=SEMANTIC_ELASTIC_QA.cloud_id,
api_key=SEMANTIC_ELASTIC_QA.api_key,
verify_certs=False,
request_timeout=60 * 3
) as es:
for query_group in es.msearch(body=queries).get("responses", []):
for hit in query_group.get("hits", {}).get("hits", []):
hit = ElasticHitsResult(
index=hit["_index"],
id=hit["_id"],
score=hit["_score"],
source=hit["_source"],
inner_hits=hit.get("inner_hits", {})
)
results.append(hit)
return results
def get_query_results(search_text: str, indices: Optional[List[str]] = None) -> List[ElasticHitsResult]:
"""Builds and executes Elasticsearch data queries from a search string.
Parameters
----------
search_text : str
Search context string
indices : Optional[List[str]], optional
Semantic index names to search over, by default None
Returns
-------
List[ElasticHitsResult]
"""
queries = query_builder(query=search_text, indices=indices)
return multi_search(queries)
def retrieved_text(hits: Dict[str, Any]) -> str:
"""Extracts retrieved sub-texts from documents which are strong hits from semantic queries for the purpose of
re-scoring by a secondary language model.
Parameters
----------
hits : Dict[str, Any]
Returns
-------
str
"""
text = []
for _, v in hits.items():
for h in (v.get("hits", {}).get("hits") or []):
for _, field in h.get("fields", {}).items():
for chunk in field:
if chunk.get("chunk"):
text.extend(chunk["chunk"])
return '\n'.join(text)
def cosine_rescore(query: str, contexts: List[str]) -> List[float]:
"""Computes cosine scores between retrieved contexts and the original query to re-score results based on overall
relevance to the original query.
Parameters
----------
query : str
Search context string
contexts : List[str]
Semantic field sub-texts, order is by document retrieved from the original multi-search query.
Returns
-------
List[float]
Scores in the same order as the input document contexts
"""
nlp = CandidSLM()
X = nlp.encode([query, *contexts]).vectors
X = F.normalize(X, dim=-1, p=2.)
cosine = X[1:] @ X[:1].T
return cosine.flatten().cpu().numpy().tolist()
def reranker(
query_results: Iterable[ElasticHitsResult],
search_text: Optional[str] = None
) -> Iterator[ElasticHitsResult]:
"""Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales.
This will shuffle results
Parameters
----------
query_results : Iterable[ElasticHitsResult]
Yields
------
Iterator[ElasticHitsResult]
"""
results: List[ElasticHitsResult] = []
texts: List[str] = []
for _, data in groupby(query_results, key=lambda x: x.index):
data = list(data)
max_score = max(data, key=lambda x: x.score).score
min_score = min(data, key=lambda x: x.score).score
for d in data:
d.score = (d.score - min_score) / (max_score - min_score + 1e-9)
results.append(d)
if search_text:
text = retrieved_text(d.inner_hits)
texts.append(text)
# if search_text and len(texts) == len(results):
# scores = cosine_rescore(search_text, texts)
# for r, s in zip(results, scores):
# r.score = s
yield from sorted(results, key=lambda x: x.score, reverse=True)
def get_results(user_input: str, indices: List[str]) -> Tuple[str, List[Document]]:
"""End-to-end search and re-rank function.
Parameters
----------
user_input : str
Search context string
indices : List[str]
Semantic index names to search over
Returns
-------
Tuple[str, List[Document]]
(concatenated text from search results, documents list)
"""
output = ["Search didn't return any Candid sources"]
page_content = []
content = "Search didn't return any Candid sources"
results = get_query_results(search_text=user_input, indices=indices)
if results:
output = get_reranked_results(results, search_text=user_input)
for doc in output:
page_content.append(doc.page_content)
content = "\n\n".join(page_content)
# for the tool we need to return a tuple for content_and_artifact type
return content, output
def get_context(field_name: str, hit: ElasticHitsResult, context_length: int = 1024) -> str:
"""Pads the relevant chunk of text with context before and after
Parameters
----------
field_name : str
a field with the long text that was chunked into pieces
hit : ElasticHitsResult
context_length : int, optional
length of text to add before and after the chunk, by default 1024
Returns
-------
str
longer chunks stuffed together
"""
chunks_with_context = []
long_text = hit.source.get(f"{field_name}", "")
inner_hits_field = f"embeddings.{field_name}.chunks"
found_chunks = hit.inner_hits.get(inner_hits_field, {})
if found_chunks:
hits = found_chunks.get("hits", {}).get("hits", [])
for h in hits:
chunk = h.get("fields", {})[inner_hits_field][0]["chunk"][0]
# cutting the middle because we may have tokenizing artifacts there
chunk = chunk[3: -3]
# Find the start and end indices of the chunk in the large text
start_index = long_text.find(chunk)
if start_index != -1: # Chunk is found
end_index = start_index + len(chunk)
pre_start_index = max(0, start_index - context_length)
post_end_index = min(len(long_text), end_index + context_length)
chunks_with_context.append(long_text[pre_start_index:post_end_index])
return '\n\n'.join(chunks_with_context)
def process_hit(hit: ElasticHitsResult) -> Union[Document, None]:
"""Parse Elasticsearch hit results into data structures handled by the RAG pipeline.
Parameters
----------
hit : ElasticHitsResult
Returns
-------
Union[Document, None]
"""
if "issuelab-elser" in hit.index:
combined_item_description = hit.source.get("combined_item_description", "") # title inside
description = hit.source.get("description", "")
combined_issuelab_findings = hit.source.get("combined_issuelab_findings", "")
# we only need to process long texts
chunks_with_context_txt = get_context("content", hit, context_length=12)
doc = Document(
page_content='\n\n'.join([
combined_item_description,
combined_issuelab_findings,
description,
chunks_with_context_txt
]),
metadata={
"title": hit.source["title"],
"source": "IssueLab",
"source_id": hit.source["resource_id"],
"url": hit.source.get("permalink", "")
}
)
elif "youtube" in hit.index:
title = hit.source.get("title", "")
# we only need to process long texts
description_cleaned_with_context_txt = get_context("description_cleaned", hit, context_length=12)
captions_cleaned_with_context_txt = get_context("captions_cleaned", hit, context_length=12)
doc = Document(
page_content='\n\n'.join([title, description_cleaned_with_context_txt, captions_cleaned_with_context_txt]),
metadata={
"title": title,
"source": "Candid YouTube",
"source_id": hit.source['video_id'],
"url": f"https://www.youtube.com/watch?v={hit.source['video_id']}"
}
)
elif "candid-blog" in hit.index:
excerpt = hit.source.get("excerpt", "")
title = hit.source.get("title", "")
# we only need to process long texts
content_with_context_txt = get_context("content", hit, context_length=12)
doc = Document(
page_content='\n\n'.join([title, excerpt, content_with_context_txt]),
metadata={
"title": title,
"source": "Candid Blog",
"source_id": hit.source["id"],
"url": hit.source["link"]
}
)
elif "candid-learning" in hit.index:
title = hit.source.get("title", "")
content_with_context_txt = get_context("content", hit, context_length=12)
training_topics = hit.source.get("training_topics", "")
staff_recommendations = hit.source.get("staff_recommendations", "")
doc = Document(
page_content='\n\n'.join([title, staff_recommendations, training_topics, content_with_context_txt]),
metadata={
"title": hit.source["title"],
"source": "Candid Learning",
"source_id": hit.source["post_id"],
"url": hit.source.get("url", "")
}
)
elif "candid-help" in hit.index:
title = hit.source.get("title", "")
content_with_context_txt = get_context("content", hit, context_length=12)
combined_article_description = hit.source.get("combined_article_description", "")
doc = Document(
page_content='\n\n'.join([combined_article_description, content_with_context_txt]),
metadata={
"title": title,
"source": "Candid Help",
"source_id": hit.source["id"],
"url": hit.source.get("link", "")
}
)
else:
doc = None
return doc
def get_reranked_results(results: List[ElasticHitsResult], search_text: Optional[str] = None) -> List[Document]:
"""Run data re-ranking and document building for tool usage.
Parameters
----------
results : List[ElasticHitsResult]
search_text : Optional[str], optional
Search context string, by default None
Returns
-------
List[Document]
"""
output = []
for r in reranker(results, search_text=search_text):
hit = process_hit(r)
if hit is not None:
output.append(hit)
return output
def retriever_tool(indices: List[str]) -> Tool:
"""Tool component for use in conditional edge building for RAG execution graph.
Cannot use `create_retriever_tool` because it only provides content losing all metadata on the way
https://python.langchain.com/docs/how_to/custom_tools/#returning-artifacts-of-tool-execution
Parameters
----------
indices : List[str]
Semantic index names to search over
Returns
-------
Tool
"""
return Tool(
name="retrieve_social_sector_information",
func=partial(get_results, indices=indices),
description=(
"Return additional information about social and philanthropic sector, "
"including nonprofits (NGO), grants, foundations, funding, RFP, LOI, Candid."
),
args_schema=RetrieverInput,
response_format="content_and_artifact"
)