lanzhiwang commited on
Commit
33d25b3
·
1 Parent(s): 5a50038
Files changed (1) hide show
  1. app.py +16 -18
app.py CHANGED
@@ -10,31 +10,29 @@ def erzeuge(prompt):
10
  return pipeline(prompt).images # [0]
11
 
12
 
13
- def erzeuge_komplex(prompt):
14
- scheduler = DDPMScheduler.from_pretrained("google/ddpm-cat-256")
15
- model = UNet2DModel.from_pretrained("google/ddpm-cat-256").to("cuda")
16
- scheduler.set_timesteps(50)
17
 
18
- sample_size = model.config.sample_size
19
- noise = torch.randn((1, 3, sample_size, sample_size)).to("cuda")
20
- input = noise
21
 
22
- for t in scheduler.timesteps:
23
- with torch.no_grad():
24
- noisy_residual = model(input, t).sample
25
- prev_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample
26
- input = prev_noisy_sample
27
 
28
- image = (input / 2 + 0.5).clamp(0, 1)
29
- image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
30
- image = Image.fromarray((image * 255).round().astype("uint8"))
31
- return image
32
 
33
 
34
- # pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2")
35
  # pipeline = DiffusionPipeline.from_pretrained("google/ddpm-cat-256")
36
  pipeline = DiffusionPipeline.from_pretrained("google/ddpm-celebahq-256")
37
-
38
  # pipeline.to("cuda")
39
 
40
 
 
10
  return pipeline(prompt).images # [0]
11
 
12
 
13
+ # def erzeuge_komplex(prompt):
14
+ # scheduler = DDPMScheduler.from_pretrained("google/ddpm-cat-256")
15
+ # model = UNet2DModel.from_pretrained("google/ddpm-cat-256").to("cuda")
16
+ # scheduler.set_timesteps(50)
17
 
18
+ # sample_size = model.config.sample_size
19
+ # noise = torch.randn((1, 3, sample_size, sample_size)).to("cuda")
20
+ # input = noise
21
 
22
+ # for t in scheduler.timesteps:
23
+ # with torch.no_grad():
24
+ # noisy_residual = model(input, t).sample
25
+ # prev_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample
26
+ # input = prev_noisy_sample
27
 
28
+ # image = (input / 2 + 0.5).clamp(0, 1)
29
+ # image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
30
+ # image = Image.fromarray((image * 255).round().astype("uint8"))
31
+ # return image
32
 
33
 
 
34
  # pipeline = DiffusionPipeline.from_pretrained("google/ddpm-cat-256")
35
  pipeline = DiffusionPipeline.from_pretrained("google/ddpm-celebahq-256")
 
36
  # pipeline.to("cuda")
37
 
38