Commit ·
c6e5e05
1
Parent(s): d4e2a68
general normal distribution
Browse files- README.md +4 -1
- pipeline.py +25 -0
README.md
CHANGED
|
@@ -14,7 +14,10 @@ Currently implemented methods:
|
|
| 14 |
Example: `custom_pipeline.load_initial_noise_modifier(method="fixed-seed", seed=…)`
|
| 15 |
- Golden Noise for Diffusion Models: A Learning Framework (Zhou et al., https://arxiv.org/abs/2411.09502).
|
| 16 |
Example: `custom_pipeline.load_initial_noise_modifier(method="golden-noise", npnet_path=…)`
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
## Citation
|
| 20 |
|
|
|
|
| 14 |
Example: `custom_pipeline.load_initial_noise_modifier(method="fixed-seed", seed=…)`
|
| 15 |
- Golden Noise for Diffusion Models: A Learning Framework (Zhou et al., https://arxiv.org/abs/2411.09502).
|
| 16 |
Example: `custom_pipeline.load_initial_noise_modifier(method="golden-noise", npnet_path=…)`
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
Demo Notebook: [](https://colab.research.google.com/drive/1-owYN8r2TbT-Je_eTEpnIMLj1nvxPYqI#scrollTo=HQS6OQ44jz66)
|
| 21 |
|
| 22 |
## Citation
|
| 23 |
|
pipeline.py
CHANGED
|
@@ -85,6 +85,25 @@ class NoiseLoaderMixin:
|
|
| 85 |
elif method == "fixed-seed":
|
| 86 |
# This is for demonstration purposes to see the impact of different seeds
|
| 87 |
self.seed = method_args.get("seed", 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
else:
|
| 89 |
raise NotImplementedError(f"No initial noise method is implemented for the given method {method}")
|
| 90 |
|
|
@@ -117,6 +136,12 @@ class NoiseLoaderMixin:
|
|
| 117 |
elif self.initial_noise_sampling_method == "fixed-seed":
|
| 118 |
generator = torch.Generator(device=self.device).manual_seed(self.seed)
|
| 119 |
return torch.randn(shape, generator=generator, dtype=self.dtype, device=self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
def unload_initial_noise_modifier(self):
|
| 122 |
"""
|
|
|
|
| 85 |
elif method == "fixed-seed":
|
| 86 |
# This is for demonstration purposes to see the impact of different seeds
|
| 87 |
self.seed = method_args.get("seed", 0)
|
| 88 |
+
elif method == "general-normal-distribution":
|
| 89 |
+
self.init_noise_mean = method_args.get("init_noise_mean", [0.0, 0.0, 0.0, 0.0])
|
| 90 |
+
self.init_noise_std = method_args.get("init_noise_std", [1.0, 1.0, 1.0, 1.0])
|
| 91 |
+
|
| 92 |
+
if isinstance(self.init_noise_mean, (list, tuple)):
|
| 93 |
+
if len(self.init_noise_mean) != 4:
|
| 94 |
+
raise ValueError("Mean must be a scalar or a list/tuple of 4 values (one per channel)")
|
| 95 |
+
self.init_noise_mean = torch.tensor(self.init_noise_mean, dtype=self.dtype, device=self.device).view(1, 4, 1, 1)
|
| 96 |
+
else:
|
| 97 |
+
# Scalar case - broadcast to all channels
|
| 98 |
+
self.init_noise_mean = torch.tensor([self.init_noise_mean] * 4, dtype=self.dtype, device=self.device).view(1, 4, 1, 1)
|
| 99 |
+
|
| 100 |
+
if isinstance(self.init_noise_std, (list, tuple)):
|
| 101 |
+
if len(self.init_noise_std) != 4:
|
| 102 |
+
raise ValueError("Std must be a scalar or a list/tuple of 4 values (one per channel)")
|
| 103 |
+
self.init_noise_std = torch.tensor(self.init_noise_std, dtype=self.dtype, device=self.device).view(1, 4, 1, 1)
|
| 104 |
+
else:
|
| 105 |
+
# Scalar case - broadcast to all channels
|
| 106 |
+
self.init_noise_std = torch.tensor([self.init_noise_std] * 4, dtype=self.dtype, device=self.device).view(1, 4, 1, 1)
|
| 107 |
else:
|
| 108 |
raise NotImplementedError(f"No initial noise method is implemented for the given method {method}")
|
| 109 |
|
|
|
|
| 136 |
elif self.initial_noise_sampling_method == "fixed-seed":
|
| 137 |
generator = torch.Generator(device=self.device).manual_seed(self.seed)
|
| 138 |
return torch.randn(shape, generator=generator, dtype=self.dtype, device=self.device)
|
| 139 |
+
elif self.initial_noise_sampling_method == "general-normal-distribution":
|
| 140 |
+
# Generate signal using custom mean and std
|
| 141 |
+
initial_noise = self.init_noise_mean + self.init_noise_std * randn_tensor(
|
| 142 |
+
shape, generator=generator, device=self.device, dtype=self.dtype
|
| 143 |
+
)
|
| 144 |
+
return initial_noise
|
| 145 |
|
| 146 |
def unload_initial_noise_modifier(self):
|
| 147 |
"""
|