File size: 2,490 Bytes
0063d17
 
 
 
 
500f44b
0063d17
 
 
226ff74
 
 
 
 
 
 
 
 
 
b3c55d5
226ff74
 
 
 
0063d17
 
 
226ff74
0063d17
 
226ff74
 
 
 
0063d17
226ff74
 
 
 
0063d17
 
b3c55d5
0063d17
 
 
 
 
 
 
226ff74
0063d17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3c55d5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import numpy as np
import pandas as pd
import gradio as gr

import faiss
from huggingface_hub import hf_hub_download
from sentence_transformers import SentenceTransformer


# ===== Prepare model & data =====
passages_path = hf_hub_download(repo_id='YuITC/Vietnamese-Legal-Doc-Retrieval-Data',
                                filename='corpus_data.parquet', repo_type='dataset', 
                                local_dir='demo')

index_path = hf_hub_download(repo_id='YuITC/Vietnamese-Legal-Doc-Retrieval-Data', 
                             filename='legal_faiss.index', repo_type='dataset', 
                             local_dir='demo')

emb_model   = SentenceTransformer('YuITC/bert-base-multilingual-cased-finetuned-VNLegalDocs')
passages    = pd.read_parquet(passages_path)['text'].tolist()
legal_index = faiss.read_index(index_path)


# ===== Utility functions =====
def retrieval(emb_model, query, index, top_k=10):
    q_emb = emb_model.encode(
        query, 
        convert_to_numpy=True, normalize_embeddings=True,
    ).astype(np.float32).reshape(1, -1)
    
    scores, indices = index.search(q_emb, top_k)
    cand_idxs       = indices[0]
    cand_scores     = scores[0]
    cand_texts      = [passages[i] for i in cand_idxs]

    return [{'index': int(cand_idxs[i]),
             'score': float(cand_scores[i]),
             'text' : cand_texts[i]
            } for i in range(len(cand_idxs))]

def get_results(query, top_k):
    hits = retrieval(emb_model, query, legal_index, top_k=top_k)
    
    result = ""
    for rank, h in enumerate(hits, start=1):
        result += f"[Kết quả {rank} - Độ tin cậy={h['score']:.4f}]\n\n{h['text']}\n{'-'*100}\n"
    return result

    
# ===== Gradio UI =====
demo = gr.Interface(
    fn=get_results,
    inputs=[
        gr.Textbox(lines=2, placeholder='Nhập câu hỏi pháp lý của bạn...', label='Câu hỏi'),
        gr.Slider(minimum=5, maximum=20, value=10, step=1, label='Số lượng kết quả'),
    ],
    outputs=gr.Textbox(lines=20, label='Kết quả'),
    title='Vietnamese Legal Document Retrieval System',
    description='🔍 Nhập câu hỏi pháp lý của bạn bằng tiếng Việt để nhận các đoạn văn bản pháp luật liên quan.',
    examples=[
        ['Tội xúc phạm danh dự?'],
        ['Quyền lợi của người lao động?'],
        ['Thủ tục đăng ký kết hôn?'],
    ],
    flagging_mode='never'
)

if __name__ == '__main__':
    demo.launch()