Visual and code walkthrough to understand how models learn using gradient descent.
"If you dropped a ball on a valley-shaped hill, where would it settle? That's gradient descent in action."
Explain in simple terms: - Gradient Descent is an algorithm used to find the minimum of a function. - It is used heavily in machine learning to minimize error/loss in prediction.
This is what the algorithm does:
"Take a small step opposite to the gradient."
Gradient Descent is the engine that powers the learning process in most supervised ML models. Here's a breakdown:
Gradient Descent is responsible for step 6, which updates the model to reduce error.
In theory, we could: - Evaluate the loss function for all combinations of parameters. - Pick the one with the lowest error.
For example, if we try 1 million values for 100 parameters:
10⁶^{100} = 10^{600}
This is more than the number of atoms in the universe.
Even the world's fastest supercomputer can't handle this.
Gradient Descent gives us a smart shortcut: - It doesn't search randomly. - It uses the gradient (slope) to find the best direction to reduce error. - It iteratively refines parameters with minimal computation.
You go from being lost in a jungle with no map… to having a compass that always points downhill.
Visual:
Plot a simple convex function:
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(-10, 10, 100)
y = x**2
plt.plot(x, y)
plt.title("Convex Function: y = x²")
plt.xlabel("x")
plt.ylabel("y")
plt.grid(True)
plt.show()
import numpy as np
import matplotlib.pyplot as plt
# Function and Gradient
def f(x): return x**2
def grad(x): return 2*x
# Gradient Descent
x_vals = [8] # Start from x=8
alpha = 0.1
for _ in range(20):
x_new = x_vals[-1] - alpha * grad(x_vals[-1])
x_vals.append(x_new)
# Plot descent steps
x = np.linspace(-10, 10, 100)
y = f(x)
plt.plot(x, y, label="y = x²")
plt.scatter(x_vals, [f(i) for i in x_vals], color='red')
plt.plot(x_vals, [f(i) for i in x_vals], 'r--', label="Descent Path")
plt.title("Gradient Descent on Convex Function")
plt.xlabel("x")
plt.ylabel("f(x)")
plt.legend()
plt.grid(True)
plt.show()
If these aren't true: - You might need a lower learning rate - Consider switching to advanced optimisers (Adam, RMSprop) - Add regularization to avoid overfitting
Whether you're training a linear model or a deep neural net, understanding how gradient descent adjusts parameters brings transparency and control to your modelling journey.