janbanot commited on
Commit
b1c28de
·
1 Parent(s): 4631bc7

fix: refactor

Browse files

chore: refactor

fix: wrong parameter name

Files changed (1) hide show
  1. app.py +143 -78
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import torch
3
  import spaces
@@ -9,86 +10,150 @@ from transformers import (
9
  )
10
  from threading import Thread
11
 
 
12
  MODEL_ID = "speakleash/Bielik-11B-v2.3-Instruct"
13
- MODEL_NAME = MODEL_ID.split("/")[-1]
14
-
15
- if torch.cuda.is_available():
16
- device = torch.device("cuda")
17
- print("Using GPU:", torch.cuda.get_device_name(0))
18
- else:
19
- device = torch.device("cpu")
20
- print("CUDA is not available. Using CPU.")
21
-
22
- quantization_config = BitsAndBytesConfig(
23
- load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  )
25
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
26
- tokenizer.pad_token = tokenizer.eos_token
27
- model = AutoModelForCausalLM.from_pretrained(
28
- MODEL_ID,
29
- torch_dtype=torch.bfloat16,
30
- quantization_config=quantization_config,
31
- low_cpu_mem_usage=True,
32
- )
33
-
34
-
35
- @spaces.GPU
36
- def test(prompt):
37
- max_tokens = 5000
38
- temperature = 0
39
- top_k = 0
40
- top_p = 0
41
 
42
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
43
- system = "Jesteś chatboem udzielającym odpowiedzi na pytania w języku polskim"
44
- messages = []
45
 
46
- if system:
47
- messages.append({"role": "system", "content": system})
 
 
48
 
49
- messages.append({"role": "user", "content": prompt})
50
-
51
- tokenizer_output = tokenizer.apply_chat_template(
52
- messages, return_tensors="pt", return_dict=True
53
- )
54
-
55
- if torch.cuda.is_available():
56
- model_input_ids = tokenizer_output.input_ids.to(device)
57
-
58
- model_attention_mask = tokenizer_output.attention_mask.to(device)
59
-
60
- else:
61
- model_input_ids = tokenizer_output.input_ids
62
- model_attention_mask = tokenizer_output.attention_mask
63
-
64
- generate_kwargs = {
65
- "input_ids": model_input_ids,
66
- "attention_mask": model_attention_mask,
67
- "streamer": streamer,
68
- "max_new_tokens": max_tokens,
69
- "do_sample": True if temperature else False,
70
- "temperature": temperature,
71
- "top_k": top_k,
72
- "top_p": top_p,
73
- }
74
-
75
- t = Thread(target=model.generate, kwargs=generate_kwargs)
76
- t.start()
77
-
78
- partial_response = ""
79
- for new_token in streamer:
80
- partial_response += new_token
81
- # Stop if we hit any of the special tokens
82
- if "<|im_end|>" in partial_response or "<|endoftext|>" in partial_response:
83
- break
84
- yield partial_response
85
-
86
-
87
- demo = gr.Interface(
88
- fn=test,
89
- inputs=gr.Textbox(label="Your question", placeholder="Type your question here..."),
90
- outputs=gr.Textbox(label="Answer", lines=5),
91
- title="Polish Chatbot",
92
- description="Ask questions in Polish to the Bielik-11B-v2.3-Instruct model"
93
- )
94
- demo.launch()
 
1
+ from typing import Dict, Generator, List, Optional
2
  import gradio as gr
3
  import torch
4
  import spaces
 
10
  )
11
  from threading import Thread
12
 
13
+ # Configuration
14
  MODEL_ID = "speakleash/Bielik-11B-v2.3-Instruct"
