File size: 12,341 Bytes
daeec09
36704dc
524b722
 
 
 
 
4631bc7
524b722
4631bc7
524b722
16d3aa3
7d2afe0
 
 
 
 
 
 
 
 
 
a23d50e
 
7d2afe0
 
 
 
 
 
 
 
 
 
 
a23d50e
993c87f
a23d50e
 
 
 
 
b5c615a
a23d50e
 
 
 
993c87f
 
 
 
7d2afe0
daeec09
7d2afe0
 
a2ced42
993c87f
7d2afe0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a23d50e
7d2afe0
a23d50e
7d2afe0
 
 
 
 
 
 
 
 
 
 
58567e1
 
 
7d2afe0
 
993c87f
 
 
 
 
 
 
b1632ff
 
 
 
 
 
 
 
 
b47491b
b1632ff
 
993c87f
 
b1632ff
b47491b
b1632ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
993c87f
b1632ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5c615a
b1632ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
993c87f
b1632ff
 
 
 
 
 
b5c615a
b1632ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
993c87f
e955a84
b1632ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b9468a
 
 
 
 
 
 
 
 
 
 
b1632ff
7b9468a
 
993c87f
b1632ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e955a84
b1632ff
 
 
 
 
 
 
 
 
 
58567e1
b1632ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e955a84
 
b1632ff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
import gradio as gr
import torch
import spaces
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TextIteratorStreamer,
)
from threading import Thread

MODEL_ID = "speakleash/Bielik-11B-v2.3-Instruct"
MODEL_NAME = MODEL_ID.split("/")[-1]

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using GPU:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("CUDA is not available. Using CPU.")

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    quantization_config=quantization_config,
    low_cpu_mem_usage=True,
)


@spaces.GPU
def generate(
    user_input,
    temperature,
    max_tokens,
    top_k,
    repetition_penalty,
    top_p,
    prompt_style="",
):
    streamer = TextIteratorStreamer(
        tokenizer, skip_prompt=True, skip_special_tokens=True
    )
    system = f"""Jesteś pomocnym botem udzielającym odpowiedzi na pytania w języku polskim.
    Odpowiadaj krótko i zwięźle, unikaj zbyt skomplikowanych odpowiedzi.
    {prompt_style}
    """
    messages = []

    if system:
        messages.append({"role": "system", "content": system})

    messages.append({"role": "user", "content": user_input})

    tokenizer_output = tokenizer.apply_chat_template(
        messages, return_tensors="pt", return_dict=True
    )

    if torch.cuda.is_available():
        model_input_ids = tokenizer_output.input_ids.to(device)
        model_attention_mask = tokenizer_output.attention_mask.to(device)
    else:
        model_input_ids = tokenizer_output.input_ids
        model_attention_mask = tokenizer_output.attention_mask

    generate_kwargs = {
        "input_ids": model_input_ids,
        "attention_mask": model_attention_mask,
        "streamer": streamer,
        "do_sample": True if temperature else False,
        "temperature": temperature,
        "max_new_tokens": max_tokens,
        "top_k": top_k,
        "repetition_penalty": repetition_penalty,
        "top_p": top_p,
    }

    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    partial_response = ""
    for new_token in streamer:
        partial_response += new_token
        if "<|im_end|>" in partial_response or "<|endoftext|>" in partial_response:
            break
        # Strip leading whitespace and newlines
        cleaned_response = partial_response.lstrip("\n").lstrip()
        yield cleaned_response


STYLE_PROMPTS = {
    "Formalny": """Przekształć poniższy tekst na bardziej formalny, zachowując jego oryginalne znaczenie i klarowność.""", # noqa
    "Nieformalny": """Przekształć poniższy tekst na luźniejszy i bardziej nieformalny, tak żeby brzmiał swobodnie i naturalnie..""", # noqa
    "Neutralny": """Przekształć poniższy tekst na bardziej neutralny, eliminując zbyt formalne lub potoczne sformułowania.""", # noqa
}


