janbanot commited on
Commit
7d2afe0
·
1 Parent(s): b1c28de

Revert "fix: refactor"

Browse files

This reverts commit b1c28de92515add5f0b6debbd169c837aa7b9be6.

Files changed (1) hide show
  1. app.py +78 -143
app.py CHANGED
@@ -1,4 +1,3 @@
1
- from typing import Dict, Generator, List, Optional
2
  import gradio as gr
3
  import torch
4
  import spaces
@@ -10,150 +9,86 @@ from transformers import (
10
  )
11
  from threading import Thread
12
 
13
- # Configuration
14
  MODEL_ID = "speakleash/Bielik-11B-v2.3-Instruct"
15
- SYSTEM_PROMPT = "Jesteś chatboem udzielającym odpowiedzi na pytania w języku polskim"
16
- DEFAULT_GENERATION_PARAMS = {
17
- "max_new_tokens": 5000,
18
- "temperature": 0,
19
- "top_k": 0,
20
- "top_p": 0,
21
- }
22
-
23
-
24
- class ModelLoader:
25
- """Handles model loading and device setup"""
26
-
27
- def __init__(self, model_id: str):
28
- self.device = self._get_device()
29
- self.quantization_config = BitsAndBytesConfig(
30
- load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16
31
- )
32
- self.tokenizer = self._load_tokenizer(model_id)
33
- self.model = self._load_model(model_id)
34
-
35
- def _get_device(self) -> torch.device:
36
- """Determine and return the appropriate device"""
37
- if torch.cuda.is_available():
38
- device = torch.device("cuda")
39
- print(f"Using GPU: {torch.cuda.get_device_name(0)}")
40
- else:
41
- device = torch.device("cpu")
42
- print("CUDA is not available. Using CPU.")
43
- return device
44
-
45
- def _load_tokenizer(self, model_id: str) -> AutoTokenizer:
46
- """Load and configure the tokenizer"""
47
- tokenizer = AutoTokenizer.from_pretrained(model_id)
48
- tokenizer.pad_token = tokenizer.eos_token
49
- return tokenizer
50
-
51
- def _load_model(self, model_id: str) -> AutoModelForCausalLM:
52
- """Load and configure the model"""
53
- return AutoModelForCausalLM.from_pretrained(
54
- model_id,
55
- torch_dtype=torch.bfloat16,
56
- quantization_config=self.quantization_config,
57
- low_cpu_mem_usage=True,
58
- device_map="auto",
59
- )
60
-
61
-
62
- class ChatInterface:
63
- """Handles chat interactions and response generation"""
64
-
65
- def __init__(self, model_loader: ModelLoader):
66
- self.model = model_loader.model
67
- self.tokenizer = model_loader.tokenizer
68
- self.device = model_loader.device
69
-
70
- @spaces.GPU
71
- def generate_response(
72
- self, prompt: str, system_prompt: Optional[str] = None
73
- ) -> Generator[str, None, None]:
74
- """Generate streaming response for the given prompt"""
75
- generation_params = DEFAULT_GENERATION_PARAMS.copy()
76
- streamer = TextIteratorStreamer(
77
- self.tokenizer, skip_prompt=True, skip_special_tokens=True
78
- )
79
-
80
- messages = self._build_messages(prompt, system_prompt or SYSTEM_PROMPT)
81
- tokenizer_output = self._prepare_inputs(messages)
82
-
83
- generate_kwargs = {
84
- **generation_params,
85
- **tokenizer_output,
86
- "streamer": streamer,
87
- "do_sample": bool(generation_params["temperature"]),
88
- }
89
-
90
- self._start_generation_thread(generate_kwargs)
91
- yield from self._stream_response(streamer)
92
-
93
- def _build_messages(self, prompt: str, system_prompt: str) -> List[Dict[str, str]]:
94
- """Build the message structure for the model"""
95
- messages = [{"role": "system", "content": system_prompt}]
96
- messages.append({"role": "user", "content": prompt})
97
- return messages
98
-
99
- def _prepare_inputs(
100
- self, messages: List[Dict[str, str]]
101
- ) -> Dict[str, torch.Tensor]:
102
- """Prepare model inputs from messages"""
103
- tokenizer_output = self.tokenizer.apply_chat_template(
104
- messages, return_tensors="pt", return_dict=True
105
- )
106
-
107
- # Ensure all tensors are on the correct device
108
- inputs = {
109
- "input_ids": tokenizer_output.input_ids.to(self.device),
110
- "attention_mask": tokenizer_output.attention_mask.to(self.device),
111
- }
112
-
113
- # Move model to device if not already there
114
- if self.model.device != self.device:
115
- self.model.to(self.device)
116
-
117
- return inputs
118
-
119
- def _start_generation_thread(self, generate_kwargs: Dict):
120
- """Start model generation in a separate thread"""
121
- t = Thread(target=self.model.generate, kwargs=generate_kwargs)
122
- t.start()
123
-
124
- def _stream_response(
125
- self, streamer: TextIteratorStreamer
126
- ) -> Generator[str, None, None]:
127
- """Stream the response token by token"""
128
- partial_response = ""
129
- for new_token in streamer:
130
- partial_response += new_token
131
- if any(
132
- stop_token in partial_response
133
- for stop_token in ["<|im_end|>", "<|endoftext|>"]
134
- ):
135
- break
136
- yield partial_response
137
-
138
-
139
- def create_gradio_interface(chat_interface: ChatInterface) -> gr.Interface:
140
- """Create and configure the Gradio interface"""
141
- return gr.Interface(
142
- fn=chat_interface.generate_response,
143
- inputs=gr.Textbox(
144
- label="Your question", placeholder="Type your question here..."
145
- ),
146
- outputs=gr.Textbox(label="Answer", lines=5),
147
- title="Polish Chatbot",
148
- description="Ask questions in Polish to the Bielik-11B-v2.3-Instruct model",
149
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
 
 
 
151
 
152
- if __name__ == "__main__":
153
- # Initialize components
154
- model_loader = ModelLoader(MODEL_ID)
155
- chat_interface = ChatInterface(model_loader)
156
 
157
- # Create and launch interface
158
- demo = create_gradio_interface(chat_interface)
159
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
  import spaces
 
9
  )
