{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "24106202", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using device: cuda\n" ] } ], "source": [ "!python settings.py" ] }, { "cell_type": "code", "execution_count": null, "id": "0086aabe", "metadata": {}, "outputs": [], "source": [ "import os\n", "import pandas as pd\n", "from datasets import Dataset\n", "from tqdm.autonotebook import tqdm\n", "\n", "from sentence_transformers import (\n", " SentenceTransformer,\n", " SentenceTransformerTrainer,\n", " SentenceTransformerTrainingArguments,\n", " SentenceTransformerModelCardData,\n", ")\n", "from sentence_transformers.readers import InputExample\n", "from sentence_transformers.models import Transformer, Pooling\n", "from sentence_transformers.losses import CachedMultipleNegativesRankingLoss\n", "from sentence_transformers.training_args import BatchSamplers\n", "\n", "from settings import MODEL_ID, MODEL_NAME, CACHE_DIR, OUTPUT_DIR, MAX_SEQ_LEN, EPOCHS, LR, BATCH_SIZE, DEVICE\n", "os.environ['WANDB_DISABLED'] = 'true'" ] }, { "cell_type": "code", "execution_count": 3, "id": "3a5cc53d", "metadata": {}, "outputs": [], "source": [ "data = {\n", " 'corpus': pd.read_parquet('data/processed/corpus_data.parquet'),\n", " 'train' : pd.read_parquet('data/processed/train_data.parquet'),\n", " 'test' : pd.read_parquet('data/processed/test_data.parquet')\n", "}\n", "for split in ['train', 'test']:\n", " data[split]['cid'] = data[split]['cid'].apply(lambda x: x.tolist())\n", " data[split]['context_list'] = data[split]['context_list'].apply(lambda x: x.tolist())\n", " \n", "examples = {'train': [], 'test': []}" ] }, { "cell_type": "code", "execution_count": 4, "id": "30ebbd40", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
questioncontext_listqidcid
0Liên đoàn Luật sư Việt Nam là tổ chức xã hội –...[“Điều 2. Địa vị pháp lý của Liên đoàn Luật sư...72600[142820]
1Tên hợp tác xã bị rơi vào trường hợp cấm thì c...[\"Điều 7. Tên hợp tác xã, liên hiệp hợp tác xã...147562[27817, 72117]
2Tài xế lái xe ô tô khách 50 chỗ ngồi bao lâu t...[\"1. Sử dụng lái xe bảo đảm sức khỏe theo tiêu...142107[33215, 56201]
3Các bước chuẩn bị thủ thuật bó bột Cravate sẽ ...[BỘT CRAVATE\\n...\\nIV. CHUẨN BỊ\\n1. Người thực...77353[148158]
4Viên chức Hộ sinh hạng 4 có những nhiệm vụ gì ...[Hộ sinh hạng IV - Mã số: V.08.06.16\\n1. Nhiệm...113090[188132]
\n", "
" ], "text/plain": [ " question \\\n", "0 Liên đoàn Luật sư Việt Nam là tổ chức xã hội –... \n", "1 Tên hợp tác xã bị rơi vào trường hợp cấm thì c... \n", "2 Tài xế lái xe ô tô khách 50 chỗ ngồi bao lâu t... \n", "3 Các bước chuẩn bị thủ thuật bó bột Cravate sẽ ... \n", "4 Viên chức Hộ sinh hạng 4 có những nhiệm vụ gì ... \n", "\n", " context_list qid cid \n", "0 [“Điều 2. Địa vị pháp lý của Liên đoàn Luật sư... 72600 [142820] \n", "1 [\"Điều 7. Tên hợp tác xã, liên hiệp hợp tác xã... 147562 [27817, 72117] \n", "2 [\"1. Sử dụng lái xe bảo đảm sức khỏe theo tiêu... 142107 [33215, 56201] \n", "3 [BỘT CRAVATE\\n...\\nIV. CHUẨN BỊ\\n1. Người thực... 77353 [148158] \n", "4 [Hộ sinh hạng IV - Mã số: V.08.06.16\\n1. Nhiệm... 113090 [188132] " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data['train'].head()" ] }, { "cell_type": "code", "execution_count": 5, "id": "943bf8ce", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "question \n", "context_list \n", "qid \n", "cid \n", "True\n" ] } ], "source": [ "# Debug\n", "for col in data['test'].columns:\n", " print(col, type(data['test'][col][0]))\n", " \n", "print((data['test']['cid'].apply(len) == data['test']['context_list'].apply(len)).all())" ] }, { "cell_type": "code", "execution_count": null, "id": "2c751cf4", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "509893cf5cfd4a8d9e18bba47561a41c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Processing train: 0%| | 0/89162 [00:00\n", " \n", " \n", " [3890/3890 3:32:33, Epoch 5/5]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
1001.882700
2000.442800
3000.356400
4000.285600
5000.244500
6000.224100
7000.193800
8000.189400
9000.143200
10000.143200
11000.134100
12000.131100
13000.124900
14000.122700
15000.124100
16000.102800
17000.085200
18000.085000
19000.082000
20000.080000
21000.082400
22000.080200
23000.082200
24000.063300
25000.061500
26000.061200
27000.058000
28000.056600
29000.052100
30000.054800
31000.054700
32000.047900
33000.044900
34000.044000
35000.043900
36000.044400
37000.045700
38000.046100

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "TrainOutput(global_step=3890, training_loss=0.1604946916084976, metrics={'train_runtime': 12756.5123, 'train_samples_per_second': 39.031, 'train_steps_per_second': 0.305, 'total_flos': 0.0, 'train_loss': 0.1604946916084976, 'epoch': 5.0})" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def to_frame(ex_list):\n", " rows = [(ex.texts[0], ex.texts[1]) for ex in ex_list]\n", " return pd.DataFrame(rows, columns=['text_0', 'text_1'])\n", "\n", "train_ds = Dataset.from_pandas(to_frame(examples['train']))\n", "\n", "trainer = SentenceTransformerTrainer(\n", " model=model,\n", " args=args,\n", " train_dataset=train_ds,\n", " loss=loss,\n", ")\n", "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "id": "f47a01a1", "metadata": {}, "outputs": [], "source": [ "model.save_pretrained(OUTPUT_DIR)\n", "# model.push_to_hub(\n", "# repo_id='YuITC/bert-base-multilingual-cased-finetuned-VNLegalDocs', \n", "# commit_message='Update README.md',\n", "# exist_ok=True,\n", "# replace_model_card=False,\n", "# train_datasets=['tmnam20/BKAI-Legal-Retrieval']\n", "# )" ] } ], "metadata": { "kernelspec": { "display_name": "legal_doc_retrieval", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 5 }