{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "24106202", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using device: cuda\n" ] } ], "source": [ "!python settings.py" ] }, { "cell_type": "code", "execution_count": null, "id": "0086aabe", "metadata": {}, "outputs": [], "source": [ "import os\n", "import pandas as pd\n", "from datasets import Dataset\n", "from tqdm.autonotebook import tqdm\n", "\n", "from sentence_transformers import (\n", " SentenceTransformer,\n", " SentenceTransformerTrainer,\n", " SentenceTransformerTrainingArguments,\n", " SentenceTransformerModelCardData,\n", ")\n", "from sentence_transformers.readers import InputExample\n", "from sentence_transformers.models import Transformer, Pooling\n", "from sentence_transformers.losses import CachedMultipleNegativesRankingLoss\n", "from sentence_transformers.training_args import BatchSamplers\n", "\n", "from settings import MODEL_ID, MODEL_NAME, CACHE_DIR, OUTPUT_DIR, MAX_SEQ_LEN, EPOCHS, LR, BATCH_SIZE, DEVICE\n", "os.environ['WANDB_DISABLED'] = 'true'" ] }, { "cell_type": "code", "execution_count": 3, "id": "3a5cc53d", "metadata": {}, "outputs": [], "source": [ "data = {\n", " 'corpus': pd.read_parquet('data/processed/corpus_data.parquet'),\n", " 'train' : pd.read_parquet('data/processed/train_data.parquet'),\n", " 'test' : pd.read_parquet('data/processed/test_data.parquet')\n", "}\n", "for split in ['train', 'test']:\n", " data[split]['cid'] = data[split]['cid'].apply(lambda x: x.tolist())\n", " data[split]['context_list'] = data[split]['context_list'].apply(lambda x: x.tolist())\n", " \n", "examples = {'train': [], 'test': []}" ] }, { "cell_type": "code", "execution_count": 4, "id": "30ebbd40", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | question | \n", "context_list | \n", "qid | \n", "cid | \n", "
---|---|---|---|---|
0 | \n", "Liên đoàn Luật sư Việt Nam là tổ chức xã hội –... | \n", "[“Điều 2. Địa vị pháp lý của Liên đoàn Luật sư... | \n", "72600 | \n", "[142820] | \n", "
1 | \n", "Tên hợp tác xã bị rơi vào trường hợp cấm thì c... | \n", "[\"Điều 7. Tên hợp tác xã, liên hiệp hợp tác xã... | \n", "147562 | \n", "[27817, 72117] | \n", "
2 | \n", "Tài xế lái xe ô tô khách 50 chỗ ngồi bao lâu t... | \n", "[\"1. Sử dụng lái xe bảo đảm sức khỏe theo tiêu... | \n", "142107 | \n", "[33215, 56201] | \n", "
3 | \n", "Các bước chuẩn bị thủ thuật bó bột Cravate sẽ ... | \n", "[BỘT CRAVATE\\n...\\nIV. CHUẨN BỊ\\n1. Người thực... | \n", "77353 | \n", "[148158] | \n", "
4 | \n", "Viên chức Hộ sinh hạng 4 có những nhiệm vụ gì ... | \n", "[Hộ sinh hạng IV - Mã số: V.08.06.16\\n1. Nhiệm... | \n", "113090 | \n", "[188132] | \n", "
Step | \n", "Training Loss | \n", "
---|---|
100 | \n", "1.882700 | \n", "
200 | \n", "0.442800 | \n", "
300 | \n", "0.356400 | \n", "
400 | \n", "0.285600 | \n", "
500 | \n", "0.244500 | \n", "
600 | \n", "0.224100 | \n", "
700 | \n", "0.193800 | \n", "
800 | \n", "0.189400 | \n", "
900 | \n", "0.143200 | \n", "
1000 | \n", "0.143200 | \n", "
1100 | \n", "0.134100 | \n", "
1200 | \n", "0.131100 | \n", "
1300 | \n", "0.124900 | \n", "
1400 | \n", "0.122700 | \n", "
1500 | \n", "0.124100 | \n", "
1600 | \n", "0.102800 | \n", "
1700 | \n", "0.085200 | \n", "
1800 | \n", "0.085000 | \n", "
1900 | \n", "0.082000 | \n", "
2000 | \n", "0.080000 | \n", "
2100 | \n", "0.082400 | \n", "
2200 | \n", "0.080200 | \n", "
2300 | \n", "0.082200 | \n", "
2400 | \n", "0.063300 | \n", "
2500 | \n", "0.061500 | \n", "
2600 | \n", "0.061200 | \n", "
2700 | \n", "0.058000 | \n", "
2800 | \n", "0.056600 | \n", "
2900 | \n", "0.052100 | \n", "
3000 | \n", "0.054800 | \n", "
3100 | \n", "0.054700 | \n", "
3200 | \n", "0.047900 | \n", "
3300 | \n", "0.044900 | \n", "
3400 | \n", "0.044000 | \n", "
3500 | \n", "0.043900 | \n", "
3600 | \n", "0.044400 | \n", "
3700 | \n", "0.045700 | \n", "
3800 | \n", "0.046100 | \n", "
"
],
"text/plain": [
"