|
import numpy as np |
|
import pandas as pd |
|
import gradio as gr |
|
|
|
import faiss |
|
from huggingface_hub import hf_hub_download |
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
|
|
passages_path = hf_hub_download(repo_id='YuITC/Vietnamese-Legal-Doc-Retrieval-Data', |
|
filename='corpus_data.parquet', repo_type='dataset', |
|
local_dir='demo') |
|
|
|
index_path = hf_hub_download(repo_id='YuITC/Vietnamese-Legal-Doc-Retrieval-Data', |
|
filename='legal_faiss.index', repo_type='dataset', |
|
local_dir='demo') |
|
|
|
emb_model = SentenceTransformer('YuITC/bert-base-multilingual-cased-finetuned-VNLegalDocs') |
|
passages = pd.read_parquet(passages_path)['text'].tolist() |
|
legal_index = faiss.read_index(index_path) |
|
|
|
|
|
|
|
def retrieval(emb_model, query, index, top_k=10): |
|
q_emb = emb_model.encode( |
|
query, |
|
convert_to_numpy=True, normalize_embeddings=True, |
|
).astype(np.float32).reshape(1, -1) |
|
|
|
scores, indices = index.search(q_emb, top_k) |
|
cand_idxs = indices[0] |
|
cand_scores = scores[0] |
|
cand_texts = [passages[i] for i in cand_idxs] |
|
|
|
return [{'index': int(cand_idxs[i]), |
|
'score': float(cand_scores[i]), |
|
'text' : cand_texts[i] |
|
} for i in range(len(cand_idxs))] |
|
|
|
def get_results(query, top_k): |
|
hits = retrieval(emb_model, query, legal_index, top_k=top_k) |
|
|
|
result = "" |
|
for rank, h in enumerate(hits, start=1): |
|
result += f"[Kết quả {rank} - Độ tin cậy={h['score']:.4f}]\n\n{h['text']}\n{'-'*100}\n" |
|
return result |
|
|
|
|
|
|
|
demo = gr.Interface( |
|
fn=get_results, |
|
inputs=[ |
|
gr.Textbox(lines=2, placeholder='Nhập câu hỏi pháp lý của bạn...', label='Câu hỏi'), |
|
gr.Slider(minimum=5, maximum=20, value=10, step=1, label='Số lượng kết quả'), |
|
], |
|
outputs=gr.Textbox(lines=20, label='Kết quả'), |
|
title='Vietnamese Legal Document Retrieval System', |
|
description='🔍 Nhập câu hỏi pháp lý của bạn bằng tiếng Việt để nhận các đoạn văn bản pháp luật liên quan.', |
|
examples=[ |
|
['Tội xúc phạm danh dự?'], |
|
['Quyền lợi của người lao động?'], |
|
['Thủ tục đăng ký kết hôn?'], |
|
], |
|
flagging_mode='never' |
|
) |
|
|
|
if __name__ == '__main__': |
|
demo.launch() |