Optimisation

  • Training a neural network, fitting a regression line, tuning hyperparameters: at the core of almost every ML algorithm is an optimisation problem.

  • We have some function (a loss, a cost, an objective) and we want to find the inputs that make it as small (or large) as possible.

  • Before optimising, we need to understand zeros (or roots) of functions. A zero of $f(x)$ is a value $x$ where $f(x) = 0$. Graphically, these are the x-intercepts.

  • For example, $f(x) = x^2 - 3x + 2 = (x-1)(x-2)$ has zeros at $x = 1$ and $x = 2$. Between the zeros, the function is negative ($f(1.5) = -0.25$); outside the zeros, it is positive. The zeros divide the number line into regions where the function has constant sign.

  • The multiplicity of a zero is how many times the corresponding factor appears.

  • At a simple zero (multiplicity 1), the graph crosses the x-axis. At a double zero (multiplicity 2), the graph touches the x-axis but bounces back without crossing, appearing "flat" at that point.

  • Finding zeros matters because the zeros of the derivative $f'(x)$ are the critical points of $f(x)$, the candidates for maxima and minima.

  • At a maximum or minimum, the tangent line is flat (slope = 0), so $f'(x) = 0$.

Critical points: where the derivative equals zero, the function has a peak, valley, or saddle

  • But not every critical point is a maximum or minimum. A point where $f'(x) = 0$ could also be an inflection point (like $x = 0$ for $f(x) = x^3$), where the function flattens momentarily but does not change direction.

  • The second derivative test resolves this. At a critical point $x = c$ where $f'(c) = 0$:

    • If $f''(c) > 0$: the curve is concave up (like a bowl), so $c$ is a local minimum.
    • If $f''(c) < 0$: the curve is concave down (like a hill), so $c$ is a local maximum.
    • If $f''(c) = 0$: the test is inconclusive; higher derivatives or other methods are needed.
  • For example, $f(x) = x^3 - 3x$. The derivative is $f'(x) = 3x^2 - 3 = 3(x-1)(x+1)$, so critical points are at $x = -1$ and $x = 1$. The second derivative is $f''(x) = 6x$. At $x = -1$: $f''(-1) = -6 < 0$ (local max). At $x = 1$: $f''(1) = 6 > 0$ (local min).

  • A function is convex if the line segment between any two points on its graph lies above (or on) the graph. Think of it as a bowl shape, curving upward everywhere. Mathematically, $f$ is convex if $f''(x) \geq 0$ for all $x$.

Convex functions have a unique global minimum; non-convex functions can have many local minima

  • Convexity is powerful because convex functions have a remarkable property: every local minimum is also the global minimum. There are no deceptive local valleys to get trapped in. If you roll a ball into a convex bowl, it will always reach the bottom.

  • A function is concave (curving downward) if $-f$ is convex. Points where the function transitions between concave and convex are inflection points, occurring where $f''(x) = 0$.

  • Newton's method finds zeros of functions (and by extension, critical points of their derivatives) using tangent lines. Starting from an initial guess $x_0$, it iteratively refines:

