Spaces:
Running
Running
from typing import List, Tuple, Dict, Iterable, Iterator, Optional, Union, Any | |
from dataclasses import dataclass | |
from functools import partial | |
from itertools import groupby | |
from torch.nn import functional as F | |
from pydantic import BaseModel, Field | |
from langchain_core.documents import Document | |
from langchain_core.tools import Tool | |
from elasticsearch import Elasticsearch | |
from ask_candid.services.small_lm import CandidSLM | |
from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA | |
from ask_candid.base.config.data import ElasticIndexMapping, ALL_INDICES | |
class ElasticHitsResult: | |
"""Dataclass for Elasticsearch hits results | |
""" | |
index: str | |
id: Any | |
score: float | |
source: Dict[str, Any] | |
inner_hits: Dict[str, Any] | |
class RetrieverInput(BaseModel): | |
"""Input to the Elasticsearch retriever.""" | |
user_input: str = Field(description="query to look up in retriever") | |
def build_text_expansion_query( | |
query: str, | |
fields: Tuple[str], | |
model_id: str = ".elser_model_2_linux-x86_64" | |
) -> Dict[str, Any]: | |
"""Builds a valid Elasticsearch text expansion query payload | |
Parameters | |
---------- | |
query : str | |
Search context string | |
fields : Tuple[str] | |
Semantic text field names | |
model_id : str, optional | |
ID of model deployed in Elasticsearch, by default ".elser_model_2_linux-x86_64" | |
Returns | |
------- | |
Dict[str, Any] | |
""" | |
output = [] | |
for f in fields: | |
output.append({ | |
"nested": { | |
"path": f"embeddings.{f}.chunks", | |
"query": { | |
"text_expansion": { | |
f"embeddings.{f}.chunks.vector": { | |
"model_id": model_id, | |
"model_text": query, | |
"boost": 1 / len(fields) | |
} | |
} | |
}, | |
"inner_hits": { | |
"_source": False, | |
"size": 2, | |
"fields": [f"embeddings.{f}.chunks.chunk"] | |
} | |
} | |
}) | |
return {"query": {"bool": {"should": output}}} | |
def query_builder(query: str, indices: List[str]) -> List[Dict[str, Any]]: | |
"""Builds Elasticsearch multi-search query payload | |
Parameters | |
---------- | |
query : str | |
Search context string | |
indices : List[str] | |
Semantic index names to search over | |
Returns | |
------- | |
List[Dict[str, Any]] | |
""" | |
queries = [] | |
if indices is None: | |
indices = list(ALL_INDICES) | |
for index in indices: | |
if index == "issuelab": | |
q = build_text_expansion_query( | |
query=query, | |
fields=("description", "content", "combined_issuelab_findings", "combined_item_description") | |
) | |
q["_source"] = {"excludes": ["embeddings"]} | |
q["size"] = 1 | |
queries.extend([{"index": ElasticIndexMapping.ISSUELAB_INDEX_ELSER}, q]) | |
elif index == "youtube": | |
q = build_text_expansion_query( | |
query=query, | |
fields=("captions_cleaned", "description_cleaned", "title") | |
) | |
# text_cleaned duplicates captions_cleaned | |
q["_source"] = {"excludes": ["embeddings", "captions", "description", "text_cleaned"]} | |
q["size"] = 2 | |
queries.extend([{"index": ElasticIndexMapping.YOUTUBE_INDEX_ELSER}, q]) | |
elif index == "candid_blog": | |
q = build_text_expansion_query( | |
query=query, | |
fields=("content", "title") | |
) | |
q["_source"] = {"excludes": ["embeddings"]} | |
q["size"] = 2 | |
queries.extend([{"index": ElasticIndexMapping.CANDID_BLOG_INDEX_ELSER}, q]) | |
elif index == "candid_learning": | |
q = build_text_expansion_query( | |
query=query, | |
fields=("content", "title", "training_topics", "staff_recommendations") | |
) | |
q["_source"] = {"excludes": ["embeddings"]} | |
q["size"] = 2 | |
queries.extend([{"index": ElasticIndexMapping.CANDID_LEARNING_INDEX_ELSER}, q]) | |
elif index == "candid_help": | |
q = build_text_expansion_query( | |
query=query, | |
fields=("content", "combined_article_description") | |
) | |
q["_source"] = {"excludes": ["embeddings"]} | |
q["size"] = 2 | |
queries.extend([{"index": ElasticIndexMapping.CANDID_HELP_INDEX_ELSER}, q]) | |
return queries | |
def multi_search(queries: List[Dict[str, Any]]) -> List[ElasticHitsResult]: | |
"""Runs multi-search query | |
Parameters | |
---------- | |
queries : List[Dict[str, Any]] | |
Pre-built multi-search query payload | |
Returns | |
------- | |
List[ElasticHitsResult] | |
""" | |
results = [] | |
with Elasticsearch( | |
cloud_id=SEMANTIC_ELASTIC_QA.cloud_id, | |
api_key=SEMANTIC_ELASTIC_QA.api_key, | |
verify_certs=False, | |
request_timeout=60 * 3 | |
) as es: | |
for query_group in es.msearch(body=queries).get("responses", []): | |
for hit in query_group.get("hits", {}).get("hits", []): | |
hit = ElasticHitsResult( | |
index=hit["_index"], | |
id=hit["_id"], | |
score=hit["_score"], | |
source=hit["_source"], | |
inner_hits=hit.get("inner_hits", {}) | |
) | |
results.append(hit) | |
return results | |
def get_query_results(search_text: str, indices: Optional[List[str]] = None) -> List[ElasticHitsResult]: | |
"""Builds and executes Elasticsearch data queries from a search string. | |
Parameters | |
---------- | |
search_text : str | |
Search context string | |
indices : Optional[List[str]], optional | |
Semantic index names to search over, by default None | |
Returns | |
------- | |
List[ElasticHitsResult] | |
""" | |
queries = query_builder(query=search_text, indices=indices) | |
return multi_search(queries) | |
def retrieved_text(hits: Dict[str, Any]) -> str: | |
"""Extracts retrieved sub-texts from documents which are strong hits from semantic queries for the purpose of | |
re-scoring by a secondary language model. | |
Parameters | |
---------- | |
hits : Dict[str, Any] | |
Returns | |
------- | |
str | |
""" | |
text = [] | |
for _, v in hits.items(): | |
for h in (v.get("hits", {}).get("hits") or []): | |
for _, field in h.get("fields", {}).items(): | |
for chunk in field: | |
if chunk.get("chunk"): | |
text.extend(chunk["chunk"]) | |
return '\n'.join(text) | |
def cosine_rescore(query: str, contexts: List[str]) -> List[float]: | |
"""Computes cosine scores between retrieved contexts and the original query to re-score results based on overall | |
relevance to the original query. | |
Parameters | |
---------- | |
query : str | |
Search context string | |
contexts : List[str] | |
Semantic field sub-texts, order is by document retrieved from the original multi-search query. | |
Returns | |
------- | |
List[float] | |
Scores in the same order as the input document contexts | |
""" | |
nlp = CandidSLM() | |
X = nlp.encode([query, *contexts]).vectors | |
X = F.normalize(X, dim=-1, p=2.) | |
cosine = X[1:] @ X[:1].T | |
return cosine.flatten().cpu().numpy().tolist() | |
def reranker( | |
query_results: Iterable[ElasticHitsResult], | |
search_text: Optional[str] = None | |
) -> Iterator[ElasticHitsResult]: | |
"""Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales. | |
This will shuffle results | |
Parameters | |
---------- | |
query_results : Iterable[ElasticHitsResult] | |
Yields | |
------ | |
Iterator[ElasticHitsResult] | |
""" | |
results: List[ElasticHitsResult] = [] | |
texts: List[str] = [] | |
for _, data in groupby(query_results, key=lambda x: x.index): | |
data = list(data) | |
max_score = max(data, key=lambda x: x.score).score | |
min_score = min(data, key=lambda x: x.score).score | |
for d in data: | |
d.score = (d.score - min_score) / (max_score - min_score + 1e-9) | |
results.append(d) | |
if search_text: | |
text = retrieved_text(d.inner_hits) | |
texts.append(text) | |
# if search_text and len(texts) == len(results): | |
# scores = cosine_rescore(search_text, texts) | |
# for r, s in zip(results, scores): | |
# r.score = s | |
yield from sorted(results, key=lambda x: x.score, reverse=True) | |
def get_results(user_input: str, indices: List[str]) -> Tuple[str, List[Document]]: | |
"""End-to-end search and re-rank function. | |
Parameters | |
---------- | |
user_input : str | |
Search context string | |
indices : List[str] | |
Semantic index names to search over | |
Returns | |
------- | |
Tuple[str, List[Document]] | |
(concatenated text from search results, documents list) | |
""" | |
output = ["Search didn't return any Candid sources"] | |
page_content = [] | |
content = "Search didn't return any Candid sources" | |
results = get_query_results(search_text=user_input, indices=indices) | |
if results: | |
output = get_reranked_results(results, search_text=user_input) | |
for doc in output: | |
page_content.append(doc.page_content) | |
content = "\n\n".join(page_content) | |
# for the tool we need to return a tuple for content_and_artifact type | |
return content, output | |
def get_context(field_name: str, hit: ElasticHitsResult, context_length: int = 1024) -> str: | |
"""Pads the relevant chunk of text with context before and after | |
Parameters | |
---------- | |
field_name : str | |
a field with the long text that was chunked into pieces | |
hit : ElasticHitsResult | |
context_length : int, optional | |
length of text to add before and after the chunk, by default 1024 | |
Returns | |
------- | |
str | |
longer chunks stuffed together | |
""" | |
chunks_with_context = [] | |
long_text = hit.source.get(f"{field_name}", "") | |
inner_hits_field = f"embeddings.{field_name}.chunks" | |
found_chunks = hit.inner_hits.get(inner_hits_field, {}) | |
if found_chunks: | |
hits = found_chunks.get("hits", {}).get("hits", []) | |
for h in hits: | |
chunk = h.get("fields", {})[inner_hits_field][0]["chunk"][0] | |
# cutting the middle because we may have tokenizing artifacts there | |
chunk = chunk[3: -3] | |
# Find the start and end indices of the chunk in the large text | |
start_index = long_text.find(chunk) | |
if start_index != -1: # Chunk is found | |
end_index = start_index + len(chunk) | |
pre_start_index = max(0, start_index - context_length) | |
post_end_index = min(len(long_text), end_index + context_length) | |
chunks_with_context.append(long_text[pre_start_index:post_end_index]) | |
return '\n\n'.join(chunks_with_context) | |
def process_hit(hit: ElasticHitsResult) -> Union[Document, None]: | |
"""Parse Elasticsearch hit results into data structures handled by the RAG pipeline. | |
Parameters | |
---------- | |
hit : ElasticHitsResult | |
Returns | |
------- | |
Union[Document, None] | |
""" | |
if "issuelab-elser" in hit.index: | |
combined_item_description = hit.source.get("combined_item_description", "") # title inside | |
description = hit.source.get("description", "") | |
combined_issuelab_findings = hit.source.get("combined_issuelab_findings", "") | |
# we only need to process long texts | |
chunks_with_context_txt = get_context("content", hit, context_length=12) | |
doc = Document( | |
page_content='\n\n'.join([ | |
combined_item_description, | |
combined_issuelab_findings, | |
description, | |
chunks_with_context_txt | |
]), | |
metadata={ | |
"title": hit.source["title"], | |
"source": "IssueLab", | |
"source_id": hit.source["resource_id"], | |
"url": hit.source.get("permalink", "") | |
} | |
) | |
elif "youtube" in hit.index: | |
title = hit.source.get("title", "") | |
# we only need to process long texts | |
description_cleaned_with_context_txt = get_context("description_cleaned", hit, context_length=12) | |
captions_cleaned_with_context_txt = get_context("captions_cleaned", hit, context_length=12) | |
doc = Document( | |
page_content='\n\n'.join([title, description_cleaned_with_context_txt, captions_cleaned_with_context_txt]), | |
metadata={ | |
"title": title, | |
"source": "Candid YouTube", | |
"source_id": hit.source['video_id'], | |
"url": f"https://www.youtube.com/watch?v={hit.source['video_id']}" | |
} | |
) | |
elif "candid-blog" in hit.index: | |
excerpt = hit.source.get("excerpt", "") | |
title = hit.source.get("title", "") | |
# we only need to process long texts | |
content_with_context_txt = get_context("content", hit, context_length=12) | |
doc = Document( | |
page_content='\n\n'.join([title, excerpt, content_with_context_txt]), | |
metadata={ | |
"title": title, | |
"source": "Candid Blog", | |
"source_id": hit.source["id"], | |
"url": hit.source["link"] | |
} | |
) | |
elif "candid-learning" in hit.index: | |
title = hit.source.get("title", "") | |
content_with_context_txt = get_context("content", hit, context_length=12) | |
training_topics = hit.source.get("training_topics", "") | |
staff_recommendations = hit.source.get("staff_recommendations", "") | |
doc = Document( | |
page_content='\n\n'.join([title, staff_recommendations, training_topics, content_with_context_txt]), | |
metadata={ | |
"title": hit.source["title"], | |
"source": "Candid Learning", | |
"source_id": hit.source["post_id"], | |
"url": hit.source.get("url", "") | |
} | |
) | |
elif "candid-help" in hit.index: | |
title = hit.source.get("title", "") | |
content_with_context_txt = get_context("content", hit, context_length=12) | |
combined_article_description = hit.source.get("combined_article_description", "") | |
doc = Document( | |
page_content='\n\n'.join([combined_article_description, content_with_context_txt]), | |
metadata={ | |
"title": title, | |
"source": "Candid Help", | |
"source_id": hit.source["id"], | |
"url": hit.source.get("link", "") | |
} | |
) | |
else: | |
doc = None | |
return doc | |
def get_reranked_results(results: List[ElasticHitsResult], search_text: Optional[str] = None) -> List[Document]: | |
"""Run data re-ranking and document building for tool usage. | |
Parameters | |
---------- | |
results : List[ElasticHitsResult] | |
search_text : Optional[str], optional | |
Search context string, by default None | |
Returns | |
------- | |
List[Document] | |
""" | |
output = [] | |
for r in reranker(results, search_text=search_text): | |
hit = process_hit(r) | |
if hit is not None: | |
output.append(hit) | |
return output | |
def retriever_tool(indices: List[str]) -> Tool: | |
"""Tool component for use in conditional edge building for RAG execution graph. | |
Cannot use `create_retriever_tool` because it only provides content losing all metadata on the way | |
https://python.langchain.com/docs/how_to/custom_tools/#returning-artifacts-of-tool-execution | |
Parameters | |
---------- | |
indices : List[str] | |
Semantic index names to search over | |
Returns | |
------- | |
Tool | |
""" | |
return Tool( | |
name="retrieve_social_sector_information", | |
func=partial(get_results, indices=indices), | |
description=( | |
"Return additional information about social and philanthropic sector, " | |
"including nonprofits (NGO), grants, foundations, funding, RFP, LOI, Candid." | |
), | |
args_schema=RetrieverInput, | |
response_format="content_and_artifact" | |
) | |