Update flow_loss.py
Browse files- flow_loss.py +1 -1
flow_loss.py
CHANGED
@@ -39,7 +39,7 @@ class FlowLoss(nn.Module):
|
|
39 |
|
40 |
def sample(self, z, num_samples=1):
|
41 |
z = z.repeat(num_samples, 1)
|
42 |
-
noise = torch.randn(z.shape[0], self.in_channels).
|
43 |
x = noise
|
44 |
dt = 1.0 / self.num_sampling_steps
|
45 |
for i in range(self.num_sampling_steps):
|
|
|
39 |
|
40 |
def sample(self, z, num_samples=1):
|
41 |
z = z.repeat(num_samples, 1)
|
42 |
+
noise = torch.randn(z.shape[0], self.in_channels).to(z.device)
|
43 |
x = noise
|
44 |
dt = 1.0 / self.num_sampling_steps
|
45 |
for i in range(self.num_sampling_steps):
|