15
+ SYSTEM_PROMPT = "Jesteś chatboem udzielającym odpowiedzi na pytania w języku polskim"
16
+ DEFAULT_GENERATION_PARAMS = {
17
+ "max_new_tokens": 5000,
18
+ "temperature": 0,
19
+ "top_k": 0,
20
+ "top_p": 0,
21
+ }
22
+
23
+
24
+ class ModelLoader:
25
+ """Handles model loading and device setup"""
26
+
27
+ def __init__(self, model_id: str):
28
+ self.device = self._get_device()
29
+ self.quantization_config = BitsAndBytesConfig(
30
+ load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16
31
+ )
32
+ self.tokenizer = self._load_tokenizer(model_id)
33
+ self.model = self._load_model(model_id)
34
+
35
+ def _get_device(self) -> torch.device:
36
+ """Determine and return the appropriate device"""
37
+ if torch.cuda.is_available():
38
+ device = torch.device("cuda")
39
+ print(f"Using GPU: {torch.cuda.get_device_name(0)}")
40
+ else:
41
+ device = torch.device("cpu")
42
+ print("CUDA is not available. Using CPU.")
43
+ return device
44
+
45
+ def _load_tokenizer(self, model_id: str) -> AutoTokenizer:
46
+ """Load and configure the tokenizer"""
47
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
48
+ tokenizer.pad_token = tokenizer.eos_token
49
+ return tokenizer
50
+
51
+ def _load_model(self, model_id: str) -> AutoModelForCausalLM:
52
+ """Load and configure the model"""
53
+ return AutoModelForCausalLM.from_pretrained(
54
+ model_id,
55
+ torch_dtype=torch.bfloat16,
56
+ quantization_config=self.quantization_config,
57
+ low_cpu_mem_usage=True,
58
+ device_map="auto",
59
+ )
60
+
61
+
62
+ class ChatInterface:
63
+ """Handles chat interactions and response generation"""
64
+
65
+ def __init__(self, model_loader: ModelLoader):
66
+ self.model = model_loader.model
67
+ self.tokenizer = model_loader.tokenizer
68
+ self.device = model_loader.device
69
+
70
+ @spaces.GPU
71
+ def generate_response(
72
+ self, prompt: str, system_prompt: Optional[str] = None
73
+ ) -> Generator[str, None, None]:
74
+ """Generate streaming response for the given prompt"""
75
+ generation_params = DEFAULT_GENERATION_PARAMS.copy()
76
+ streamer = TextIteratorStreamer(
77
+ self.tokenizer, skip_prompt=True, skip_special_tokens=True
78
+ )
79
+
80
+ messages = self._build_messages(prompt, system_prompt or SYSTEM_PROMPT)
81
+ tokenizer_output = self._prepare_inputs(messages)
82
+
83
+ generate_kwargs = {
84
+ **generation_params,
85
+ **tokenizer_output,
86
+ "streamer": streamer,
87
+ "do_sample": bool(generation_params["temperature"]),
88
+ }
89
+
90
+ self._start_generation_thread(generate_kwargs)
91
+ yield from self._stream_response(streamer)
92
+
93
+ def _build_messages(self, prompt: str, system_prompt: str) -> List[Dict[str, str]]:
94
+ """Build the message structure for the model"""
95
+ messages = [{"role": "system", "content": system_prompt}]
96
+ messages.append({"role": "user", "content": prompt})
97
+ return messages
98
+
99
+ def _prepare_inputs(
100
+ self, messages: List[Dict[str, str]]
101
+ ) -> Dict[str, torch.Tensor]:
102
+ """Prepare model inputs from messages"""
103
+ tokenizer_output = self.tokenizer.apply_chat_template(
104
+ messages, return_tensors="pt", return_dict=True
105
+ )
106
+
107
+ # Ensure all tensors are on the correct device
108
+ inputs = {
109
+ "input_ids": tokenizer_output.input_ids.to(self.device),
110
+ "attention_mask": tokenizer_output.attention_mask.to(self.device),
111
+ }
112
+
113
+ # Move model to device if not already there
114
+ if self.model.device != self.device:
115
+ self.model.to(self.device)
116
+
117
+ return inputs
118
+
119
+ def _start_generation_thread(self, generate_kwargs: Dict):
120
+ """Start model generation in a separate thread"""
121
+ t = Thread(target=self.model.generate, kwargs=generate_kwargs)
122
+ t.start()
123
+
124
+ def _stream_response(
125
+ self, streamer: TextIteratorStreamer
126
+ ) -> Generator[str, None, None]:
127
+ """Stream the response token by token"""
128
+ partial_response = ""
129
+ for new_token in streamer:
130
+ partial_response += new_token
131
+ if any(
132
+ stop_token in partial_response
133
+ for stop_token in ["<|im_end|>", "<|endoftext|>"]
134
+ ):
135
+ break
136
+ yield partial_response
137
+
138
+
139
+ def create_gradio_interface(chat_interface: ChatInterface) -> gr.Interface:
140
+ """Create and configure the Gradio interface"""
141
+ return gr.Interface(
142
+ fn=chat_interface.generate_response,
143
+ inputs=gr.Textbox(
144
+ label="Your question", placeholder="Type your question here..."
145
+ ),
146
+ outputs=gr.Textbox(label="Answer", lines=5),
147
+ title="Polish Chatbot",
148
+ description="Ask questions in Polish to the Bielik-11B-v2.3-Instruct model",
149
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
 
 
 
151
 
152
+ if __name__ == "__main__":
153
+ # Initialize components
154
+ model_loader = ModelLoader(MODEL_ID)
155
+ chat_interface = ChatInterface(model_loader)
156
 
157
+ # Create and launch interface
158
+ demo = create_gradio_interface(chat_interface)
159
+ demo.launch()