$$x_{n+1} = x_n - \frac{f(x_n)}{f'(x_n)}$$

Newton's method: follow the tangent line to find a better approximation of the root

  • The idea: at $x_n$, draw the tangent line and find where it crosses the x-axis. That crossing point becomes $x_{n+1}$. For well-behaved functions with a good starting point, Newton's method converges very quickly (quadratically, meaning the number of correct digits roughly doubles each step).

  • For example, to find $\sqrt{5}$ (a zero of $f(x) = x^2 - 5$): $f'(x) = 2x$, so $x_{n+1} = x_n - \frac{x_n^2 - 5}{2x_n}$. Starting at $x_0 = 2$: $x_1 = 2.25$, $x_2 = 2.2361\ldots$, which is already accurate to four decimal places.

  • Newton's method can fail if the initial guess is far from the root, if $f'(x) = 0$ near the root, or if the function has inflection points nearby. It also requires computing the derivative, which may be expensive.

  • For optimisation (finding minima instead of zeros), we apply Newton's method to $f'(x) = 0$, which gives the update:

$$x_{n+1} = x_n - \frac{f'(x_n)}{f''(x_n)}$$

  • In multiple dimensions, this becomes $\mathbf{x}_{n+1} = \mathbf{x}_n - H^{-1} \nabla f(\mathbf{x}_n)$, where $H$ is the Hessian matrix. This is the second-order Taylor approximation from the previous file in action: approximate the function as a quadratic, jump to the minimum of that quadratic, repeat.

  • Lagrange multipliers solve constrained optimisation: find the optimum of $f(x, y)$ subject to a constraint $g(x, y) = c$. Instead of searching all of $\mathbb{R}^n$, we are restricted to the set where the constraint holds (a curve or surface).

  • The key insight is geometric: at the constrained optimum, the gradient of $f$ must be parallel to the gradient of $g$. If they were not parallel, we could move along the constraint in a direction that still improves $f$, so we would not be at the optimum yet.

  • We introduce a new variable $\lambda$ (the Lagrange multiplier) and define the Lagrangian:

$$\mathcal{L}(x, y, \lambda) = f(x, y) - \lambda(g(x, y) - c)$$

  • Setting all partial derivatives to zero gives a system of equations whose solutions are the constrained optima:

$$\frac{\partial \mathcal{L}}{\partial x} = 0, \quad \frac{\partial \mathcal{L}}{\partial y} = 0, \quad \frac{\partial \mathcal{L}}{\partial \lambda} = 0$$

Lagrange multipliers: at the optimum, gradients of f and g are parallel

  • For example, maximise $f(x,y) = x^2 y$ subject to $x^2 + y^2 = 1$. The Lagrangian is $\mathcal{L} = x^2 y - \lambda(x^2 + y^2 - 1)$. Taking partials:

$$2xy - 2\lambda x = 0, \quad x^2 - 2\lambda y = 0, \quad x^2 + y^2 = 1$$

  • From the first equation (assuming $x \neq 0$): $\lambda = y$. Substituting into the second: $x^2 = 2y^2$. Combined with the constraint: $2y^2 + y^2 = 1$, so $y = \frac{1}{\sqrt{3}}$. The maximum value is $f = \frac{2}{3\sqrt{3}}$.

  • For inequality constraints ($g(x,y) \leq c$ instead of $= c$), the Karush-Kuhn-Tucker (KKT) conditions generalise Lagrange multipliers. The constraint is either active (binding, treated as equality) or inactive (the solution lies in the interior and the constraint is irrelevant).

  • In practice, we rarely optimise by hand. Here are the main algorithmic families:

    • First-order methods (use only gradient): gradient descent, stochastic gradient descent (SGD), Adam. These are cheap per step but can converge slowly, especially on ill-conditioned problems.

    • Second-order methods (use gradient and Hessian): Newton's method converges fast but computing and inverting the Hessian is expensive ($O(n^3)$ for $n$ parameters). Quasi-Newton methods (like BFGS and L-BFGS) approximate the Hessian using only gradient information, achieving faster convergence than first-order methods without the full cost of second-order methods.

    • Conjugate gradient: efficient for large sparse systems, using only matrix-vector products instead of storing the full Hessian.

    • Gauss-Newton and Levenberg-Marquardt: specialised for least-squares problems (common in regression), approximating the Hessian via the Jacobian.

    • Natural gradient descent: accounts for the geometry of the parameter space using the Fisher information matrix, which can be more effective for probabilistic models.

  • The choice of optimiser depends on the problem. For deep learning, first-order methods (especially Adam) dominate because the number of parameters is enormous (millions to billions), making Hessian computation impractical. For smaller problems with smooth objectives, second-order methods can be dramatically faster.

Coding Tasks (use CoLab or notebook)

  1. Implement Newton's method to find $\sqrt{7}$ (a zero of $f(x) = x^2 - 7$). Observe the rapid convergence.
import jax.numpy as jnp

f = lambda x: x**2 - 7
df = lambda x: 2*x

x = 3.0  # initial guess
for i in range(6):
    x = x - f(x) / df(x)
    print(f"step {i+1}: x = {x:.10f}  (error: {abs(x - jnp.sqrt(7.0)):.2e})")
  1. Use gradient descent to minimise $f(x, y) = (x - 3)^2 + (y + 1)^2$. The minimum is at $(3, -1)$. Experiment with different learning rates.
import jax
import jax.numpy as jnp

def f(params):
    x, y = params
    return (x - 3)**2 + (y + 1)**2

grad_f = jax.grad(f)
params = jnp.array([0.0, 0.0])
lr = 0.1

for i in range(20):
    g = grad_f(params)
    params = params - lr * g
    if i % 5 == 0 or i == 19:
        print(f"step {i:2d}: ({params[0]:.4f}, {params[1]:.4f})  loss={f(params):.6f}")
  1. Solve a constrained optimisation problem numerically. Maximise $f(x,y) = xy$ subject to $x + y = 10$ by parameterising $y = 10 - x$ and finding the optimum of the single-variable function.
import jax
import jax.numpy as jnp

# Substitute constraint: y = 10 - x, so f = x(10 - x) = 10x - x²
f = lambda x: x * (10 - x)
df = jax.grad(f)

# Gradient ascent (we want maximum, so add gradient)
x = 1.0
lr = 0.1
for i in range(20):
    x = x + lr * df(x)
print(f"x={x:.4f}, y={10-x:.4f}, f={f(x):.4f}")  # should be x=5, y=5, f=25