Spaces:
Running
Running
from typing import List, Dict | |
from tqdm.auto import tqdm | |
from transformers import AutoModelForMaskedLM, AutoTokenizer | |
from torch.utils.data import DataLoader | |
from torch.nn import functional as F | |
import torch | |
class SpladeEncoder: | |
batch_size = 4 | |
def __init__(self): | |
model_id = "naver/splade-v3" | |
self.tokenizer = AutoTokenizer.from_pretrained(model_id) | |
self.model = AutoModelForMaskedLM.from_pretrained(model_id) | |
self.idx2token = {idx: token for token, idx in self.tokenizer.get_vocab().items()} | |
def forward(self, texts: List[str]): | |
vectors = [] | |
for batch in tqdm(DataLoader(dataset=texts, shuffle=False, batch_size=self.batch_size), desc="Re-ranking"): | |
tokens = self.tokenizer(batch, return_tensors='pt', truncation=True, padding=True) | |
output = self.model(**tokens) | |
vec = torch.max( | |
torch.log(1 + torch.relu(output.logits)) * tokens.attention_mask.unsqueeze(-1), | |
dim=1 | |
)[0].squeeze() | |
vectors.append(vec) | |
return torch.vstack(vectors) | |
def query_reranking(self, query: str, documents: List[str]): | |
vec = self.forward([query, *documents]) | |
xQ = F.normalize(vec[:1], dim=-1, p=2.) | |
xD = F.normalize(vec[1:], dim=-1, p=2.) | |
return (xQ * xD).sum(dim=-1).cpu().tolist() | |
def token_expand(self, query: str) -> Dict[str, float]: | |
vec = self.forward([query]).squeeze() | |
cols = vec.nonzero().squeeze().cpu().tolist() | |
weights = vec[cols].cpu().tolist() | |
sparse_dict_tokens = {self.idx2token[idx]: round(weight, 3) for idx, weight in zip(cols, weights) if weight > 0} | |
return dict(sorted(sparse_dict_tokens.items(), key=lambda item: item[1], reverse=True)) | |