YuITC commited on
Commit
226ff74
·
1 Parent(s): f122677

refactor: update data loading and organization in main.py

Browse files
Files changed (2) hide show
  1. .gitignore +1 -1
  2. main.py +26 -21
.gitignore CHANGED
@@ -1,8 +1,8 @@
1
  __pycache__/
2
  .gradio/
3
  cache/
4
- data/original/
5
  models/
6
  data/
7
  tmp/
 
8
  .env
 
1
  __pycache__/
2
  .gradio/
3
  cache/
 
4
  models/
5
  data/
6
  tmp/
7
+ demo/
8
  .env
main.py CHANGED
@@ -4,35 +4,40 @@ import pandas as pd
4
  import gradio as gr
5
 
6
  import faiss
 
7
  from sentence_transformers import SentenceTransformer
8
- from settings import OUTPUT_DIR, DEVICE
9
- os.environ['WANDB_DISABLED'] = 'true'
10
 
11
-
12
- fine_tuned_model = SentenceTransformer(OUTPUT_DIR, device=DEVICE)
13
- passages = pd.read_parquet('data/processed/corpus_data.parquet')['text'].tolist()
14
- legal_index = faiss.read_index('data/retrieval/legal_faiss.index')
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def retrieval(emb_model, query, index, top_k=10):
17
  q_emb = emb_model.encode(
18
  query,
19
- convert_to_numpy=True,
20
- normalize_embeddings=True,
21
  ).astype(np.float32).reshape(1, -1)
22
 
23
- scores, indices = index.search(q_emb, top_k) # shape: (1, top_k)
24
-
25
- cand_idxs = indices[0]
26
- cand_scores = scores[0]
27
- cand_texts = [passages[i] for i in cand_idxs]
28
 
29
- results = [{
30
- 'index': int(cand_idxs[i]),
31
- 'score': float(cand_scores[i]),
32
- 'text': cand_texts[i]
33
- } for i in range(len(cand_idxs))]
34
-
35
- return results
36
 
37
  def get_results(query, top_k):
38
  hits = retrieval(fine_tuned_model, query, legal_index, top_k=top_k)
@@ -43,8 +48,8 @@ def get_results(query, top_k):
43
  return result
44
 
45
 
 
46
  demo = gr.Interface(
47
- 'huggingface/YuITC/bert-base-multilingual-cased-finetuned-VNLegalDocs',
48
  fn=get_results,
49
  inputs=[
50
  gr.Textbox(lines=2, placeholder='Nhập câu hỏi pháp lý của bạn...', label='Câu hỏi'),
 
4
  import gradio as gr
5
 
6
  import faiss
7
+ from datasets import load_dataset
8
  from sentence_transformers import SentenceTransformer
 
 
9
 
 
 
 
 
10
 
11
+ # ===== Prepare model & data =====
12
+ passages_path = hf_hub_download(repo_id='YuITC/Vietnamese-Legal-Doc-Retrieval-Data',
13
+ filename='corpus_data.parquet', repo_type='dataset',
14
+ local_dir='demo')
15
+
16
+ index_path = hf_hub_download(repo_id='YuITC/Vietnamese-Legal-Doc-Retrieval-Data',
17
+ filename='legal_faiss.index', repo_type='dataset',
18
+ local_dir='demo')
19
+
20
+ emb_model = SentenceTransformer('YuITC/bert-base-multilingual-cased-finetuned-VNLegalDocs')
21
+ passages = pd.read_parquet(passages_path)
22
+ legal_index = faiss.read_index(index_path)
23
+
24
+
25
+ # ===== Utility functions =====
26
  def retrieval(emb_model, query, index, top_k=10):
27
  q_emb = emb_model.encode(
28
  query,
29
+ convert_to_numpy=True, normalize_embeddings=True,
 
30
  ).astype(np.float32).reshape(1, -1)
31
 
32
+ scores, indices = index.search(q_emb, top_k)
33
+ cand_idxs = indices[0]
34
+ cand_scores = scores[0]
35
+ cand_texts = [passages[i] for i in cand_idxs]
 
36
 
37
+ return [{'index': int(cand_idxs[i]),
38
+ 'score': float(cand_scores[i]),
39
+ 'text' : cand_texts[i]
40
+ } for i in range(len(cand_idxs))]
 
 
 
41
 
42
  def get_results(query, top_k):
43
  hits = retrieval(fine_tuned_model, query, legal_index, top_k=top_k)
 
48
  return result
49
 
50
 
51
+ # ===== Gradio UI =====
52
  demo = gr.Interface(
 
53
  fn=get_results,
54
  inputs=[
55
  gr.Textbox(lines=2, placeholder='Nhập câu hỏi pháp lý của bạn...', label='Câu hỏi'),