import os import subprocess from threading import Thread import random import torch import spaces import gradio as gr from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer, ) subprocess.run( "pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True, ) MODEL_ID = "speakleash/Bielik-7B-Instruct-v0.1" CHAT_TEMPLATE = "ChatML" MODEL_NAME = MODEL_ID.split("/")[-1] CONTEXT_LENGTH = 1024 COLOR = os.environ.get("COLOR") EMOJI = os.environ.get("EMOJI") DESCRIPTION = os.environ.get("DESCRIPTION") # Load model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16 ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="auto", torch_dtype="auto", attn_implementation="flash_attention_2", ) @spaces.GPU() def generate( instruction, stop_tokens, temperature, max_new_tokens, top_k, repetition_penalty, top_p, ): streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True ) enc = tokenizer([instruction], return_tensors="pt", padding=True, truncation=True) input_ids, attention_mask = enc.input_ids, enc.attention_mask if input_ids.shape[1] > CONTEXT_LENGTH: input_ids = input_ids[:, -CONTEXT_LENGTH:] generate_kwargs = dict( { "input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device), }, streamer=streamer, do_sample=True if temperature else False, temperature=temperature, max_new_tokens=max_new_tokens, top_k=top_k, repetition_penalty=repetition_penalty, top_p=top_p, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] for new_token in streamer: outputs.append(new_token) if new_token in stop_tokens: break yield "".join(outputs) def predict( message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p, ): repetition_penalty = float(repetition_penalty) print( "LLL", [ message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p, ], ) # Format history with a given chat template if CHAT_TEMPLATE == "ChatML": stop_tokens = ["<|endoftext|>", "<|im_end|>"] instruction = "<|im_start|>system\n" + system_prompt + "\n<|im_end|>\n" for human, assistant in history: instruction += ( "<|im_start|>user\n" + human + "\n<|im_end|>\n<|im_start|>assistant\n" + assistant ) instruction += ( "\n<|im_start|>user\n" + message + "\n<|im_end|>\n<|im_start|>assistant\n" ) elif CHAT_TEMPLATE == "Mistral Instruct": stop_tokens = ["", "[INST]", "[INST] ", "", "[/INST]", "[/INST] "] instruction = "[INST] " + system_prompt for human, assistant in history: instruction += human + " [/INST] " + assistant + "[INST]" instruction += " " + message + " [/INST]" elif CHAT_TEMPLATE == "Bielik": stop_tokens = [""] prompt_builder = ["[INST] "] if system_prompt: prompt_builder.append(f"<>\n{system_prompt}\n<>\n\n") for human, assistant in history: prompt_builder.append(f"{human} [/INST] {assistant}[INST] ") prompt_builder.append(f"{message} [/INST]") instruction = "".join(prompt_builder) else: raise Exception( "Incorrect chat template, select 'ChatML' or 'Mistral Instruct'" ) print(instruction) for output_text in generate( instruction, stop_tokens, temperature, max_new_tokens, top_k, repetition_penalty, top_p, ): yield output_text # Create Gradio interface def update_examples(): exs = [["Kim jesteś?"], ["Ile to jest 9+2-1?"], ["Napisz mi coś miłego."]] random.shuffle(exs) return gr.Dataset(samples=exs) with gr.Blocks() as demo: chatbot = gr.Chatbot(label="Chatbot", likeable=True, render=False) chat = gr.ChatInterface( predict, chatbot=chatbot, title=EMOJI + " " + MODEL_NAME + " - online chat demo", description=DESCRIPTION, examples=[["Kim jesteś?"], ["Ile to jest 9+2-1?"], ["Napisz mi coś miłego."]], additional_inputs_accordion=gr.Accordion( label="⚙️ Parameters", open=False, render=False ), additional_inputs=[ gr.Textbox("", label="System prompt", render=False), gr.Slider(0, 1, 0.6, label="Temperature", render=False), gr.Slider(128, 4096, 1024, label="Max new tokens", render=False), gr.Slider(1, 80, 40, step=1, label="Top K sampling", render=False), gr.Slider(0, 2, 1.1, label="Repetition penalty", render=False), gr.Slider(0, 1, 0.95, label="Top P sampling", render=False), ], theme=gr.themes.Soft(primary_hue=COLOR), ) demo.load(update_examples, None, chat.examples_handler.dataset) demo.queue(max_size=20).launch()