Diffusion models have gained popularity in generative modeling, especially in image generation tasks. The key idea behind diffusion models is to gradually add noise to an image (the forward process) and then learn to reverse this process to denoise it back to the original (the reverse process). This tutorial will walk you through the essential parts of a diffusion model based on the provided code.

Introduction

A diffusion model operates in two phases:

image.png

Now, let's break down the key components in the code to understand this process in detail.

Forward Diffusion Process

The forward process is where we add noise to the original image in a controlled manner across multiple steps. In this case, the noise is added based on a schedule (called a beta schedule) that defines how much noise is introduced at each step.

def forward_diffusion(x_0, t, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod):
    noise = torch.randn_like(x_0)
    sqrt_alpha_t = sqrt_alphas_cumprod[t][:, None, None, None]
    sqrt_one_minus_alpha_t = sqrt_one_minus_alphas_cumprod[t][:, None, None, None]
    noisy_image = sqrt_alpha_t * x_0 + sqrt_one_minus_alpha_t * noise
    return noisy_image, noise

The process is controlled by the noise schedule, specifically the terms sqrt_alpha_t and sqrt_one_minus_alpha_t, which determine how much of the original image and noise should be combined.

Reverse Process (Sampling/Generation)

Once we have the noisy image, the reverse process aims to denoise it step by step, essentially reversing the forward diffusion process. This reverse process is done using a U-Net model, which learns to predict the noise added at each step.

@torch.no_grad()
def sample(model, image_size, batch_size, timesteps, betas, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, device):
    model.eval()
    img = torch.randn((batch_size, 3, image_size, image_size)).to(device)

    for t in reversed(range(1, timesteps)):
        t_tensor = torch.full((batch_size,), t, device=device, dtype=torch.long)
        predicted_noise = model(img, t_tensor)

        sqrt_alpha_t = sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_alpha_t = sqrt_alpha_t + 1e-8  # Add epsilon for numerical stability

        sqrt_one_minus_alpha_t = sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
        beta_t = betas[t].view(-1, 1, 1, 1)

        # Reverse the noise at timestep t
        img = (img - (beta_t / sqrt_one_minus_alpha_t) * predicted_noise) / torch.sqrt(sqrt_alpha_t + 1e-8)
        img = torch.clamp(img, -1.0, 1.0)

        # Optionally add Gaussian noise if not at the final step
        if t > 1:
            noise = torch.randn_like(img)
            img = img + sqrt_one_minus_alpha_t * noise

    return img

4. Beta Schedule for Noise Addition

The noise schedule (i.e., how much noise is added at each step) is defined by the beta schedule. In this case, it's a linear schedule that increases the noise linearly across timesteps.