{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"id": "29a91458",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device: cuda\n"
]
}
],
"source": [
"!python settings.py"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "97c0ec5c",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import zipfile\n",
"import requests\n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "f7b1ed51",
"metadata": {},
"outputs": [],
"source": [
"# Download the dataset\n",
"url = 'https://huggingface.co/datasets/tmnam20/BKAI-Legal-Retrieval/resolve/main/archive.zip'\n",
"zip_path = 'data/original/archive.zip'\n",
"\n",
"response = requests.get(url)\n",
"with open(zip_path, 'wb') as f:\n",
" f.write(response.content)\n",
"\n",
"with zipfile.ZipFile(zip_path, 'r') as zip_ref:\n",
" zip_ref.extractall('data/original')\n",
" \n",
"os.remove(zip_path)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "4fe0c4f8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train split data: 89592\n",
"Test split data : 29864\n"
]
}
],
"source": [
"corpus_data = pd.read_csv('data/original/corpus.csv')\n",
"train_split = pd.read_csv('data/original/train_split.csv')\n",
"test_split = pd.read_csv('data/original/val_split.csv')\n",
"\n",
"print(f\"Train split data: {len(train_split)}\")\n",
"print(f\"Test split data : {len(test_split)}\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "6e3fbd6e",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" text | \n",
" cid | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" Thông tư này hướng dẫn tuần tra, canh gác bảo ... | \n",
" 0 | \n",
"
\n",
" \n",
" 1 | \n",
" 1. Hàng năm trước mùa mưa, lũ, Ủy ban nhân dân... | \n",
" 1 | \n",
"
\n",
" \n",
" 2 | \n",
" Tiêu chuẩn của các thành viên thuộc lực lượng ... | \n",
" 2 | \n",
"
\n",
" \n",
" 3 | \n",
" Nhiệm vụ của lực lượng tuần tra, canh gác đê\\n... | \n",
" 3 | \n",
"
\n",
" \n",
" 4 | \n",
" Phù hiệu của lực lượng tuần tra, canh gác đê\\n... | \n",
" 4 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" text cid\n",
"0 Thông tư này hướng dẫn tuần tra, canh gác bảo ... 0\n",
"1 1. Hàng năm trước mùa mưa, lũ, Ủy ban nhân dân... 1\n",
"2 Tiêu chuẩn của các thành viên thuộc lực lượng ... 2\n",
"3 Nhiệm vụ của lực lượng tuần tra, canh gác đê\\n... 3\n",
"4 Phù hiệu của lực lượng tuần tra, canh gác đê\\n... 4"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"corpus_data.head()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "3d32d13a",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" question | \n",
" context | \n",
" cid | \n",
" qid | \n",
" context_list | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" Liên đoàn Luật sư Việt Nam là tổ chức xã hội –... | \n",
" ['“Điều 2. Địa vị pháp lý của Liên đoàn Luật s... | \n",
" [142820] | \n",
" 72600 | \n",
" [“Điều 2. Địa vị pháp lý của Liên đoàn Luật sư... | \n",
"
\n",
" \n",
" 1 | \n",
" Tên hợp tác xã bị rơi vào trường hợp cấm thì c... | \n",
" ['Tên hợp tác xã, liên hiệp hợp tác xã\\n1. Tên... | \n",
" [27817, 72117] | \n",
" 147562 | \n",
" [\"Điều 7. Tên hợp tác xã, liên hiệp hợp tác xã... | \n",
"
\n",
" \n",
" 2 | \n",
" Tài xế lái xe ô tô khách 50 chỗ ngồi bao lâu t... | \n",
" ['\"1. Sử dụng lái xe bảo đảm sức khỏe theo tiê... | \n",
" [33215, 56201] | \n",
" 142107 | \n",
" [\"1. Sử dụng lái xe bảo đảm sức khỏe theo tiêu... | \n",
"
\n",
" \n",
" 3 | \n",
" Các bước chuẩn bị thủ thuật bó bột Cravate sẽ ... | \n",
" ['BỘT CRAVATE\\n...\\nIV. CHUẨN BỊ\\n1. Người thự... | \n",
" [148158] | \n",
" 77353 | \n",
" [BỘT CRAVATE\\n...\\nIV. CHUẨN BỊ\\n1. Người thực... | \n",
"
\n",
" \n",
" 4 | \n",
" Viên chức Hộ sinh hạng 4 có những nhiệm vụ gì ... | \n",
" ['Hộ sinh hạng IV - Mã số: V.08.06.16\\n1. Nhiệ... | \n",
" [188132] | \n",
" 113090 | \n",
" [Hộ sinh hạng IV - Mã số: V.08.06.16\\n1. Nhiệm... | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" question \\\n",
"0 Liên đoàn Luật sư Việt Nam là tổ chức xã hội –... \n",
"1 Tên hợp tác xã bị rơi vào trường hợp cấm thì c... \n",
"2 Tài xế lái xe ô tô khách 50 chỗ ngồi bao lâu t... \n",
"3 Các bước chuẩn bị thủ thuật bó bột Cravate sẽ ... \n",
"4 Viên chức Hộ sinh hạng 4 có những nhiệm vụ gì ... \n",
"\n",
" context cid qid \\\n",
"0 ['“Điều 2. Địa vị pháp lý của Liên đoàn Luật s... [142820] 72600 \n",
"1 ['Tên hợp tác xã, liên hiệp hợp tác xã\\n1. Tên... [27817, 72117] 147562 \n",
"2 ['\"1. Sử dụng lái xe bảo đảm sức khỏe theo tiê... [33215, 56201] 142107 \n",
"3 ['BỘT CRAVATE\\n...\\nIV. CHUẨN BỊ\\n1. Người thự... [148158] 77353 \n",
"4 ['Hộ sinh hạng IV - Mã số: V.08.06.16\\n1. Nhiệ... [188132] 113090 \n",
"\n",
" context_list \n",
"0 [“Điều 2. Địa vị pháp lý của Liên đoàn Luật sư... \n",
"1 [\"Điều 7. Tên hợp tác xã, liên hiệp hợp tác xã... \n",
"2 [\"1. Sử dụng lái xe bảo đảm sức khỏe theo tiêu... \n",
"3 [BỘT CRAVATE\\n...\\nIV. CHUẨN BỊ\\n1. Người thực... \n",
"4 [Hộ sinh hạng IV - Mã số: V.08.06.16\\n1. Nhiệm... "
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 'cid' column: '[1 2 3]'\n",
"train_split['cid'] = train_split['cid'].apply(lambda x: [int(i) for i in x[1:-1].split()])\n",
"test_split['cid'] = test_split['cid'].apply(lambda x: [int(i) for i in x[1:-1].split()])\n",
"\n",
"\n",
"# Mapping from corpus \n",
"mapping = dict(zip(corpus_data['cid'], corpus_data['text']))\n",
"\n",
"def get_context_list(cid_list):\n",
" return [mapping[cid] for cid in cid_list if cid in mapping]\n",
"\n",
"train_split['context_list'] = train_split['cid'].apply(get_context_list)\n",
"test_split['context_list'] = test_split['cid'].apply(get_context_list)\n",
"\n",
"train_split.head()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "e0450414",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"430 99 331\n",
"question \n",
"context \n",
"cid \n",
"qid \n",
"context_list \n"
]
}
],
"source": [
"# Debug\n",
"print(\n",
" len(train_split[train_split['context_list'].apply(len) != train_split['cid'].apply(len)]),\n",
" \n",
" len(\n",
" train_split[\n",
" (train_split['context_list'].apply(len) != train_split['cid'].apply(len)) &\n",
" (train_split['context_list'].apply(len) != 0)\n",
" ]\n",
" ),\n",
" \n",
" len(\n",
" train_split[\n",
" (train_split['context_list'].apply(len) != train_split['cid'].apply(len)) &\n",
" (train_split['context_list'].apply(len) == 0)\n",
" ]\n",
" )\n",
")\n",
"\n",
"for col in train_split.columns:\n",
" print(col, type(train_split[col][0]))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "fd1eb4a2",
"metadata": {},
"outputs": [],
"source": [
"# Drop invalid data\n",
"train_data = train_split.loc[\n",
" ~(train_split['context_list'].apply(len) != train_split['cid'].apply(len)), \n",
" ['question', 'context_list', 'qid', 'cid']\n",
"]\n",
"\n",
"test_data = test_split.loc[\n",
" ~(test_split['context_list'].apply(len) != test_split['cid'].apply(len)), \n",
" ['question', 'context_list', 'qid', 'cid']\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "3661c9cb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train data saved: 89162\n",
"Test data saved : 29723\n"
]
}
],
"source": [
"# Save the processed data to parquet files\n",
"corpus_data.to_parquet('data/processed/corpus_data.parquet', index=False)\n",
"train_data.to_parquet('data/processed/train_data.parquet', index=False)\n",
"test_data.to_parquet('data/processed/test_data.parquet', index=False)\n",
"\n",
"print(f\"Train data saved: {len(train_data)}\")\n",
"print(f\"Test data saved : {len(test_data)}\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "6382a715",
"metadata": {},
"outputs": [],
"source": [
"# # Get demo data\n",
"# os.makedirs('data/demo', exist_ok=True)\n",
"\n",
"# demo_corpus_data = corpus_data.sample(10, random_state=42).reset_index(drop=True)\n",
"# demo_train_data = train_data.sample(10, random_state=42).reset_index(drop=True)\n",
"# demo_test_data = test_data.sample(10, random_state=42).reset_index(drop=True)\n",
"\n",
"# demo_corpus_data.to_csv('data/demo/demo_corpus_data.csv', index=False)\n",
"# demo_train_data.to_csv('data/demo/demo_train_data.csv', index=False)\n",
"# demo_test_data.to_csv('data/demo/demo_test_data.csv', index=False)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "legal_doc_retrieval",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}