Grid search is a foundational optimization technique used to find the minimum or maximum of a function by systematically testing values within a specified range.

It operates by discretizing the search space into a grid of values, evaluating the function at each grid point, and refining the search around the region of interest in subsequent iterations.

This method is particularly useful for problems where

Although it lacks the sophistication of gradient-based or probabilistic methods, grid search's simplicity and thoroughness make it an indispensable tool, especially in scenarios where derivatives are unavailable or unreliable, providing a robust baseline against which more complex algorithms can be compared.

Code

import torch
import matplotlib.pyplot as plt

Non-convex Function: Below is the definition of a function that we aim to optimize to find its minimal value. This particular function exhibits an interesting behavior—it oscillates due to the sine component, which adds complexity across different values of x.

def f(x):
  return (x - 3)**2 + 3*torch.sin(5 * x)

image.png

This formula combines a quadratic term, which generally creates a parabolic shape, with a trigonometric sine function that introduces oscillations, making the function's landscape varied and more challenging to optimize.

# Initial parameters for the grid search.
x_start = -10
x_end = 10
num_grids = 6
iterations = 4  # We want to visualize the first 4 steps

image.png

# Generate initial full range of x values and compute function values to set fixed y-axis limits. (For visualization purpose)
x_full = torch.linspace(-10, 10, 100)
y_full = f(x_full)
y_min, y_max = y_full.min().item(), y_full.max().item()
# Set up a 1x4 grid for plotting.
fig, axes = plt.subplots(1, iterations, figsize=(20, 5))

for i in range(iterations):
    x = torch.linspace(x_start, x_end, num_grids)
    y = f(x)
    
    min_index = torch.argmin(y)
    min_x = x[min_index]
    min_y = y[min_index]
    
    ax = axes[i]
    ax.plot(x_full.numpy(), y_full.numpy(), label='f(x)', alpha=0.3)
    ax.scatter(x.numpy(), y.numpy(), color='blue', s=10, label='Grid points') 
    ax.scatter(min_x.item(), min_y.item(), color='red', s=50, zorder=5, label='min') 
    ax.set_title(f'Iteration {i+1}\\nmin_x = {min_x.item():.4f}')
    ax.set_xlabel('x')
    ax.set_ylabel('f(x)')
    ax.legend()
    
    ax.set_xlim([-10, 10])
    ax.set_ylim([y_min - 1, y_max + 1]) 
    
    interval = abs(x_start - x_end) / (num_grids - 1)
    x_start = min_x - interval
    x_end = min_x + interval
plt.tight_layout()
plt.show()

image.png