Grid search is a foundational optimization technique used to find the minimum or maximum of a function by systematically testing values within a specified range.

It operates by discretizing the search space into a grid of values, evaluating the function at each grid point, and refining the search around the region of interest in subsequent iterations.

This method is particularly useful for problems where

Although it lacks the sophistication of gradient-based or probabilistic methods, grid search's simplicity and thoroughness make it an indispensable tool, especially in scenarios where derivatives are unavailable or unreliable, providing a robust baseline against which more complex algorithms can be compared.

Code

import torch
import matplotlib.pyplot as plt

Non-convex Function: Below is the definition of a function that we aim to optimize to find its minimal value. This particular function exhibits an interesting behavior—it oscillates due to the sine component, which adds complexity across different values of x.

def f(x):
  return (x - 3)**2 + 3*torch.sin(5 * x)

image.png

This formula combines a quadratic term, which generally creates a parabolic shape, with a trigonometric sine function that introduces oscillations, making the function's landscape varied and more challenging to optimize.

# Initial parameters for the grid search.
x_start = -10
x_end = 10
num_grids = 6
refinement_factor = 3
iterations = 4  # We want to visualize the first 4 steps

# Generate initial full range of x values and compute function values to set fixed y-axis limits.
x_full = torch.linspace(-10, 10, 100)
y_full = f(x_full)
y_min, y_max = y_full.min().item(), y_full.max().item()
# Set up a 1x4 grid for plotting.
fig, axes = plt.subplots(1, iterations, figsize=(20, 5))

# Loop over the iterations, visualizing each search grid.
for i in range(iterations):
    # Create the grid of x-values and compute f(x) over this grid.
    x = torch.linspace(x_start, x_end, num_grids)
    y = f(x)
    
    # Find the index and value of the minimum in the current grid.
    min_index = torch.argmin(y)
    min_x = x[min_index]
    min_y = y[min_index]
    
    # Plot the function and the current estimated minimum.
    ax = axes[i]
    ax.plot(x_full.numpy(), y_full.numpy(), 
            label='f(x)', alpha=0.3)  # Faint curve for the whole function range
    ax.scatter(x.numpy(), y.numpy(), color='blue', 
               s=10, label='Grid points')  # Grid points
    ax.scatter(min_x.item(), min_y.item(), color='red', 
               s=50, zorder=5, label='min')  # Minimum point
    ax.set_title(f'Iteration {i+1}\\nmin_x = {min_x.item():.4f}')
    ax.set_xlabel('x')
    ax.set_ylabel('f(x)')
    ax.legend()
    
    # Set fixed axes limits.
    ax.set_xlim([-10, 10])
    ax.set_ylim([y_min - 1, y_max + 1])  # A bit of padding around the y range
    
    # Refine the search range around the current minimum.
    interval = abs(x_start - x_end) / (num_grids - 1)
    x_start = min_x - interval
    x_end = min_x + interval
plt.tight_layout()
plt.show()

image.png