{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "b41fd227", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using device: cuda\n" ] } ], "source": [ "!python settings.py" ] }, { "cell_type": "code", "execution_count": 2, "id": "b5fd917b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "π¦ PyTorch version: 2.5.1\n", "π CUDA available : True\n", "π§ GPU Name : NVIDIA RTX A4000\n", "π¦ FAISS version : 1.9.0\n", "π FAISS is using GPU β \n" ] } ], "source": [ "import torch\n", "\n", "print(\"π¦ PyTorch version:\", torch.__version__)\n", "print(\"π CUDA available :\", torch.cuda.is_available())\n", "if torch.cuda.is_available():\n", " print(\"π§ GPU Name :\", torch.cuda.get_device_name(0))\n", " \n", "import faiss\n", "\n", "print(\"π¦ FAISS version :\", faiss.__version__)\n", "\n", "# Kiα»m tra module FAISS-GPU cΓ³ hoαΊ‘t Δα»ng khΓ΄ng\n", "try:\n", " res = faiss.StandardGpuResources() # NαΊΏu khΓ΄ng lα»i lΓ cΓ³ GPU\n", " print(\"π FAISS is using GPU β \")\n", "except Exception as e:\n", " print(\"β FAISS is NOT using GPU:\", str(e))" ] }, { "cell_type": "code", "execution_count": 3, "id": "030016c2", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\Administrator\\AppData\\Local\\Temp\\2\\ipykernel_648\\3951191562.py:5: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n", " from tqdm.autonotebook import tqdm\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Using device: cuda\n" ] } ], "source": [ "import os\n", "import json\n", "import pandas as pd\n", "from pprint import pprint\n", "from tqdm.autonotebook import tqdm\n", "\n", "from sentence_transformers import SentenceTransformer\n", "from mteb import MTEB\n", "from mteb.abstasks.TaskMetadata import TaskMetadata\n", "from mteb.abstasks.AbsTaskRetrieval import AbsTaskRetrieval\n", "\n", "from settings import MODEL_NAME, OUTPUT_DIR, DEVICE, BATCH_SIZE\n", "\n", "os.environ['WANDB_DISABLED'] = 'true'" ] }, { "cell_type": "code", "execution_count": 4, "id": "dd3f53a3", "metadata": {}, "outputs": [], "source": [ "data = {\n", " 'corpus': pd.read_parquet('data/processed/corpus_data.parquet'),\n", " 'train' : pd.read_parquet('data/processed/train_data.parquet'),\n", " 'test' : pd.read_parquet('data/processed/test_data.parquet')\n", "}\n", "for split in ['train', 'test']:\n", " data[split]['cid'] = data[split]['cid'].apply(lambda x: x.tolist())\n", " data[split]['context_list'] = data[split]['context_list'].apply(lambda x: x.tolist())" ] }, { "cell_type": "code", "execution_count": 5, "id": "41ffd5ce", "metadata": {}, "outputs": [], "source": [ "class BKAILegalDocRetrievalTask(AbsTaskRetrieval):\n", " # Metadata definition used by MTEB benchmark\n", " metadata = TaskMetadata(name='BKAILegalDocRetrieval',\n", " description='',\n", " reference='https://github.com/embeddings-benchmark/mteb/blob/main/docs/adding_a_dataset.md',\n", " type='Retrieval',\n", " category='s2p',\n", " modalities=['text'],\n", " eval_splits=['test'],\n", " eval_langs=['vi'],\n", " main_score='ndcg_at_10',\n", " other_scores=['recall_at_10', 'precision_at_10', 'map'],\n", " dataset={\n", " 'path' : 'data',\n", " 'revision': 'd4c5a8ba10ae71224752c727094ac4c46947fa29',\n", " },\n", " date=('2012-01-01', '2020-01-01'),\n", " form='Written',\n", " domains=['Academic', 'Non-fiction'],\n", " task_subtypes=['Scientific Reranking'],\n", " license='cc-by-nc-4.0',\n", " annotations_creators='derived',\n", " dialect=[],\n", " text_creation='found',\n", " bibtex_citation=''\n", " )\n", "\n", " data_loaded = True # Flag\n", "\n", " def __init__(self, **kwargs):\n", " super().__init__(**kwargs)\n", "\n", " self.corpus = {}\n", " self.queries = {}\n", " self.relevant_docs = {}\n", "\n", " shared_corpus = {}\n", " for _, row in data['corpus'].iterrows():\n", " shared_corpus[f\"c{row['cid']}\"] = {\n", " 'text': row['text'],\n", " '_id' : row['cid']\n", " }\n", " \n", " for split in ['train', 'test']:\n", " self.corpus[split] = shared_corpus\n", " self.queries[split] = {}\n", " self.relevant_docs[split] = {}\n", "\n", " for split in ['train', 'test']:\n", " for _, row in data[split].iterrows():\n", " qid, cids = row['qid'], row['cid']\n", " \n", " qid_str = f'q{qid}'\n", " cids_str = [f'c{cid}' for cid in cids]\n", " \n", " self.queries[split][qid_str] = row['question']\n", " \n", " if qid_str not in self.relevant_docs[split]:\n", " self.relevant_docs[split][qid_str] = {}\n", " \n", " for cid_str in cids_str:\n", " self.relevant_docs[split][qid_str][cid_str] = 1\n", " \n", " self.data_loaded = True" ] }, { "cell_type": "code", "execution_count": 6, "id": "8c212fe9", "metadata": {}, "outputs": [], "source": [ "fine_tuned_model = SentenceTransformer(OUTPUT_DIR, device=DEVICE)" ] }, { "cell_type": "code", "execution_count": 7, "id": "aae09322", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The `batch_size` argument is deprecated and will be removed in the next release. Please use `encode_kwargs = {'batch_size': ...}` to set the batch size instead.\n" ] }, { "data": { "text/html": [ "
βββββββββββββββββββββββββββββββββββββββββββββββββ Selected tasks βββββββββββββββββββββββββββββββββββββββββββββββββ\n", "\n" ], "text/plain": [ "\u001b[38;5;235mβββββββββββββββββββββββββββββββββββββββββββββββββ \u001b[0m\u001b[1mSelected tasks \u001b[0m\u001b[38;5;235m βββββββββββββββββββββββββββββββββββββββββββββββββ\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Retrieval\n",
"
\n"
],
"text/plain": [
"\u001b[1mRetrieval\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" - BKAILegalDocRetrieval, s2p\n",
"
\n"
],
"text/plain": [
" - BKAILegalDocRetrieval, \u001b[3;38;5;241ms2p\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n", "\n", "\n" ], "text/plain": [ "\n", "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "53778754caf4456f8e140cfa58b60709", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Batches: 0%| | 0/233 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f9b27ae885fc46ad83f332f222a76381", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Batches: 0%| | 0/391 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0b6e38b0d54a4b429db05158604d24a5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Batches: 0%| | 0/391 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "20ec5df7261c43a7921abc968cc5e3a6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Batches: 0%| | 0/391 [00:02, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5f365f06d3de4becb965adb801aeee60", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Batches: 0%| | 0/391 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a43b764ac83e43aeb754c1e60771fd5c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Batches: 0%| | 0/391 [00:02, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ae46c8f76bc64eac8ca475d13f312875", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Batches: 0%| | 0/91 [00:02, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "[TaskResult(task_name=BKAILegalDocRetrieval, scores=...)]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "custom_task = BKAILegalDocRetrievalTask()\n", "evaluation = MTEB(tasks=[custom_task])\n", "evaluation.run(fine_tuned_model, batch_size=BATCH_SIZE)" ] }, { "cell_type": "code", "execution_count": 8, "id": "004e6930", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Main Evaluation Metrics (Top-K = 10):\n", "{'evaluation_time (s)': 3061.7869832515717,\n", " 'main_score': 0.60389,\n", " 'mrr@10': 0.555102,\n", " 'precision@10': 0.08587,\n", " 'recall@10': 0.79407}\n" ] } ], "source": [ "file_path = f\"results/no_model_name_available/no_revision_available/BKAILegalDocRetrieval.json\"\n", "\n", "with open(file_path, 'r', encoding='utf-8') as f:\n", " eval_data = json.load(f)\n", "\n", "scores = eval_data[\"scores\"][\"test\"][0]\n", "main_metrics = {\n", " 'main_score' : scores.get('ndcg_at_10'),\n", " 'recall@10' : scores.get('recall_at_10'),\n", " 'precision@10' : scores.get('precision_at_10'),\n", " 'mrr@10' : scores.get('mrr_at_10'),\n", " 'evaluation_time (s)': eval_data.get('evaluation_time')\n", "}\n", "\n", "print('Main Evaluation Metrics (Top-K = 10):')\n", "pprint(main_metrics)" ] }, { "cell_type": "code", "execution_count": 9, "id": "672ebc32", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Evaluation Scores by K:\n", "metric map mrr ndcg precision recall\n", "k \n", "1 0.4033 0.4242 0.4242 0.4242 0.4033\n", "3 0.5031 0.5247 0.5394 0.2215 0.6232\n", "5 0.5230 0.5434 0.5739 0.1512 0.7047\n", "10 0.5361 0.5551 0.6039 0.0859 0.7941\n", "20 0.5414 0.5596 0.6216 0.0469 0.8611\n", "100 0.5442 0.5617 0.6389 0.0104 0.9480\n", "1000 0.5444 0.5619 0.6444 0.0011 0.9879\n" ] } ], "source": [ "metrics = {k: v for k, v in scores.items() if '_at_' in k and not k.startswith('nauc')}\n", "\n", "parsed_metrics = []\n", "for key, value in metrics.items():\n", " metric, at_k = key.split('_at_')\n", " parsed_metrics.append({'metric': metric, 'k': int(at_k), 'score': value})\n", "\n", "df_metrics = pd.DataFrame(parsed_metrics).pivot(index='k', columns='metric', values='score')\n", "df_metrics = df_metrics.sort_index()\n", "\n", "print(\"\\nEvaluation Scores by K:\")\n", "print(df_metrics.round(4))" ] } ], "metadata": { "kernelspec": { "display_name": "legal_doc_retrieval", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 5 }