janbanot commited on
Commit
a23d50e
·
1 Parent(s): 7d2afe0

fix: refactor + interface change

Browse files
Files changed (1) hide show
  1. app.py +25 -16
app.py CHANGED
@@ -20,8 +20,8 @@ else:
20
  print("CUDA is not available. Using CPU.")
21
 
22
  quantization_config = BitsAndBytesConfig(
23
- load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16
24
- )
25
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
26
  tokenizer.pad_token = tokenizer.eos_token
27
  model = AutoModelForCausalLM.from_pretrained(
@@ -33,13 +33,17 @@ model = AutoModelForCausalLM.from_pretrained(
33
 
34
 
35
  @spaces.GPU
36
- def test(prompt):
37
- max_tokens = 5000
38
- temperature = 0
39
- top_k = 0
40
- top_p = 0
41
-
42
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
 
 
 
43
  system = "Jesteś chatboem udzielającym odpowiedzi na pytania w języku polskim"
44
  messages = []
45
 
@@ -54,9 +58,7 @@ def test(prompt):
54
 
55
  if torch.cuda.is_available():
56
  model_input_ids = tokenizer_output.input_ids.to(device)
57
-
58
  model_attention_mask = tokenizer_output.attention_mask.to(device)
59
-
60
  else:
61
  model_input_ids = tokenizer_output.input_ids
62
  model_attention_mask = tokenizer_output.attention_mask
@@ -65,10 +67,11 @@ def test(prompt):
65
  "input_ids": model_input_ids,
66
  "attention_mask": model_attention_mask,
67
  "streamer": streamer,
68
- "max_new_tokens": max_tokens,
69
  "do_sample": True if temperature else False,
70
  "temperature": temperature,
 
71
  "top_k": top_k,
 
72
  "top_p": top_p,
73
  }
74
 
@@ -78,17 +81,23 @@ def test(prompt):
78
  partial_response = ""
79
  for new_token in streamer:
80
  partial_response += new_token
81
- # Stop if we hit any of the special tokens
82
  if "<|im_end|>" in partial_response or "<|endoftext|>" in partial_response:
83
  break
84
  yield partial_response
85
 
86
 
87
  demo = gr.Interface(
88
- fn=test,
89
- inputs=gr.Textbox(label="Your question", placeholder="Type your question here..."),
 
 
 
 
 
 
 
90
  outputs=gr.Textbox(label="Answer", lines=5),
91
  title="Polish Chatbot",
92
- description="Ask questions in Polish to the Bielik-11B-v2.3-Instruct model"
93
  )
94
  demo.launch()
 
20
  print("CUDA is not available. Using CPU.")
21
 
22
  quantization_config = BitsAndBytesConfig(
23
+ load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16
24
+ )
25
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
26
  tokenizer.pad_token = tokenizer.eos_token
27
  model = AutoModelForCausalLM.from_pretrained(
 
33
 
34
 
35
  @spaces.GPU
36
+ def generate(
37
+ prompt,
38
+ temperature,
39
+ max_tokens,
40
+ top_k,
41
+ repetition_penalty,
42
+ top_p,
43
+ ):
44
+ streamer = TextIteratorStreamer(
45
+ tokenizer, skip_prompt=True, skip_special_tokens=True
46
+ )
47
  system = "Jesteś chatboem udzielającym odpowiedzi na pytania w języku polskim"
48
  messages = []
49
 
 
58
 
59
  if torch.cuda.is_available():
60
  model_input_ids = tokenizer_output.input_ids.to(device)
 
61
  model_attention_mask = tokenizer_output.attention_mask.to(device)
 
62
  else:
63
  model_input_ids = tokenizer_output.input_ids
64
  model_attention_mask = tokenizer_output.attention_mask
 
67
  "input_ids": model_input_ids,
68
  "attention_mask": model_attention_mask,
69
  "streamer": streamer,
 
70
  "do_sample": True if temperature else False,
71
  "temperature": temperature,
72
+ "max_new_tokens": max_tokens,
73
  "top_k": top_k,
74
+ "repetition_penalty": repetition_penalty,
75
  "top_p": top_p,
76
  }
77
 
 
81
  partial_response = ""
82
  for new_token in streamer:
83
  partial_response += new_token
 
84
  if "<|im_end|>" in partial_response or "<|endoftext|>" in partial_response:
85
  break
86
  yield partial_response
87
 
88
 
89
  demo = gr.Interface(
90
+ fn=generate,
91
+ inputs=[
92
+ gr.Textbox(label="Your question", placeholder="Type your question here..."),
93
+ gr.Slider(0, 1, 0.6, label="Temperature"),
94
+ gr.Slider(128, 4096, 1024, label="Max new tokens"),
95
+ gr.Slider(1, 80, 40, step=1, label="Top K sampling"),
96
+ gr.Slider(0, 2, 1.1, label="Repetition penalty"),
97
+ gr.Slider(0, 1, 0.95, label="Top P sampling"),
98
+ ],
99
  outputs=gr.Textbox(label="Answer", lines=5),
100
  title="Polish Chatbot",
101
+ description="Ask questions in Polish to the Bielik-11B-v2.3-Instruct model",
102
  )
103
  demo.launch()