{ "cells": [ { "cell_type": "code", "execution_count": 11, "id": "1195e917", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using device: cuda\n" ] } ], "source": [ "!python settings.py" ] }, { "cell_type": "code", "execution_count": null, "id": "01589fc8", "metadata": {}, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "import pandas as pd\n", "from tqdm.autonotebook import tqdm\n", "\n", "import faiss\n", "from sentence_transformers import SentenceTransformer, CrossEncoder\n", "\n", "from transformers import logging\n", "logging.set_verbosity_error()\n", "\n", "from settings import OUTPUT_DIR, DEVICE\n", "os.environ['WANDB_DISABLED'] = 'true'" ] }, { "cell_type": "code", "execution_count": 14, "id": "5634b72a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "SentenceTransformer(\n", " (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel \n", " (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})\n", ")" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fine_tuned_model = SentenceTransformer(OUTPUT_DIR, device=DEVICE)\n", "fine_tuned_model.half()" ] }, { "cell_type": "code", "execution_count": null, "id": "62cc0ead", "metadata": {}, "outputs": [], "source": [ "passages = pd.read_parquet('data/processed/corpus_data.parquet')['text'].tolist()\n", "corpus_embeddings = fine_tuned_model.encode(\n", " passages, \n", " batch_size=128,\n", " convert_to_numpy=True, \n", " normalize_embeddings=True,\n", " show_progress_bar=True, \n", " device=DEVICE,\n", ").astype(np.float32)" ] }, { "cell_type": "code", "execution_count": null, "id": "465e8d2a", "metadata": {}, "outputs": [], "source": [ "d = corpus_embeddings.shape[1] # 768\n", "cpu_index = faiss.IndexFlatIP(d)\n", "\n", "res = faiss.StandardGpuResources()\n", "gpu_index = faiss.index_cpu_to_gpu(res, 0, cpu_index)\n", "gpu_index.add(corpus_embeddings)" ] }, { "cell_type": "code", "execution_count": null, "id": "af365371", "metadata": {}, "outputs": [], "source": [ "final_cpu_index = faiss.index_gpu_to_cpu(gpu_index)\n", "faiss.write_index(final_cpu_index, 'data/retrieval/legal_faiss.index')" ] }, { "cell_type": "code", "execution_count": null, "id": "9251d0db", "metadata": {}, "outputs": [], "source": [ "legal_index = faiss.read_index('data/retrieval/legal_faiss.index')" ] }, { "cell_type": "code", "execution_count": 19, "id": "9f54c596", "metadata": {}, "outputs": [], "source": [ "def retrieval(emb_model, query, index, top_k=10):\n", " q_emb = emb_model.encode(\n", " query, \n", " convert_to_numpy=True, \n", " normalize_embeddings=True,\n", " ).astype(np.float32).reshape(1, -1)\n", " \n", " scores, indices = index.search(q_emb, top_k) # shape: (1, top_k)\n", " \n", " cand_idxs = indices[0]\n", " cand_scores = scores[0]\n", " cand_texts = [passages[i] for i in cand_idxs]\n", "\n", " results = [{\n", " 'index': int(cand_idxs[i]),\n", " 'score': float(cand_scores[i]),\n", " 'text': cand_texts[i]\n", " } for i in range(len(cand_idxs))]\n", " \n", " return results" ] }, { "cell_type": "code", "execution_count": null, "id": "ece21ef6", "metadata": {}, "outputs": [], "source": [ "query = 'Tội xúc phạm danh dự'\n", "hits = retrieval(fine_tuned_model, query, legal_index, top_k=10)\n", "\n", "for h in hits:\n", " print(f\"[Rank {hits.index(h)+1}] - index={h['index']}, score={h['score']:.4f}]\")\n", " print(f\"{h['text']}\")\n", " print('-' * 100)" ] } ], "metadata": { "kernelspec": { "display_name": "legal_doc_retrieval", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 5 }