{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "b41fd227", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using device: cuda\n" ] } ], "source": [ "!python settings.py" ] }, { "cell_type": "code", "execution_count": null, "id": "030016c2", "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "import pandas as pd\n", "from pprint import pprint\n", "from tqdm.autonotebook import tqdm\n", "\n", "from sentence_transformers import SentenceTransformer\n", "\n", "from mteb import MTEB\n", "from mteb.abstasks.TaskMetadata import TaskMetadata\n", "from mteb.abstasks.AbsTaskRetrieval import AbsTaskRetrieval\n", "\n", "from settings import MODEL_NAME, OUTPUT_DIR, DEVICE, BATCH_SIZE\n", "os.environ['WANDB_DISABLED'] = 'true'" ] }, { "cell_type": "code", "execution_count": 4, "id": "dd3f53a3", "metadata": {}, "outputs": [], "source": [ "data = {\n", " 'corpus': pd.read_parquet('data/processed/corpus_data.parquet'),\n", " 'train' : pd.read_parquet('data/processed/train_data.parquet'),\n", " 'test' : pd.read_parquet('data/processed/test_data.parquet')\n", "}\n", "for split in ['train', 'test']:\n", " data[split]['cid'] = data[split]['cid'].apply(lambda x: x.tolist())\n", " data[split]['context_list'] = data[split]['context_list'].apply(lambda x: x.tolist())" ] }, { "cell_type": "code", "execution_count": 5, "id": "41ffd5ce", "metadata": {}, "outputs": [], "source": [ "class BKAILegalDocRetrievalTask(AbsTaskRetrieval):\n", " # Metadata definition used by MTEB benchmark\n", " metadata = TaskMetadata(name='BKAILegalDocRetrieval',\n", " description='',\n", " reference='https://github.com/embeddings-benchmark/mteb/blob/main/docs/adding_a_dataset.md',\n", " type='Retrieval',\n", " category='s2p',\n", " modalities=['text'],\n", " eval_splits=['test'],\n", " eval_langs=['vi'],\n", " main_score='ndcg_at_10',\n", " other_scores=['recall_at_10', 'precision_at_10', 'map'],\n", " dataset={\n", " 'path' : 'data',\n", " 'revision': 'd4c5a8ba10ae71224752c727094ac4c46947fa29',\n", " },\n", " date=('2012-01-01', '2020-01-01'),\n", " form='Written',\n", " domains=['Academic', 'Non-fiction'],\n", " task_subtypes=['Scientific Reranking'],\n", " license='cc-by-nc-4.0',\n", " annotations_creators='derived',\n", " dialect=[],\n", " text_creation='found',\n", " bibtex_citation=''\n", " )\n", "\n", " data_loaded = True # Flag\n", "\n", " def __init__(self, **kwargs):\n", " super().__init__(**kwargs)\n", "\n", " self.corpus = {}\n", " self.queries = {}\n", " self.relevant_docs = {}\n", "\n", " shared_corpus = {}\n", " for _, row in data['corpus'].iterrows():\n", " shared_corpus[f\"c{row['cid']}\"] = {\n", " 'text': row['text'],\n", " '_id' : row['cid']\n", " }\n", " \n", " for split in ['train', 'test']:\n", " self.corpus[split] = shared_corpus\n", " self.queries[split] = {}\n", " self.relevant_docs[split] = {}\n", "\n", " for split in ['train', 'test']:\n", " for _, row in data[split].iterrows():\n", " qid, cids = row['qid'], row['cid']\n", " \n", " qid_str = f'q{qid}'\n", " cids_str = [f'c{cid}' for cid in cids]\n", " \n", " self.queries[split][qid_str] = row['question']\n", " \n", " if qid_str not in self.relevant_docs[split]:\n", " self.relevant_docs[split][qid_str] = {}\n", " \n", " for cid_str in cids_str:\n", " self.relevant_docs[split][qid_str][cid_str] = 1\n", " \n", " self.data_loaded = True" ] }, { "cell_type": "code", "execution_count": 6, "id": "8c212fe9", "metadata": {}, "outputs": [], "source": [ "fine_tuned_model = SentenceTransformer(OUTPUT_DIR, device=DEVICE)" ] }, { "cell_type": "code", "execution_count": 7, "id": "aae09322", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The `batch_size` argument is deprecated and will be removed in the next release. Please use `encode_kwargs = {'batch_size': ...}` to set the batch size instead.\n" ] }, { "data": { "text/html": [ "
───────────────────────────────────────────────── Selected tasks  ─────────────────────────────────────────────────\n",
       "
\n" ], "text/plain": [ "\u001b[38;5;235m───────────────────────────────────────────────── \u001b[0m\u001b[1mSelected tasks \u001b[0m\u001b[38;5;235m ─────────────────────────────────────────────────\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Retrieval\n",
       "
\n" ], "text/plain": [ "\u001b[1mRetrieval\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
    - BKAILegalDocRetrieval, s2p\n",
       "
\n" ], "text/plain": [ " - BKAILegalDocRetrieval, \u001b[3;38;5;241ms2p\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
       "\n",
       "
\n" ], "text/plain": [ "\n", "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "53778754caf4456f8e140cfa58b60709", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Batches: 0%| | 0/233 [00:00