Visualizing deeper layer kernels in image space can indeed be formulated as an optimization problem. The goal is to find an input image that maximizes the activation of a specific deep layer kernel. This is done by defining an objective function that quantifies the activation level of the kernel in response to an input image.
$$ \begin{aligned}\arg\max_x\quad &w_2*(w_1*x)\\ \text{s.t.}\quad & |x|=1 \end{aligned} $$
This optimization process essentially means the coding below,
# You may copy-paste it to Colab for Play
import torch
import torch.nn.functional as F
# Define two fixed 3x3 kernels (weights are not trainable)
# The 4 dimensions are actually a convention in image processing,
# The dimensions are allocated to [B, C, H, W]
# batch_size, channel_num, height, width
# Try to grab the MNIST trained parameters to here for learning.
kernel_a = torch.randn(1, 1, 3, 3, requires_grad=False)
kernel_b = torch.randn(1, 1, 3, 3, requires_grad=False)
# Define a 5x5 input variable that requires gradient
input_var = torch.randn(1, 1, 5, 5, requires_grad=True)
# Define an optimizer to optimize the input variable
optimizer = torch.optim.Adam([input_var], lr=0.01)
# We will optimize the input so that the convolution with kernel_b is maximized
# This means we want to maximize the value in the output feature map
# Since most optimizers are designed to minimize functions, we can minimize the negative of the output
# Optimization loop
for iteration in range(100):
optimizer.zero_grad() # Clear gradients from the previous step
# Project in Project Gradient Descent: Normalize the input variable x to vector length=1
# be aware of the .data means we only adjust its value without create a more complex computational graph (recurrusive diagram on input_var itself)
input_var.data = input_var.data / torch.norm(input_var.data, p=2).detach()
# Convolve the input with both kernels to get the feature maps
hidden_a = F.conv2d(input_var, kernel_a)
# Add your activation function if you want
# ---------------------------
# Add your activation function if you want
output_b = F.conv2d(hidden_a, kernel_b)
# We use the negative of the output for kernel_b because we want to maximize this output
loss = -torch.sum(output_b) # output_b should be just a single value
# Backpropagation
loss.backward()
# Update the input variable
optimizer.step()
# Print the loss every 10 iterations
if iteration % 10 == 0:
print(f"Iteration {iteration}, Loss: {loss.item()}")
# After optimization, input_var is the optimized 5x5 input
optimized_input = input_var.detach()
print("Optimized 5x5 Input:", optimized_input)