hvlgo commited on
Commit
e04dafa
·
verified ·
1 Parent(s): 510ebf2

Update flow_loss.py

Browse files
Files changed (1) hide show
  1. flow_loss.py +1 -1
flow_loss.py CHANGED
@@ -39,7 +39,7 @@ class FlowLoss(nn.Module):
39
 
40
  def sample(self, z, num_samples=1):
41
  z = z.repeat(num_samples, 1)
42
- noise = torch.randn(z.shape[0], self.in_channels).cuda()
43
  x = noise
44
  dt = 1.0 / self.num_sampling_steps
45
  for i in range(self.num_sampling_steps):
 
39
 
40
  def sample(self, z, num_samples=1):
41
  z = z.repeat(num_samples, 1)
42
+ noise = torch.randn(z.shape[0], self.in_channels).to(z.device)
43
  x = noise
44
  dt = 1.0 / self.num_sampling_steps
45
  for i in range(self.num_sampling_steps):