Diffusion models have gained popularity in generative modeling, especially in image generation tasks. The key idea behind diffusion models is to gradually add noise to an image (the forward process) and then learn to reverse this process to denoise it back to the original (the reverse process). This tutorial will walk you through the essential parts of a diffusion model based on the provided code.
A diffusion model operates in two phases:
Now, let's break down the key components in the code to understand this process in detail.
The forward process is where we add noise to the original image in a controlled manner across multiple steps. In this case, the noise is added based on a schedule (called a beta schedule) that defines how much noise is introduced at each step.
def forward_diffusion(x_0, t, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod):
noise = torch.randn_like(x_0)
sqrt_alpha_t = sqrt_alphas_cumprod[t][:, None, None, None]
sqrt_one_minus_alpha_t = sqrt_one_minus_alphas_cumprod[t][:, None, None, None]
noisy_image = sqrt_alpha_t * x_0 + sqrt_one_minus_alpha_t * noise
return noisy_image, noise
x_0
: The original clean image.t
: The current timestep (step of adding noise).sqrt_alphas_cumprod
and sqrt_one_minus_alphas_cumprod
: Precomputed terms that help control the amount of noise added at each step.noisy_image
: The image with noise added at timestep t
.noise
: The random noise that was added (used later for training).The process is controlled by the noise schedule, specifically the terms sqrt_alpha_t
and sqrt_one_minus_alpha_t
, which determine how much of the original image and noise should be combined.
Once we have the noisy image, the reverse process aims to denoise it step by step, essentially reversing the forward diffusion process. This reverse process is done using a U-Net model, which learns to predict the noise added at each step.
@torch.no_grad()
def sample(model, image_size, batch_size, timesteps, betas, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, device):
model.eval()
img = torch.randn((batch_size, 3, image_size, image_size)).to(device)
for t in reversed(range(1, timesteps)):
t_tensor = torch.full((batch_size,), t, device=device, dtype=torch.long)
predicted_noise = model(img, t_tensor)
sqrt_alpha_t = sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
sqrt_alpha_t = sqrt_alpha_t + 1e-8 # Add epsilon for numerical stability
sqrt_one_minus_alpha_t = sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
beta_t = betas[t].view(-1, 1, 1, 1)
# Reverse the noise at timestep t
img = (img - (beta_t / sqrt_one_minus_alpha_t) * predicted_noise) / torch.sqrt(sqrt_alpha_t + 1e-8)
img = torch.clamp(img, -1.0, 1.0)
# Optionally add Gaussian noise if not at the final step
if t > 1:
noise = torch.randn_like(img)
img = img + sqrt_one_minus_alpha_t * noise
return img
model
: The trained U-Net model that predicts the noise added at each timestep.timesteps
: The number of diffusion steps.betas
, sqrt_alphas_cumprod
, and sqrt_one_minus_alphas_cumprod
: Terms to control the noise level.torch.randn
).1e-8
are added.The noise schedule (i.e., how much noise is added at each step) is defined by the beta schedule. In this case, it's a linear schedule that increases the noise linearly across timesteps.