syrinenoamen commited on
Commit
c6e5e05
·
1 Parent(s): d4e2a68

general normal distribution

Browse files
Files changed (2) hide show
  1. README.md +4 -1
  2. pipeline.py +25 -0
README.md CHANGED
@@ -14,7 +14,10 @@ Currently implemented methods:
14
  Example: `custom_pipeline.load_initial_noise_modifier(method="fixed-seed", seed=…)`
15
  - Golden Noise for Diffusion Models: A Learning Framework (Zhou et al., https://arxiv.org/abs/2411.09502).
16
  Example: `custom_pipeline.load_initial_noise_modifier(method="golden-noise", npnet_path=…)`
17
- Demo Notebook: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-owYN8r2TbT-Je_eTEpnIMLj1nvxPYqI?usp=sharing)
 
 
 
18
 
19
  ## Citation
20
 
 
14
  Example: `custom_pipeline.load_initial_noise_modifier(method="fixed-seed", seed=…)`
15
  - Golden Noise for Diffusion Models: A Learning Framework (Zhou et al., https://arxiv.org/abs/2411.09502).
16
  Example: `custom_pipeline.load_initial_noise_modifier(method="golden-noise", npnet_path=…)`
17
+
18
+
19
+
20
+ Demo Notebook: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-owYN8r2TbT-Je_eTEpnIMLj1nvxPYqI#scrollTo=HQS6OQ44jz66)
21
 
22
  ## Citation
23
 
pipeline.py CHANGED
@@ -85,6 +85,25 @@ class NoiseLoaderMixin:
85
  elif method == "fixed-seed":
86
  # This is for demonstration purposes to see the impact of different seeds
87
  self.seed = method_args.get("seed", 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  else:
89
  raise NotImplementedError(f"No initial noise method is implemented for the given method {method}")
90
 
@@ -117,6 +136,12 @@ class NoiseLoaderMixin:
117
  elif self.initial_noise_sampling_method == "fixed-seed":
118
  generator = torch.Generator(device=self.device).manual_seed(self.seed)
119
  return torch.randn(shape, generator=generator, dtype=self.dtype, device=self.device)
 
 
 
 
 
 
120
 
121
  def unload_initial_noise_modifier(self):
122
  """
 
85
  elif method == "fixed-seed":
86
  # This is for demonstration purposes to see the impact of different seeds
87
  self.seed = method_args.get("seed", 0)
88
+ elif method == "general-normal-distribution":
89
+ self.init_noise_mean = method_args.get("init_noise_mean", [0.0, 0.0, 0.0, 0.0])
90
+ self.init_noise_std = method_args.get("init_noise_std", [1.0, 1.0, 1.0, 1.0])
91
+
92
+ if isinstance(self.init_noise_mean, (list, tuple)):
93
+ if len(self.init_noise_mean) != 4:
94
+ raise ValueError("Mean must be a scalar or a list/tuple of 4 values (one per channel)")
95
+ self.init_noise_mean = torch.tensor(self.init_noise_mean, dtype=self.dtype, device=self.device).view(1, 4, 1, 1)
96
+ else:
97
+ # Scalar case - broadcast to all channels
98
+ self.init_noise_mean = torch.tensor([self.init_noise_mean] * 4, dtype=self.dtype, device=self.device).view(1, 4, 1, 1)
99
+
100
+ if isinstance(self.init_noise_std, (list, tuple)):
101
+ if len(self.init_noise_std) != 4:
102
+ raise ValueError("Std must be a scalar or a list/tuple of 4 values (one per channel)")
103
+ self.init_noise_std = torch.tensor(self.init_noise_std, dtype=self.dtype, device=self.device).view(1, 4, 1, 1)
104
+ else:
105
+ # Scalar case - broadcast to all channels
106
+ self.init_noise_std = torch.tensor([self.init_noise_std] * 4, dtype=self.dtype, device=self.device).view(1, 4, 1, 1)
107
  else:
108
  raise NotImplementedError(f"No initial noise method is implemented for the given method {method}")
109
 
 
136
  elif self.initial_noise_sampling_method == "fixed-seed":
137
  generator = torch.Generator(device=self.device).manual_seed(self.seed)
138
  return torch.randn(shape, generator=generator, dtype=self.dtype, device=self.device)
139
+ elif self.initial_noise_sampling_method == "general-normal-distribution":
140
+ # Generate signal using custom mean and std
141
+ initial_noise = self.init_noise_mean + self.init_noise_std * randn_tensor(
142
+ shape, generator=generator, device=self.device, dtype=self.dtype
143
+ )
144
+ return initial_noise
145
 
146
  def unload_initial_noise_modifier(self):
147
  """