janbanot's picture
Revert "fix: refactor"
7d2afe0
raw
history blame
2.7 kB
import gradio as gr
import torch
import spaces
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TextIteratorStreamer,
)
from threading import Thread
MODEL_ID = "speakleash/Bielik-11B-v2.3-Instruct"
MODEL_NAME = MODEL_ID.split("/")[-1]
if torch.cuda.is_available():
device = torch.device("cuda")
print("Using GPU:", torch.cuda.get_device_name(0))
else:
device = torch.device("cpu")
print("CUDA is not available. Using CPU.")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
quantization_config=quantization_config,
low_cpu_mem_usage=True,
)
@spaces.GPU
def test(prompt):
max_tokens = 5000
temperature = 0
top_k = 0
top_p = 0
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
system = "Jesteś chatboem udzielającym odpowiedzi na pytania w języku polskim"
messages = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
tokenizer_output = tokenizer.apply_chat_template(
messages, return_tensors="pt", return_dict=True
)
if torch.cuda.is_available():
model_input_ids = tokenizer_output.input_ids.to(device)
model_attention_mask = tokenizer_output.attention_mask.to(device)
else:
model_input_ids = tokenizer_output.input_ids
model_attention_mask = tokenizer_output.attention_mask
generate_kwargs = {
"input_ids": model_input_ids,
"attention_mask": model_attention_mask,
"streamer": streamer,
"max_new_tokens": max_tokens,
"do_sample": True if temperature else False,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
}
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
partial_response = ""
for new_token in streamer:
partial_response += new_token
# Stop if we hit any of the special tokens
if "<|im_end|>" in partial_response or "<|endoftext|>" in partial_response:
break
yield partial_response
demo = gr.Interface(
fn=test,
inputs=gr.Textbox(label="Your question", placeholder="Type your question here..."),
outputs=gr.Textbox(label="Answer", lines=5),
title="Polish Chatbot",
description="Ask questions in Polish to the Bielik-11B-v2.3-Instruct model"
)
demo.launch()