ask-candid / ask_candid /retrieval /sparse_lexical.py
brainsqueeze's picture
Batching
2744d22 verified
raw
history blame
1.8 kB
from typing import List, Dict
from tqdm.auto import tqdm
from transformers import AutoModelForMaskedLM, AutoTokenizer
from torch.utils.data import DataLoader
from torch.nn import functional as F
import torch
class SpladeEncoder:
batch_size = 4
def __init__(self):
model_id = "naver/splade-v3"
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForMaskedLM.from_pretrained(model_id)
self.idx2token = {idx: token for token, idx in self.tokenizer.get_vocab().items()}
@torch.no_grad()
def forward(self, texts: List[str]):
vectors = []
for batch in tqdm(DataLoader(dataset=texts, shuffle=False, batch_size=self.batch_size), desc="Re-ranking"):
tokens = self.tokenizer(batch, return_tensors='pt', truncation=True, padding=True)
output = self.model(**tokens)
vec = torch.max(
torch.log(1 + torch.relu(output.logits)) * tokens.attention_mask.unsqueeze(-1),
dim=1
)[0].squeeze()
vectors.append(vec)
return torch.vstack(vectors)
def query_reranking(self, query: str, documents: List[str]):
vec = self.forward([query, *documents])
xQ = F.normalize(vec[:1], dim=-1, p=2.)
xD = F.normalize(vec[1:], dim=-1, p=2.)
return (xQ * xD).sum(dim=-1).cpu().tolist()
def token_expand(self, query: str) -> Dict[str, float]:
vec = self.forward([query]).squeeze()
cols = vec.nonzero().squeeze().cpu().tolist()
weights = vec[cols].cpu().tolist()
sparse_dict_tokens = {self.idx2token[idx]: round(weight, 3) for idx, weight in zip(cols, weights) if weight > 0}
return dict(sorted(sparse_dict_tokens.items(), key=lambda item: item[1], reverse=True))