matthoffner commited on
Commit
ffb1641
·
1 Parent(s): e38650f

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +57 -0
utils.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from zipfile import ZipFile
4
+ from presets import *
5
+
6
+ def refresh_json_list(plain=False):
7
+ json_list = []
8
+ for root, dirs, files in os.walk("./index"):
9
+ for file in files:
10
+ if os.path.splitext(file)[1] == '.json':
11
+ json_list.append(os.path.splitext(file)[0])
12
+ if plain:
13
+ return json_list
14
+ return gr.Dropdown.update(choices=json_list)
15
+
16
+ def upload_file(file_obj):
17
+ files = []
18
+ with ZipFile(file_obj.name) as zfile:
19
+ for zinfo in zfile.infolist():
20
+ files.append(
21
+ {
22
+ "name": zinfo.filename,
23
+ }
24
+ )
25
+ return files
26
+
27
+ def reset_textbox():
28
+ return gr.update(value='')
29
+
30
+ def change_prompt_tmpl(tmpl_select):
31
+ new_tmpl = prompt_tmpl_dict[tmpl_select]
32
+ return gr.update(value=new_tmpl)
33
+
34
+ def change_refine_tmpl(refine_select):
35
+ new_tmpl = refine_tmpl_dict[refine_select]
36
+ return gr.update(value=new_tmpl)
37
+
38
+ def lock_params(index_type):
39
+ if index_type == "GPTVectorStoreIndex" or index_type == "GPTListIndex":
40
+ return gr.Slider.update(interactive=False, label="子节点数量(当前索引类型不可用)"), gr.Slider.update(interactive=False, label="每段关键词数量(当前索引类型不可用)")
41
+ elif index_type == "GPTTreeIndex":
42
+ return gr.Slider.update(interactive=True, label="子节点数量"), gr.Slider.update(interactive=False, label="每段关键词数量(当前索引类型不可用)")
43
+ elif index_type == "GPTKeywordTableIndex":
44
+ return gr.Slider.update(interactive=False, label="子节点数量(当前索引类型不可用)"), gr.Slider.update(interactive=True, label="每段关键词数量")
45
+
46
+ def add_space(text):
47
+ punctuations = {',': ', ', '。': '。 ', '?': '? ', '!': '! ', ':': ': ', ';': '; '}
48
+ for cn_punc, en_punc in punctuations.items():
49
+ text = text.replace(cn_punc, en_punc)
50
+ return text
51
+
52
+ ## create a test for parse_text
53
+ def parse_text(text):
54
+ lines = text.split("\n")
55
+ lines = [line for line in lines if line != ""]
56
+ text = "".join(lines)
57
+ return text