janbanot commited on
Commit
524b722
·
1 Parent(s): a2ced42

fix: try again

Browse files
Files changed (1) hide show
  1. app.py +178 -20
app.py CHANGED
@@ -1,30 +1,188 @@
1
- import gradio as gr
2
- import spaces
 
 
 
3
  import torch
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- MODEL_NAME = "speakleash/Bielik-11B-v2.3-Instruct-GGUF"
7
- MODEL_FILE = "Bielik-11B-v2.3-Instruct.Q4_K_M.gguf"
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
- @spaces.GPU
11
- def test():
12
- device = torch.device("cuda")
13
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
14
- model = AutoModelForCausalLM.from_pretrained(
15
- MODEL_NAME,
16
- model_file=MODEL_FILE,
17
- model_type="mistral", gpu_layers=50, hf=True).to(device)
 
 
 
 
 
 
 
18
 
19
- inputs = tokenizer("Cześć Bielik, jak się masz?", return_tensors="pt").to(device)
 
20
 
21
- with torch.no_grad():
22
- outputs = model.generate(
23
- **inputs, max_new_tokens=128, pad_token_id=tokenizer.eos_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- demo = gr.Interface(fn=test, inputs=None, outputs=gr.Text())
30
- demo.launch()
 
1
+ import os
2
+ import subprocess
3
+ from threading import Thread
4
+
5
+ import random
6
  import torch
7
+ import spaces
8
+ import gradio as gr
9
+ from transformers import (
10
+ AutoModelForCausalLM,
11
+ AutoTokenizer,
12
+ BitsAndBytesConfig,
13
+ TextIteratorStreamer,
14
+ )
15
+
16
+ subprocess.run(
17
+ "pip install flash-attn --no-build-isolation",
18
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
19
+ shell=True,
20
+ )
21
+
22
+ MODEL_ID = "speakleash/Bielik-7B-Instruct-v0.1"
23
+ CHAT_TEMPLATE = "ChatML"
24
+ MODEL_NAME = MODEL_ID.split("/")[-1]
25
+ CONTEXT_LENGTH = 1024
26
+ COLOR = os.environ.get("COLOR")
27
+ EMOJI = os.environ.get("EMOJI")
28
+ DESCRIPTION = os.environ.get("DESCRIPTION")
29
 
30
+ # Load model
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ quantization_config = BitsAndBytesConfig(
33
+ load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16
34
+ )
35
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
36
+ tokenizer.pad_token = tokenizer.eos_token
37
+ model = AutoModelForCausalLM.from_pretrained(
38
+ MODEL_ID,
39
+ device_map="auto",
40
+ torch_dtype="auto",
41
+ attn_implementation="flash_attention_2",
42
+ )
43
 
44
 
45
+ @spaces.GPU()
46
+ def generate(
47
+ instruction,
48
+ stop_tokens,
49
+ temperature,
50
+ max_new_tokens,
51
+ top_k,
52
+ repetition_penalty,
53
+ top_p,
54
+ ):
55
+ streamer = TextIteratorStreamer(
56
+ tokenizer, skip_prompt=True, skip_special_tokens=True
57
+ )
58
+ enc = tokenizer([instruction], return_tensors="pt", padding=True, truncation=True)
59
+ input_ids, attention_mask = enc.input_ids, enc.attention_mask
60
 
61
+ if input_ids.shape[1] > CONTEXT_LENGTH:
62
+ input_ids = input_ids[:, -CONTEXT_LENGTH:]
63
 
64
+ generate_kwargs = dict(
65
+ {
66
+ "input_ids": input_ids.to(device),
67
+ "attention_mask": attention_mask.to(device),
68
+ },
69
+ streamer=streamer,
70
+ do_sample=True if temperature else False,
71
+ temperature=temperature,
72
+ max_new_tokens=max_new_tokens,
73
+ top_k=top_k,
74
+ repetition_penalty=repetition_penalty,
75
+ top_p=top_p,
76
+ )
77
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
78
+ t.start()
79
+ outputs = []
80
+ for new_token in streamer:
81
+ outputs.append(new_token)
82
+ if new_token in stop_tokens:
83
+ break
84
+ yield "".join(outputs)
85
+
86
+
87
+ def predict(
88
+ message,
89
+ history,
90
+ system_prompt,
91
+ temperature,
92
+ max_new_tokens,
93
+ top_k,
94
+ repetition_penalty,
95
+ top_p,
96
+ ):
97
+ repetition_penalty = float(repetition_penalty)
98
+ print(
99
+ "LLL",
100
+ [
101
+ message,
102
+ history,
103
+ system_prompt,
104
+ temperature,
105
+ max_new_tokens,
106
+ top_k,
107
+ repetition_penalty,
108
+ top_p,
109
+ ],
110
+ )
111
+ # Format history with a given chat template
112
+ if CHAT_TEMPLATE == "ChatML":
113
+ stop_tokens = ["<|endoftext|>", "<|im_end|>"]
114
+ instruction = "<|im_start|>system\n" + system_prompt + "\n<|im_end|>\n"
115
+ for human, assistant in history:
116
+ instruction += (
117
+ "<|im_start|>user\n"
118
+ + human
119
+ + "\n<|im_end|>\n<|im_start|>assistant\n"
120
+ + assistant
121
+ )
122
+ instruction += (
123
+ "\n<|im_start|>user\n" + message + "\n<|im_end|>\n<|im_start|>assistant\n"
124
+ )
125
+ elif CHAT_TEMPLATE == "Mistral Instruct":
126
+ stop_tokens = ["</s>", "[INST]", "[INST] ", "<s>", "[/INST]", "[/INST] "]
127
+ instruction = "<s>[INST] " + system_prompt
128
+ for human, assistant in history:
129
+ instruction += human + " [/INST] " + assistant + "</s>[INST]"
130
+ instruction += " " + message + " [/INST]"
131
+ elif CHAT_TEMPLATE == "Bielik":
132
+ stop_tokens = ["</s>"]
133
+ prompt_builder = ["<s>[INST] "]
134
+ if system_prompt:
135
+ prompt_builder.append(f"<<SYS>>\n{system_prompt}\n<</SYS>>\n\n")
136
+ for human, assistant in history:
137
+ prompt_builder.append(f"{human} [/INST] {assistant}</s>[INST] ")
138
+ prompt_builder.append(f"{message} [/INST]")
139
+ instruction = "".join(prompt_builder)
140
+ else:
141
+ raise Exception(
142
+ "Incorrect chat template, select 'ChatML' or 'Mistral Instruct'"
143
  )
144
+ print(instruction)
145
+
146
+ for output_text in generate(
147
+ instruction,
148
+ stop_tokens,
149
+ temperature,
150
+ max_new_tokens,
151
+ top_k,
152
+ repetition_penalty,
153
+ top_p,
154
+ ):
155
+ yield output_text
156
+
157
+
158
+ # Create Gradio interface
159
+ def update_examples():
160
+ exs = [["Kim jesteś?"], ["Ile to jest 9+2-1?"], ["Napisz mi coś miłego."]]
161
+ random.shuffle(exs)
162
+ return gr.Dataset(samples=exs)
163
 
 
164
 
165
+ with gr.Blocks() as demo:
166
+ chatbot = gr.Chatbot(label="Chatbot", likeable=True, render=False)
167
+ chat = gr.ChatInterface(
168
+ predict,
169
+ chatbot=chatbot,
170
+ title=EMOJI + " " + MODEL_NAME + " - online chat demo",
171
+ description=DESCRIPTION,
172
+ examples=[["Kim jesteś?"], ["Ile to jest 9+2-1?"], ["Napisz mi coś miłego."]],
173
+ additional_inputs_accordion=gr.Accordion(
174
+ label="⚙️ Parameters", open=False, render=False
175
+ ),
176
+ additional_inputs=[
177
+ gr.Textbox("", label="System prompt", render=False),
178
+ gr.Slider(0, 1, 0.6, label="Temperature", render=False),
179
+ gr.Slider(128, 4096, 1024, label="Max new tokens", render=False),
180
+ gr.Slider(1, 80, 40, step=1, label="Top K sampling", render=False),
181
+ gr.Slider(0, 2, 1.1, label="Repetition penalty", render=False),
182
+ gr.Slider(0, 1, 0.95, label="Top P sampling", render=False),
183
+ ],
184
+ theme=gr.themes.Soft(primary_hue=COLOR),
185
+ )
186
+ demo.load(update_examples, None, chat.examples_handler.dataset)
187
 
188
+ demo.queue(max_size=20).launch()