YuITC commited on
Commit
b3c55d5
·
1 Parent(s): 500f44b

fix: correct passage loading and retrieval function call in main.py

Browse files
Files changed (1) hide show
  1. main.py +3 -3
main.py CHANGED
@@ -19,7 +19,7 @@ index_path = hf_hub_download(repo_id='YuITC/Vietnamese-Legal-Doc-Retrieval-Data'
19
  local_dir='demo')
20
 
21
  emb_model = SentenceTransformer('YuITC/bert-base-multilingual-cased-finetuned-VNLegalDocs')
22
- passages = pd.read_parquet(passages_path)
23
  legal_index = faiss.read_index(index_path)
24
 
25
 
@@ -41,7 +41,7 @@ def retrieval(emb_model, query, index, top_k=10):
41
  } for i in range(len(cand_idxs))]
42
 
43
  def get_results(query, top_k):
44
- hits = retrieval(fine_tuned_model, query, legal_index, top_k=top_k)
45
 
46
  result = ""
47
  for rank, h in enumerate(hits, start=1):
@@ -68,4 +68,4 @@ demo = gr.Interface(
68
  )
69
 
70
  if __name__ == '__main__':
71
- demo.launch(share=True)
 
19
  local_dir='demo')
20
 
21
  emb_model = SentenceTransformer('YuITC/bert-base-multilingual-cased-finetuned-VNLegalDocs')
22
+ passages = pd.read_parquet(passages_path)['text'].tolist()
23
  legal_index = faiss.read_index(index_path)
24
 
25
 
 
41
  } for i in range(len(cand_idxs))]
42
 
43
  def get_results(query, top_k):
44
+ hits = retrieval(emb_model, query, legal_index, top_k=top_k)
45
 
46
  result = ""
47
  for rank, h in enumerate(hits, start=1):
 
68
  )
69
 
70
  if __name__ == '__main__':
71
+ demo.launch()