janbanot commited on
Commit
daeec09
·
1 Parent(s): 6856dc0

fix: another test

Browse files
Files changed (1) hide show
  1. app.py +64 -162
app.py CHANGED
@@ -1,184 +1,86 @@
1
- import subprocess
2
- from threading import Thread
3
-
4
- import random
5
  import torch
6
  import spaces
7
- import gradio as gr
8
  from transformers import (
9
  AutoModelForCausalLM,
10
  AutoTokenizer,
11
  BitsAndBytesConfig,
12
- TextIteratorStreamer,
13
  )
14
 
15
- subprocess.run(
16
- "pip install flash-attn --no-build-isolation",
17
- env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
18
- shell=True,
19
- )
20
 
21
- MODEL_ID = "speakleash/Bielik-7B-Instruct-v0.1"
22
- CHAT_TEMPLATE = "ChatML"
23
- MODEL_NAME = MODEL_ID.split("/")[-1]
24
- CONTEXT_LENGTH = 1024
 
 
 
 
25
 
26
- # Load model
27
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
- quantization_config = BitsAndBytesConfig(
29
- load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16
30
- )
31
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
32
- tokenizer.pad_token = tokenizer.eos_token
33
- model = AutoModelForCausalLM.from_pretrained(
34
- MODEL_ID,
35
- device_map="auto",
36
- torch_dtype="auto",
37
- attn_implementation="flash_attention_2",
38
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
 
40
 
41
- @spaces.GPU()
42
- def generate(
43
- instruction,
44
- stop_tokens,
45
- temperature,
46
- max_new_tokens,
47
- top_k,
48
- repetition_penalty,
49
- top_p,
50
- ):
51
- streamer = TextIteratorStreamer(
52
- tokenizer, skip_prompt=True, skip_special_tokens=True
53
  )
54
- enc = tokenizer([instruction], return_tensors="pt", padding=True, truncation=True)
55
- input_ids, attention_mask = enc.input_ids, enc.attention_mask
56
 
57
- if input_ids.shape[1] > CONTEXT_LENGTH:
58
- input_ids = input_ids[:, -CONTEXT_LENGTH:]
 
 
 
 
 
 
59
 
60
- generate_kwargs = dict(
61
- {
62
- "input_ids": input_ids.to(device),
63
- "attention_mask": attention_mask.to(device),
64
- },
65
  streamer=streamer,
 
66
  do_sample=True if temperature else False,
67
  temperature=temperature,
68
- max_new_tokens=max_new_tokens,
69
  top_k=top_k,
70
- repetition_penalty=repetition_penalty,
71
  top_p=top_p,
72
  )
73
- t = Thread(target=model.generate, kwargs=generate_kwargs)
74
- t.start()
75
- outputs = []
76
- for new_token in streamer:
77
- outputs.append(new_token)
78
- if new_token in stop_tokens:
79
- break
80
- yield "".join(outputs)
81
-
82
-
83
- def predict(
84
- message,
85
- history,
86
- system_prompt,
87
- temperature,
88
- max_new_tokens,
89
- top_k,
90
- repetition_penalty,
91
- top_p,
92
- ):
93
- repetition_penalty = float(repetition_penalty)
94
- print(
95
- "LLL",
96
- [
97
- message,
98
- history,
99
- system_prompt,
100
- temperature,
101
- max_new_tokens,
102
- top_k,
103
- repetition_penalty,
104
- top_p,
105
- ],
106
- )
107
- # Format history with a given chat template
108
- if CHAT_TEMPLATE == "ChatML":
109
- stop_tokens = ["<|endoftext|>", "<|im_end|>"]
110
- instruction = "<|im_start|>system\n" + system_prompt + "\n<|im_end|>\n"
111
- for human, assistant in history:
112
- instruction += (
113
- "<|im_start|>user\n"
114
- + human
115
- + "\n<|im_end|>\n<|im_start|>assistant\n"
116
- + assistant
117
- )
118
- instruction += (
119
- "\n<|im_start|>user\n" + message + "\n<|im_end|>\n<|im_start|>assistant\n"
120
- )
121
- elif CHAT_TEMPLATE == "Mistral Instruct":
122
- stop_tokens = ["</s>", "[INST]", "[INST] ", "<s>", "[/INST]", "[/INST] "]
123
- instruction = "<s>[INST] " + system_prompt
124
- for human, assistant in history:
125
- instruction += human + " [/INST] " + assistant + "</s>[INST]"
126
- instruction += " " + message + " [/INST]"
127
- elif CHAT_TEMPLATE == "Bielik":
128
- stop_tokens = ["</s>"]
129
- prompt_builder = ["<s>[INST] "]
130
- if system_prompt:
131
- prompt_builder.append(f"<<SYS>>\n{system_prompt}\n<</SYS>>\n\n")
132
- for human, assistant in history:
133
- prompt_builder.append(f"{human} [/INST] {assistant}</s>[INST] ")
134
- prompt_builder.append(f"{message} [/INST]")
135
- instruction = "".join(prompt_builder)
136
- else:
137
- raise Exception(
138
- "Incorrect chat template, select 'ChatML' or 'Mistral Instruct'"
139
- )
140
- print(instruction)
141
-
142
- for output_text in generate(
143
- instruction,
144
- stop_tokens,
145
- temperature,
146
- max_new_tokens,
147
- top_k,
148
- repetition_penalty,
149
- top_p,
150
- ):
151
- yield output_text
152
-
153
-
154
- # Create Gradio interface
155
- def update_examples():
156
- exs = [["Kim jesteś?"], ["Ile to jest 9+2-1?"], ["Napisz mi coś miłego."]]
157
- random.shuffle(exs)
158
- return gr.Dataset(samples=exs)
159
-
160
-
161
- with gr.Blocks() as demo:
162
- chatbot = gr.Chatbot(label="Chatbot", render=False)
163
- chat = gr.ChatInterface(
164
- predict,
165
- chatbot=chatbot,
166
- title="Online chat demo",
167
- description="Opis testowy",
168
- examples=[["Kim jesteś?"], ["Ile to jest 9+2-1?"], ["Napisz mi coś miłego."]],
169
- additional_inputs_accordion=gr.Accordion(
170
- label="⚙️ Parameters", open=False, render=False
171
- ),
172
- additional_inputs=[
173
- gr.Textbox("", label="System prompt", render=False),
174
- gr.Slider(0, 1, 0.6, label="Temperature", render=False),
175
- gr.Slider(128, 4096, 1024, label="Max new tokens", render=False),
176
- gr.Slider(1, 80, 40, step=1, label="Top K sampling", render=False),
177
- gr.Slider(0, 2, 1.1, label="Repetition penalty", render=False),
178
- gr.Slider(0, 1, 0.95, label="Top P sampling", render=False),
179
- ],
180
- theme=gr.themes.Soft(),
181
- )
182
- demo.load(update_examples, None, chat.examples_handler.dataset)
183
 
184
- demo.queue(max_size=20).launch()
 
 
 
 
 
 
 
1
+ import gradio as gr
 
 
 
2
  import torch
3
  import spaces
 
4
  from transformers import (
5
  AutoModelForCausalLM,
6
  AutoTokenizer,
7
  BitsAndBytesConfig,
8
+ TextStreamer,
9
  )
10
 
 
 
 
 
 
11
 
12
+ @spaces.GPU
13
+ def test():
14
+ if torch.cuda.is_available():
15
+ device = torch.device("cuda")
16
+ print("Using GPU:", torch.cuda.get_device_name(0))
17
+ else:
18
+ device = torch.device("cpu")
19
+ print("CUDA is not available. Using CPU.")
20
 
21
+ device = "cuda"
22
+ model_name = "speakleash/Bielik-11B-v2.3-Instruct"
23
+
24
+ max_tokens = 5000
25
+ temperature = 0
26
+ top_k = 0
27
+ top_p = 0
28
+
29
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
30
+ tokenizer.pad_token = tokenizer.eos_token
31
+
32
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
33
+
34
+ quantization_config = BitsAndBytesConfig(
35
+ load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16
36
+ )
37
+
38
+ model = AutoModelForCausalLM.from_pretrained(
39
+ model_name,
40
+ torch_dtype=torch.bfloat16,
41
+ quantization_config=quantization_config,
42
+ low_cpu_mem_usage=True,
43
+ )
44
+
45
+ model.generation_config.pad_token_id = tokenizer.pad_token_id
46
+
47
+ prompt = "Kim jesteś?"
48
+ system = "Jesteś chatboem udzielającym odpowiedzi na pytania w języku polskim"
49
+ messages = []
50
+
51
+ if system:
52
+ messages.append({"role": "system", "content": system})
53
 
54
+ messages.append({"role": "user", "content": prompt})
55
 
56
+ tokenizer_output = tokenizer.apply_chat_template(
57
+ messages, return_tensors="pt", return_dict=True
 
 
 
 
 
 
 
 
 
 
58
  )
 
 
59
 
60
+ if torch.cuda.is_available():
61
+ model_input_ids = tokenizer_output.input_ids.to(device)
62
+
63
+ model_attention_mask = tokenizer_output.attention_mask.to(device)
64
+
65
+ else:
66
+ model_input_ids = tokenizer_output.input_ids
67
+ model_attention_mask = tokenizer_output.attention_mask
68
 
69
+ outputs = model.generate(
70
+ model_input_ids,
71
+ attention_mask=model_attention_mask,
 
 
72
  streamer=streamer,
73
+ max_new_tokens=max_tokens,
74
  do_sample=True if temperature else False,
75
  temperature=temperature,
 
76
  top_k=top_k,
 
77
  top_p=top_p,
78
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ answer = tokenizer.batch_decode(outputs, skip_special_tokens=False)
81
+
82
+ return answer
83
+
84
+
85
+ demo = gr.Interface(fn=test, inputs=None, outputs=gr.Text())
86
+ demo.launch()