Grid search is a foundational optimization technique used to find the minimum or maximum of a function by systematically testing values within a specified range.
It operates by discretizing the search space into a grid of values, evaluating the function at each grid point, and refining the search around the region of interest in subsequent iterations.
This method is particularly useful for problems where
Although it lacks the sophistication of gradient-based or probabilistic methods, grid search's simplicity and thoroughness make it an indispensable tool, especially in scenarios where derivatives are unavailable or unreliable, providing a robust baseline against which more complex algorithms can be compared.
import torch
import matplotlib.pyplot as plt
Non-convex Function: Below is the definition of a function that we aim to optimize to find its minimal value. This particular function exhibits an interesting behavior—it oscillates due to the sine component, which adds complexity across different values of x
.
def f(x):
return (x - 3)**2 + 3*torch.sin(5 * x)
This formula combines a quadratic term, which generally creates a parabolic shape, with a trigonometric sine function that introduces oscillations, making the function's landscape varied and more challenging to optimize.
x_start
and x_end
.num_grids
) to evaluate the function.# Initial parameters for the grid search.
x_start = -10
x_end = 10
num_grids = 6
refinement_factor = 3
iterations = 4 # We want to visualize the first 4 steps
# Generate initial full range of x values and compute function values to set fixed y-axis limits.
x_full = torch.linspace(-10, 10, 100)
y_full = f(x_full)
y_min, y_max = y_full.min().item(), y_full.max().item()
iterations
specified):
num_grids
evenly spaced values within the current range (x_start
to x_end
).min_y
) and its corresponding x
value (min_x
).# Set up a 1x4 grid for plotting.
fig, axes = plt.subplots(1, iterations, figsize=(20, 5))
# Loop over the iterations, visualizing each search grid.
for i in range(iterations):
# Create the grid of x-values and compute f(x) over this grid.
x = torch.linspace(x_start, x_end, num_grids)
y = f(x)
# Find the index and value of the minimum in the current grid.
min_index = torch.argmin(y)
min_x = x[min_index]
min_y = y[min_index]
# Plot the function and the current estimated minimum.
ax = axes[i]
ax.plot(x_full.numpy(), y_full.numpy(),
label='f(x)', alpha=0.3) # Faint curve for the whole function range
ax.scatter(x.numpy(), y.numpy(), color='blue',
s=10, label='Grid points') # Grid points
ax.scatter(min_x.item(), min_y.item(), color='red',
s=50, zorder=5, label='min') # Minimum point
ax.set_title(f'Iteration {i+1}\\nmin_x = {min_x.item():.4f}')
ax.set_xlabel('x')
ax.set_ylabel('f(x)')
ax.legend()
# Set fixed axes limits.
ax.set_xlim([-10, 10])
ax.set_ylim([y_min - 1, y_max + 1]) # A bit of padding around the y range
# Refine the search range around the current minimum.
interval = abs(x_start - x_end) / (num_grids - 1)
x_start = min_x - interval
x_end = min_x + interval
plt.tight_layout()
plt.show()