LPX55 commited on
Commit
a84487d
·
verified ·
1 Parent(s): 3301022

Update mini.py

Browse files
Files changed (1) hide show
  1. mini.py +7 -7
mini.py CHANGED
@@ -147,7 +147,7 @@ def process_image(image):
147
  # Move Qwen2VL models to GPU
148
  logger.info("Moving Qwen2VL models to GPU...")
149
  # qwen2vl.to(DEVICE)
150
- connector.to(DEVICE)
151
 
152
  message = [
153
  {
@@ -179,8 +179,8 @@ def process_image(image):
179
  result = (image_hidden_state.cpu(), image_grid_thw)
180
 
181
  # Move models back to CPU
182
- qwen2vl.cpu()
183
- connector.cpu()
184
  torch.cuda.empty_cache()
185
 
186
  return result
@@ -221,7 +221,7 @@ def compute_t5_text_embeddings(prompt):
221
 
222
  prompt_embeds = text_encoder_two(text_inputs.input_ids)[0]
223
  prompt_embeds = t5_context_embedder.to(DEVICE)(prompt_embeds)
224
- t5_context_embedder.cpu()
225
 
226
  return prompt_embeds
227
 
@@ -268,11 +268,11 @@ def generate(input_image, prompt="", guidance_scale=3.5, num_inference_steps=28,
268
  # Move Transformer and VAE to GPU
269
  logger.info("Moving Transformer and VAE to GPU...")
270
  # transformer.to(DEVICE)
271
- vae.to(DEVICE)
272
 
273
  # Update pipeline models
274
- pipeline.transformer = transformer
275
- pipeline.vae = vae
276
  logger.info("Models moved to GPU")
277
 
278
  # Get dimensions
 
147
  # Move Qwen2VL models to GPU
148
  logger.info("Moving Qwen2VL models to GPU...")
149
  # qwen2vl.to(DEVICE)
150
+ # connector.to(DEVICE)
151
 
152
  message = [
153
  {
 
179
  result = (image_hidden_state.cpu(), image_grid_thw)
180
 
181
  # Move models back to CPU
182
+ # qwen2vl.cpu()
183
+ # connector.cpu()
184
  torch.cuda.empty_cache()
185
 
186
  return result
 
221
 
222
  prompt_embeds = text_encoder_two(text_inputs.input_ids)[0]
223
  prompt_embeds = t5_context_embedder.to(DEVICE)(prompt_embeds)
224
+ # t5_context_embedder.cpu()
225
 
226
  return prompt_embeds
227
 
 
268
  # Move Transformer and VAE to GPU
269
  logger.info("Moving Transformer and VAE to GPU...")
270
  # transformer.to(DEVICE)
271
+ # vae.to(DEVICE)
272
 
273
  # Update pipeline models
274
+ # pipeline.transformer = transformer
275
+ # pipeline.vae = vae
276
  logger.info("Models moved to GPU")
277
 
278
  # Get dimensions