10
  from threading import Thread
11
 
 
12
  MODEL_ID = "speakleash/Bielik-11B-v2.3-Instruct"
13
+ MODEL_NAME = MODEL_ID.split("/")[-1]
14
+
15
+ if torch.cuda.is_available():
16
+ device = torch.device("cuda")
17
+ print("Using GPU:", torch.cuda.get_device_name(0))
18
+ else:
19
+ device = torch.device("cpu")
20
+ print("CUDA is not available. Using CPU.")
21
+
22
+ quantization_config = BitsAndBytesConfig(
23
+ load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  )
25
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
26
+ tokenizer.pad_token = tokenizer.eos_token
27
+ model = AutoModelForCausalLM.from_pretrained(
28
+ MODEL_ID,
29
+ torch_dtype=torch.bfloat16,
30
+ quantization_config=quantization_config,
31
+ low_cpu_mem_usage=True,
32
+ )
33
+
34
+
35
+ @spaces.GPU
36
+ def test(prompt):
37
+ max_tokens = 5000
38
+ temperature = 0
39
+ top_k = 0
40
+ top_p = 0
41
 
42
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
43
+ system = "Jesteś chatboem udzielającym odpowiedzi na pytania w języku polskim"
44
+ messages = []
45
 
46
+ if system:
47
+ messages.append({"role": "system", "content": system})
 
 
48
 
49
+ messages.append({"role": "user", "content": prompt})
50
+
51
+ tokenizer_output = tokenizer.apply_chat_template(
52
+ messages, return_tensors="pt", return_dict=True
53
+ )
54
+
55
+ if torch.cuda.is_available():
56
+ model_input_ids = tokenizer_output.input_ids.to(device)
57
+
58
+ model_attention_mask = tokenizer_output.attention_mask.to(device)
59
+
60
+ else:
61
+ model_input_ids = tokenizer_output.input_ids
62
+ model_attention_mask = tokenizer_output.attention_mask
63
+
64
+ generate_kwargs = {
65
+ "input_ids": model_input_ids,
66
+ "attention_mask": model_attention_mask,
67
+ "streamer": streamer,
68
+ "max_new_tokens": max_tokens,
69
+ "do_sample": True if temperature else False,
70
+ "temperature": temperature,
71
+ "top_k": top_k,
72
+ "top_p": top_p,
73
+ }
74
+
75
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
76
+ t.start()
77
+
78
+ partial_response = ""
79
+ for new_token in streamer:
80
+ partial_response += new_token
81
+ # Stop if we hit any of the special tokens
82
+ if "<|im_end|>" in partial_response or "<|endoftext|>" in partial_response:
83
+ break
84
+ yield partial_response
85
+
86
+
87
+ demo = gr.Interface(
88
+ fn=test,
89
+ inputs=gr.Textbox(label="Your question", placeholder="Type your question here..."),
90
+ outputs=gr.Textbox(label="Answer", lines=5),
91
+ title="Polish Chatbot",
92
+ description="Ask questions in Polish to the Bielik-11B-v2.3-Instruct model"
93
+ )
94
+ demo.launch()