{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "b41fd227", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using device: cuda\n" ] } ], "source": [ "!python settings.py" ] }, { "cell_type": "code", "execution_count": null, "id": "030016c2", "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "import pandas as pd\n", "from pprint import pprint\n", "from tqdm.autonotebook import tqdm\n", "\n", "from sentence_transformers import SentenceTransformer\n", "\n", "from mteb import MTEB\n", "from mteb.abstasks.TaskMetadata import TaskMetadata\n", "from mteb.abstasks.AbsTaskRetrieval import AbsTaskRetrieval\n", "\n", "from settings import MODEL_NAME, OUTPUT_DIR, DEVICE, BATCH_SIZE\n", "os.environ['WANDB_DISABLED'] = 'true'" ] }, { "cell_type": "code", "execution_count": 4, "id": "dd3f53a3", "metadata": {}, "outputs": [], "source": [ "data = {\n", " 'corpus': pd.read_parquet('data/processed/corpus_data.parquet'),\n", " 'train' : pd.read_parquet('data/processed/train_data.parquet'),\n", " 'test' : pd.read_parquet('data/processed/test_data.parquet')\n", "}\n", "for split in ['train', 'test']:\n", " data[split]['cid'] = data[split]['cid'].apply(lambda x: x.tolist())\n", " data[split]['context_list'] = data[split]['context_list'].apply(lambda x: x.tolist())" ] }, { "cell_type": "code", "execution_count": 5, "id": "41ffd5ce", "metadata": {}, "outputs": [], "source": [ "class BKAILegalDocRetrievalTask(AbsTaskRetrieval):\n", " # Metadata definition used by MTEB benchmark\n", " metadata = TaskMetadata(name='BKAILegalDocRetrieval',\n", " description='',\n", " reference='https://github.com/embeddings-benchmark/mteb/blob/main/docs/adding_a_dataset.md',\n", " type='Retrieval',\n", " category='s2p',\n", " modalities=['text'],\n", " eval_splits=['test'],\n", " eval_langs=['vi'],\n", " main_score='ndcg_at_10',\n", " other_scores=['recall_at_10', 'precision_at_10', 'map'],\n", " dataset={\n", " 'path' : 'data',\n", " 'revision': 'd4c5a8ba10ae71224752c727094ac4c46947fa29',\n", " },\n", " date=('2012-01-01', '2020-01-01'),\n", " form='Written',\n", " domains=['Academic', 'Non-fiction'],\n", " task_subtypes=['Scientific Reranking'],\n", " license='cc-by-nc-4.0',\n", " annotations_creators='derived',\n", " dialect=[],\n", " text_creation='found',\n", " bibtex_citation=''\n", " )\n", "\n", " data_loaded = True # Flag\n", "\n", " def __init__(self, **kwargs):\n", " super().__init__(**kwargs)\n", "\n", " self.corpus = {}\n", " self.queries = {}\n", " self.relevant_docs = {}\n", "\n", " shared_corpus = {}\n", " for _, row in data['corpus'].iterrows():\n", " shared_corpus[f\"c{row['cid']}\"] = {\n", " 'text': row['text'],\n", " '_id' : row['cid']\n", " }\n", " \n", " for split in ['train', 'test']:\n", " self.corpus[split] = shared_corpus\n", " self.queries[split] = {}\n", " self.relevant_docs[split] = {}\n", "\n", " for split in ['train', 'test']:\n", " for _, row in data[split].iterrows():\n", " qid, cids = row['qid'], row['cid']\n", " \n", " qid_str = f'q{qid}'\n", " cids_str = [f'c{cid}' for cid in cids]\n", " \n", " self.queries[split][qid_str] = row['question']\n", " \n", " if qid_str not in self.relevant_docs[split]:\n", " self.relevant_docs[split][qid_str] = {}\n", " \n", " for cid_str in cids_str:\n", " self.relevant_docs[split][qid_str][cid_str] = 1\n", " \n", " self.data_loaded = True" ] }, { "cell_type": "code", "execution_count": 6, "id": "8c212fe9", "metadata": {}, "outputs": [], "source": [ "fine_tuned_model = SentenceTransformer(OUTPUT_DIR, device=DEVICE)" ] }, { "cell_type": "code", "execution_count": 7, "id": "aae09322", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The `batch_size` argument is deprecated and will be removed in the next release. Please use `encode_kwargs = {'batch_size': ...}` to set the batch size instead.\n" ] }, { "data": { "text/html": [ "
───────────────────────────────────────────────── Selected tasks ─────────────────────────────────────────────────\n", "\n" ], "text/plain": [ "\u001b[38;5;235m───────────────────────────────────────────────── \u001b[0m\u001b[1mSelected tasks \u001b[0m\u001b[38;5;235m ─────────────────────────────────────────────────\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Retrieval\n",
"
\n"
],
"text/plain": [
"\u001b[1mRetrieval\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" - BKAILegalDocRetrieval, s2p\n",
"
\n"
],
"text/plain": [
" - BKAILegalDocRetrieval, \u001b[3;38;5;241ms2p\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n", "\n", "\n" ], "text/plain": [ "\n", "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "53778754caf4456f8e140cfa58b60709", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Batches: 0%| | 0/233 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f9b27ae885fc46ad83f332f222a76381", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Batches: 0%| | 0/391 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0b6e38b0d54a4b429db05158604d24a5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Batches: 0%| | 0/391 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "20ec5df7261c43a7921abc968cc5e3a6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Batches: 0%| | 0/391 [00:02, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5f365f06d3de4becb965adb801aeee60", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Batches: 0%| | 0/391 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a43b764ac83e43aeb754c1e60771fd5c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Batches: 0%| | 0/391 [00:02, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ae46c8f76bc64eac8ca475d13f312875", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Batches: 0%| | 0/91 [00:02, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "[TaskResult(task_name=BKAILegalDocRetrieval, scores=...)]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "custom_task = BKAILegalDocRetrievalTask()\n", "evaluation = MTEB(tasks=[custom_task])\n", "evaluation.run(fine_tuned_model, batch_size=BATCH_SIZE)" ] }, { "cell_type": "code", "execution_count": 8, "id": "004e6930", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Main Evaluation Metrics (Top-K = 10):\n", "{'evaluation_time (s)': 3061.7869832515717,\n", " 'main_score': 0.60389,\n", " 'mrr@10': 0.555102,\n", " 'precision@10': 0.08587,\n", " 'recall@10': 0.79407}\n" ] } ], "source": [ "file_path = f\"results/no_model_name_available/no_revision_available/BKAILegalDocRetrieval.json\"\n", "\n", "with open(file_path, 'r', encoding='utf-8') as f:\n", " eval_data = json.load(f)\n", "\n", "scores = eval_data[\"scores\"][\"test\"][0]\n", "main_metrics = {\n", " 'main_score' : scores.get('ndcg_at_10'),\n", " 'recall@10' : scores.get('recall_at_10'),\n", " 'precision@10' : scores.get('precision_at_10'),\n", " 'mrr@10' : scores.get('mrr_at_10'),\n", " 'evaluation_time (s)': eval_data.get('evaluation_time')\n", "}\n", "\n", "print('Main Evaluation Metrics (Top-K = 10):')\n", "pprint(main_metrics)" ] }, { "cell_type": "code", "execution_count": 9, "id": "672ebc32", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Evaluation Scores by K:\n", "metric map mrr ndcg precision recall\n", "k \n", "1 0.4033 0.4242 0.4242 0.4242 0.4033\n", "3 0.5031 0.5247 0.5394 0.2215 0.6232\n", "5 0.5230 0.5434 0.5739 0.1512 0.7047\n", "10 0.5361 0.5551 0.6039 0.0859 0.7941\n", "20 0.5414 0.5596 0.6216 0.0469 0.8611\n", "100 0.5442 0.5617 0.6389 0.0104 0.9480\n", "1000 0.5444 0.5619 0.6444 0.0011 0.9879\n" ] } ], "source": [ "metrics = {k: v for k, v in scores.items() if '_at_' in k and not k.startswith('nauc')}\n", "\n", "parsed_metrics = []\n", "for key, value in metrics.items():\n", " metric, at_k = key.split('_at_')\n", " parsed_metrics.append({'metric': metric, 'k': int(at_k), 'score': value})\n", "\n", "df_metrics = pd.DataFrame(parsed_metrics).pivot(index='k', columns='metric', values='score')\n", "df_metrics = df_metrics.sort_index()\n", "\n", "print(\"\\nEvaluation Scores by K:\")\n", "print(df_metrics.round(4))" ] } ], "metadata": { "kernelspec": { "display_name": "legal_doc_retrieval", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 5 }