In [11]:
!python settings.py

Using device: cuda


In [None]:
import os
import numpy as np
import pandas as pd
from tqdm.autonotebook import tqdm

import faiss
from sentence_transformers import SentenceTransformer, CrossEncoder

from transformers import logging
logging.set_verbosity_error()

from settings import OUTPUT_DIR, DEVICE
os.environ['WANDB_DISABLED'] = 'true'

In [14]:
fine_tuned_model = SentenceTransformer(OUTPUT_DIR, device=DEVICE)
fine_tuned_model.half()

SentenceTransformer(
 (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel 
 (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
)

In [None]:
passages = pd.read_parquet('data/processed/corpus_data.parquet')['text'].tolist()
corpus_embeddings = fine_tuned_model.encode(
 passages, 
 batch_size=128,
 convert_to_numpy=True, 
 normalize_embeddings=True,
 show_progress_bar=True, 
 device=DEVICE,
).astype(np.float32)

In [None]:
d = corpus_embeddings.shape[1] # 768
cpu_index = faiss.IndexFlatIP(d)

res = faiss.StandardGpuResources()
gpu_index = faiss.index_cpu_to_gpu(res, 0, cpu_index)
gpu_index.add(corpus_embeddings)

In [None]:
final_cpu_index = faiss.index_gpu_to_cpu(gpu_index)
faiss.write_index(final_cpu_index, 'data/retrieval/legal_faiss.index')

In [None]:
legal_index = faiss.read_index('data/retrieval/legal_faiss.index')

In [19]:
def retrieval(emb_model, query, index, top_k=10):
 q_emb = emb_model.encode(
 query, 
 convert_to_numpy=True, 
 normalize_embeddings=True,
 ).astype(np.float32).reshape(1, -1)
 
 scores, indices = index.search(q_emb, top_k) # shape: (1, top_k)
 
 cand_idxs = indices[0]
 cand_scores = scores[0]
 cand_texts = [passages[i] for i in cand_idxs]

 results = [{
 'index': int(cand_idxs[i]),
 'score': float(cand_scores[i]),
 'text': cand_texts[i]
 } for i in range(len(cand_idxs))]
 
 return results

In [None]:
query = 'Tội xúc phạm danh dự'
hits = retrieval(fine_tuned_model, query, legal_index, top_k=10)

for h in hits:
 print(f"[Rank {hits.index(h)+1}] - index={h['index']}, score={h['score']:.4f}]")
 print(f"{h['text']}")
 print('-' * 100)