with gr.Blocks(
    css="""
    .gradio-container { max-width: 1600px; margin: 20px; padding: 10px; }
    #style-dropdown { flex: 3; }
    #generate-btn, #clear-btn { flex: 1; max-width: 100px; }
    .same-height { height: 60px; }
"""
) as demo:
    gr.Markdown("# Bielik Tools - narzędzia dla modelu Bielik v2.3")

    with gr.Column(elem_id="main-content"):
        with gr.Row():
            simple_question_btn = gr.Button("Zadaj Pytanie", variant="primary")
            formalizer_btn = gr.Button("Zmiana stylu", variant="secondary")
            judge_btn = gr.Button("Sędzia", interactive=False)

        # Function to switch tool visibility and update button styles based on the active tool
        def switch_tool(tool):
            print(f"Switched to {tool}")
            return [
                gr.Button(variant="primary" if tool == "Formalizer" else "secondary"),
                gr.Button(variant="primary" if tool == "Judge" else "secondary"),
                gr.Button(
                    variant="primary" if tool == "Simple Question" else "secondary"
                ),
                gr.update(visible=(tool == "Formalizer")),
                gr.update(visible=(tool == "Judge")),
                gr.update(visible=(tool == "Simple Question")),
            ]

        # Simple Question content column
        with gr.Column(visible=True) as simple_question_column:
            input_text_sq = gr.Textbox(
                label="Twoje pytanie",
                placeholder="Zadaj swoje pytanie tutaj...",
                lines=5,
            )
            with gr.Row():
                generate_btn_sq = gr.Button("Generuj odpowiedź", interactive=False)
                clear_btn_sq = gr.Button("Wyczyść", interactive=False)
            output_text_sq = gr.Textbox(label="Odpowiedź", interactive=False, lines=5)

            with gr.Accordion("⚙️ Parametry", open=False):
                temperature_sq = gr.Slider(0, 1, 0.3, step=0.1, label="Temperatura")
                max_tokens_sq = gr.Slider(
                    128, 4096, 1024, label="Maksymalna długość odpowiedzi"
                )
                top_k_sq = gr.Slider(1, 80, 40, step=1, label="Top K")
                repetition_penalty_sq = gr.Slider(
                    0, 2, 1.1, step=0.1, label="Penalizacja powtórzeń"
                )
                top_p_sq = gr.Slider(0, 1, 0.95, step=0.05, label="Top P")

            # Update button states based on input and output text changes for interactivity
            def update_button_states_sq(input_text, output_text):
                return [
                    gr.update(interactive=bool(input_text)),
                    gr.update(interactive=bool(input_text)),
                    gr.update(interactive=bool(input_text or output_text)),
                ]

            input_text_sq.change(
                update_button_states_sq,
                inputs=[input_text_sq, output_text_sq],
                outputs=[generate_btn_sq, clear_btn_sq],
            )

            output_text_sq.change(
                update_button_states_sq,
                inputs=[input_text_sq, output_text_sq],
                outputs=[generate_btn_sq, clear_btn_sq],
            )

            # Event handlers for button actions to process and clear text
            generate_btn_sq.click(
                fn=generate,
                inputs=[
                    input_text_sq,
                    temperature_sq,
                    max_tokens_sq,
                    top_k_sq,
                    repetition_penalty_sq,
                    top_p_sq,
                ],
                outputs=output_text_sq,
            )

            clear_btn_sq.click(
                fn=lambda: ("", ""),
                inputs=None,
                outputs=[input_text_sq, output_text_sq],
            )

        with gr.Column(visible=False) as formalizer_column:
            input_text = gr.Textbox(
                placeholder="Wpisz tekst tutaj...", label="Twój tekst", lines=5
            )
            with gr.Row():
                gr.Text(
                    "Wybierz styl:",
                    elem_id="style-label",
                    show_label=False,
                    elem_classes="same-height",
                )
                style_dropdown = gr.Dropdown(
                    choices=["Formalny", "Nieformalny", "Neutralny"],
                    value="Neutralny",  # Set a default value
                    elem_id="style-dropdown",
                    show_label=False,
                    elem_classes="same-height",
                )
                generate_btn = gr.Button(
                    "Generuj",
                    interactive=False,
                    elem_id="generate-btn",
                    elem_classes="same-height",
                )
                clear_btn = gr.Button(
                    "Wyczyść",
                    interactive=False,
                    elem_id="clear-btn",
                    elem_classes="same-height",
                )
            output_text = gr.Textbox(label="Wynik", interactive=False, lines=5)

            # Update button states based on input and output text changes for interactivity
            def update_button_states(input_text, output_text):
                return [
                    gr.update(interactive=bool(input_text)),
                    gr.update(interactive=bool(input_text or output_text)),
                    gr.update(interactive=bool(output_text)),
                ]

            input_text.change(
                update_button_states,
                inputs=[input_text, output_text],
                outputs=[generate_btn, clear_btn],
            )

            output_text.change(
                update_button_states,
                inputs=[input_text, output_text],
                outputs=[generate_btn, clear_btn],
            )

            # Event handlers for button actions to process and clear text
            def format_with_style(text, style):
                return generate(
                    text,
                    temperature=0.3,
                    max_tokens=1024,
                    top_k=40,
                    repetition_penalty=1.1,
                    top_p=0.95,
                    prompt_style=STYLE_PROMPTS[style]
                )

            generate_btn.click(
                fn=format_with_style,
                inputs=[input_text, style_dropdown],
                outputs=output_text,
            )

            clear_btn.click(
                fn=lambda: ("", ""), inputs=None, outputs=[input_text, output_text]
            )

        # Placeholder for Judge content column, initially hidden
        with gr.Column(visible=False) as judge_column:
            gr.Markdown("Judge tool content goes here.")

            with gr.Accordion("⚙️ Parametry", open=False):
                temperature_jg = gr.Slider(0, 1, 0.3, step=0.1, label="Temperatura")
                max_tokens_jg = gr.Slider(
                    128, 4096, 1024, label="Maksymalna długość odpowiedzi"
                )
                top_k_jg = gr.Slider(1, 80, 40, step=1, label="Top K")
                repetition_penalty_jg = gr.Slider(
                    0, 2, 1.1, step=0.1, label="Penalizacja powtórzeń"
                )
                top_p_jg = gr.Slider(0, 1, 0.95, step=0.05, label="Top P")

        formalizer_btn.click(
            lambda: switch_tool("Formalizer"),
            outputs=[
                formalizer_btn,
                judge_btn,
                simple_question_btn,
                formalizer_column,
                judge_column,
                simple_question_column,
            ],
        )
        judge_btn.click(
            lambda: switch_tool("Judge"),
            outputs=[
                formalizer_btn,
                judge_btn,
                simple_question_btn,
                formalizer_column,
                judge_column,
                simple_question_column,
            ],
        )
        simple_question_btn.click(
            lambda: switch_tool("Simple Question"),
            outputs=[
                formalizer_btn,
                judge_btn,
                simple_question_btn,
                formalizer_column,
                judge_column,
                simple_question_column,
            ],
        )

    formalizer_btn.click(
        lambda: switch_tool("Formalizer"),
        outputs=[
            formalizer_btn,
            judge_btn,
            simple_question_btn,
            formalizer_column,
            judge_column,
            simple_question_column,
        ],
    )
    judge_btn.click(
        lambda: switch_tool("Judge"),
        outputs=[
            formalizer_btn,
            judge_btn,
            simple_question_btn,
            formalizer_column,
            judge_column,
            simple_question_column,
        ],
    )
    simple_question_btn.click(
        lambda: switch_tool("Simple Question"),
        outputs=[
            formalizer_btn,
            judge_btn,
            simple_question_btn,
            formalizer_column,
            judge_column,
            simple_question_column,
        ],
    )

demo.queue().launch()