Overview
Most textbooks bury good ideas under dense notation, skip the intuition, assume you already know half the material, and quickly get outdated in fast-moving fields like AI. This is an open, unconventional textbook covering maths, computing, and artificial intelligence from the ground up. Written for curious practitioners looking to deeply understand the stuff, not just survive an exam/interview.
Background
Over the past years working in AI/ML, I filled notebooks with intuition first, real-world context, no hand-waving explanations of maths, computing and AI concepts. In 2025, a few friends used these notes to prep for interviews at DeepMind, OpenAI, Nvidia etc. They all got in and currently perform well in their roles. So I'm sharing to everyone.
Outline
| # | Chapter | Summary | Status |
|---|---|---|---|
| 01 | Vectors | Spaces, magnitude, direction, norms, metrics, dot/cross/outer products, basis, duality | Available |
| 02 | Matrices | Properties, special types, operations, linear transformations, decompositions (LU, QR, SVD) | Available |
| 03 | Calculus | Derivatives, integrals, multivariate calculus, Taylor approximation, optimisation and gradient descent | Available |
| 04 | Statistics | Descriptive measures, sampling, central limit theorem, hypothesis testing, confidence intervals | Available |
| 05 | Probability | Counting, conditional probability, distributions, Bayesian methods, information theory | Available |
| 06 | Machine Learning | Classical ML, gradient methods, deep learning, reinforcement learning, distributed training | Available |
| 07 | Computational Linguistics | syntax, semantics, pragmatics, NLP, language models, RNNs, CNNs, attention, transformers, text diffusion, text OCR, MoE, SSMs, modern LLM architectures, NLP evaluation | Available |
| 08 | Computer Vision | image processing, object detection, segmentation, video processing, SLAM, CNNs, vision transformers, diffusion, flow matching, VR/AR | Coming |
| 09 | Audio & Speech | DSP, ASR, TTS, voice & acoustic activity detection, diarization, source separation, active noise cancelation, wavenet, conformer | Coming |
| 10 | Multimodal Learning | fusion strategies, contrastive learning, VLMs, image tokenizer, video audio co-generation | Coming |
| 11 | Autonomous Systems | perception, robot learning, VLAs, self-driving cars, space robots | Coming |
| 12 | Computing & OS | discreet maths, computer architecture, operating systems, RAM, concurrency, parallelism, programming languages | Coming |
| 13 | Data Structures & Algorithms | arrays, trees, graph, search, sorting, hashmaps | Coming |
| 14 | SIMD & GPU Programming | ARM & NEON, X86 chips, RISC ships, GPUs, TPUs, triton, CUDA, Vulkan | Coming |
| 15 | Systems Design | systems design fundamentals, cloud computing, large scale infra, ML systems design examples | Coming |
| 16 | Inference | quantisation, streamingLLMs, continuous batching, edge inference, | Coming |
| 17 | Intersecting Fields | quantum ML, neuromorphic ML, AI for finace, AI for bio | Coming |
| 18 | Research Blog | share small-scale experimental finding, 1 per md file with your name & affiliation, I have a lot | Coming |
Citation
@book{ndubuaku2025compendium,
title = {Maths, CS & AI Compendium},
author = {Henry Ndubuaku},
year = {2026},
publisher = {GitHub},
url = {https://github.com/HenryNdubuaku/maths-cs-ai-compendium}
}
Contributions
- Star & watch to get content as they drop.
- Suggest topics via GitHub issues.
- PR corrections and better intuition.
- Create SVG images in
../images/for all diagrams. - For equations, use
```mathfenced code blocks (NOT$$) - For display math — GitHub escapes
\\inside$$, breaking matrices. - Inline math
$...$is fine for simple expressions but move anything with\\into a```mathblock. - Use
\astinstead of*for conjugate/adjoint in inline math.
Vector Spaces
-
Think of a Vector Space as a specific kind of playground where mathematical objects live, and each object is called a vector.
-
For geometric intuition in machine learning (ML), we will always think of vectors as a point in Euclidean space, represented by it's coordinates.
-
The vector $\mathbf{a}$ (denoted mathematically as lowercase letters in bold) has $n$ coordinates, each representing a position along an axis.
$$\mathbf{a} = [a_1, a_2, a_3]$$
-
The vectors in the vector space live under a very specific, unbreakable set of rules:
-
Vector Addition (Combining): You can take any two vectors and combine them to create a new one. Think of vectors as instructions for movement. If vector A means "walk 3 steps forward" and vector B means "walk 2 steps right," adding them (A + B) creates a new, single instruction: "walk 3 steps forward and 2 steps right."
-
Scalar Multiplication (Scaling): You can take any vector and scale it using a regular number (a "scalar"). You can stretch it, shrink it, or reverse it. If vector A is "walk 3 steps forward," scaling it by 2 makes it "walk 6 steps forward." Scaling it by -1 flips it entirely to "walk 3 steps backward."
-
-
The dimension of a vector space is the number of independent directions it contains. $\mathbb{R}^2$ is 2-dimensional (needs 2 coordinates), while $\mathbf{a}$ above lives in $\mathbb{R}^3$.
-
We can for instance represent any object, say, a human, as a vector, where $h_1$ = height in cm, $h_2$ = weight in kg, $h_3$ = age.
$$\mathbf{h} = [185, 75, 30]$$
-
We have now created a vector space with a vector representing a human.
-
We can represent multiple humans, and see how close or apart they are!
-
We can add more features, creating a rich representation of a human, often called feature vectors in ML.
-
The more unique and meaningful features you have, the more descriptive the feature vector is, an important factor to remember.
-
Beyond 3 dimensions, vectors become very difficult to visually inspect, inspiring a field of mathematics called Linear Algebra.
-
Now, Linear algebra is the study of vectors, vector spaces and mappings between vectors.
-
We represent almost every thing in AI/ML as vectors, making linear algebra the bedrock of the field.
-
Vector addition can be performed by placing one vector on the tail of the other visually, and drawing from the origin to the endpoint.
-
For two vectors $\mathbf{a} = (a_1, a_2)$ and $\mathbf{b} = (b_1, b_2)$: $\mathbf{a} + \mathbf{b} = (a_1 + b_1, a_2 + b_2)$
-
Vectors can also be subtracted, with all addition rules applying too.
-
Multiplying a vector by a scalar scales the vector by that factor in the same direction.
-
For a scalar $c$ and vector $\mathbf{v} = (v_1, v_2)$: $c\mathbf{v} = (cv_1, cv_2)$
-
Closure under Addition: If you add any two vectors from the vector space, the result is also a vector within the same space: If $\mathbf{u} \in V$ and $\mathbf{v} \in V$, then $\mathbf{u} + \mathbf{v} \in V$
-
Closure under Scalar Multiplication: If you multiply any vector from the vector space by a scalar, the result is a vector within the same space: If $\mathbf{v} \in V$ and $c \in F$, then $c\mathbf{v} \in V$
-
Associativity of Addition: For any three vectors $\mathbf{u}$, $\mathbf{v}$, and $\mathbf{w}$: $(\mathbf{u} + \mathbf{v}) + \mathbf{w} = \mathbf{u} + (\mathbf{v} + \mathbf{w})$
-
Commutativity of Addition: For any two vectors $\mathbf{u}$ and $\mathbf{v}$: $\mathbf{u} + \mathbf{v} = \mathbf{v} + \mathbf{u}$
-
Both paths through the parallelogram arrive at the same point.
-
(Zero Vector): There exists a vector $\mathbf{0}$ such that for any vector $\mathbf{v}$: $\mathbf{v} + \mathbf{0} = \mathbf{v}$
- Additive Inverse: For every vector $\mathbf{v}$, there exists a vector $-\mathbf{v}$ such that: $\mathbf{v} + (-\mathbf{v}) = \mathbf{0}$
- Distributivity 1: For any scalar $c$ and vectors $\mathbf{u}$, $\mathbf{v}$: $c(\mathbf{u} + \mathbf{v}) = c\mathbf{u} + c\mathbf{v}$
-
Scaling the sum (gold) gives the same result as summing the scaled vectors.
-
Distributivity 2: For any scalars $c$, $d$ and vector $\mathbf{v}$: $(c + d)\mathbf{v} = c\mathbf{v} + d\mathbf{v}$
-
Associativity: For any scalars $c$, $d$ and vector $\mathbf{v}$: $(cd)\mathbf{v} = c(d\mathbf{v})$
-
Identity Element: For any vector $\mathbf{v}$: $1\mathbf{v} = \mathbf{v}$, where $1$ is the multiplicative identity in the field of scalars.
-
A subspace is just a smaller playground inside the bigger one. Imagine 3D space as a room. A flat sheet of paper passing through the centre of the room is a subspace, and so is a single straight wire through the centre.
-
The key requirement is that the subspace must pass through the origin. If you shift that sheet of paper off-centre, it stops being a subspace because the zero vector is no longer on it.
-
All the same rules from the vector space (addition, scaling, closure) still work inside a subspace. You can add or scale vectors within it and never "fall off" into the larger space.
-
A line through the origin is a 1-dimensional subspace, a plane through the origin is a 2-dimensional subspace, and the full space is a subspace of itself.
-
In ML, subspaces appear naturally. High-dimensional data often has structure that lives on a lower-dimensional subspace. Techniques like PCA find that subspace so we can work with the data more efficiently.
Coding Tasks (use CoLab or notebook)
- Run code to verify the distributivity property, then Modify and play around to test other rules!
import jax.numpy as jnp
u = jnp.array([1, 2])
v = jnp.array([3, 0])
c = 2
lhs = c * (u + v)
rhs = c*u + c*v
print(f"LHS: {lhs}")
print(f"RHS: {rhs}")
- Run code to visualise different vectors, then modify values for different coordinates to understand how each axis affects position.
import jax.numpy as jnp
import matplotlib.pyplot as plt
# Try changing these vectors!
a = jnp.array([3, 2, 4])
b = jnp.array([1, 4, 2])
c = jnp.array([4, 1, 3])
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
for vec, name, color in [(a, "a", "red"), (b, "b", "blue"), (c, "c", "green")]:
ax.quiver(0, 0, 0, *vec, color=color, arrow_length_ratio=0.1, linewidth=2, label=name)
lim = int(jnp.abs(jnp.stack([a, b, c])).max()) + 1
ax.set_xlim([0, lim]); ax.set_ylim([0, lim]); ax.set_zlim([0, lim])
ax.set_xlabel("X"); ax.set_ylabel("Y"); ax.set_zlabel("Z")
ax.legend()
plt.show()
Vector Properties
- The magnitude (or length) of a vector tells you how far it reaches. Think of it as the length of the arrow. For a vector $\mathbf{a} = (a_1, a_2, a_3)$, its magnitude is:
$$|\mathbf{a}| = \sqrt{a_1^2 + a_2^2 + a_3^2}$$
-
This is just the Pythagorean theorem extended to higher dimensions and measuring the straight-line distance from the origin to the point.
-
The direction of a vector tells you where it points; simply visualise a straight line from the origin to the coordinate's point.
-
When origin is not explicitly specifies, we often imply (0,0,...0), the centerpoint, at least for visualisation purposes.
-
Position doesn't matter, its always about displacement: a vector $(3, 2)$ drawn from the origin and the same $(3, 2)$ drawn from another point are still equal.
- Two vectors can have the same length but point in completely different directions, or point the same way but differ in length.
- Two vectors are equal if and only if all their corresponding components match; same length, same direction, the exact same arrow.
$$\mathbf{a} = \mathbf{b} \iff a_i = b_i \text{ for all } i$$
- Two vectors are parallel if one is a scalar multiple of the other. They point along the same line, either in the same direction or exactly opposite.
$$\mathbf{a} \parallel \mathbf{b} \iff \mathbf{a} = k\mathbf{b} \text{ for some scalar } k \neq 0$$
-
If $k > 0$, they point the same way. If $k < 0$, they point in opposite directions. Either way, they lie on the same line through the origin.
-
Intuitively, parallel vectors carry no "new" directional information. One is just a stretched or flipped version of the other.
-
Two vectors are orthogonal (perpendicular) if they point in completely independent directions. Moving along one gives you zero progress along the other.
-
Think of walking north and then walking east, these are orthogonal directions, no amount of walking north will ever move you east. We will encounter orthogonality very often.
-
Orthogonality is central to ML: features that are orthogonal carry completely independent information, which is ideal for representation.
-
More generally, any two vectors have an angle $\theta$ between them, ranging from $0°$ to $180°$.
-
This angle captures the entire relationship between two directions: $0°$ means parallel (same direction), $180°$ means parallel (opposite direction), and $90°$ means orthogonal. Everything in between is a blend.
-
Most vector relationships in ML live somewhere in this spectrum. Later, we will see exact tools (dot product, cosine similarity) to compute this angle.
-
A set of vectors is linearly dependent if at least one of them can be built from the others by scaling and adding. It brings no new information to the set.
-
For example, if $\mathbf{c} = 2\mathbf{a} + 3\mathbf{b}$, then $\mathbf{c}$ is redundant, you already have everything $\mathbf{c}$ offers through $\mathbf{a}$ and $\mathbf{b}$.
-
Parallel vectors are always linearly dependent, since one is just a scaled copy of the other. Any set containing the zero vector is also linearly dependent.
-
Vectors are linearly independent if none of them can be built from the others. Each one contributes a genuinely new direction. Orthogonal vectors are always linearly independent.
-
In 2D, two linearly independent vectors can reach any point in the plane. In 3D, you need three. This idea of "how many independent vectors you need" connects directly to dimension.
-
A vector is sparse when most of its components are zero. The opposite, most components being nonzero, is called dense.
$$\mathbf{s} = [0, 0, 3, 0, 0, 0, 1, 0, 0, 0]$$
-
Sparsity matters because it affects both storage and computation. Sparse vectors can be stored and processed much more efficiently by only tracking the nonzero entries.
-
A unit vector is a vector with magnitude exactly 1. It purely represents a direction with no length information. You can turn any vector into a unit vector by dividing by its magnitude:
$$\hat{\mathbf{a}} = \frac{\mathbf{a}}{|\mathbf{a}|}$$
-
This process is called normalisation. It strips away "how far" and keeps only "which way."
-
The standard unit vectors point along each axis: $\hat{\mathbf{i}} = (1, 0, 0)$, $\hat{\mathbf{j}} = (0, 1, 0)$, $\hat{\mathbf{k}} = (0, 0, 1)$. Any vector can be written as a combination of these, e.g. $(3, 2, 4) = 3\hat{\mathbf{i}} + 2\hat{\mathbf{j}} + 4\hat{\mathbf{k}}$.
Coding Tasks (use CoLab or notebook)
- Compute the magnitude of a vector and verify it matches the Pythagorean theorem, then modify to compute the unit vector.
import jax.numpy as jnp
a = jnp.array([3.0, 4.0])
magnitude = jnp.sqrt(jnp.sum(a ** 2))
print(f"Magnitude of a: {magnitude}")
- Check whether two vectors are parallel by testing if one is a scalar multiple of the other.
import jax.numpy as jnp
a = jnp.array([2, 4, 6])
b = jnp.array([1, 2, 3])
ratios = a / b
print(f"Ratios: {ratios}")
print(f"Parallel: {jnp.allclose(ratios, ratios[0])}")
Metrics and Norms
-
We know vectors have magnitude and direction. But how do we actually measure "how big" a single vector is, or "how far apart" two vectors are? This is where norms and metrics come in.
-
In scalars, we know that 10 > 5, because their values quantify them, but how can we quantify a vector? It's norm, it measures the size of a single vector.
-
The most familiar norm is the Euclidean norm (L2), which is just the magnitude formula we already know:
$$|\mathbf{v}|_2 = \sqrt{v_1^2 + v_2^2 + \cdots + v_n^2}$$
- But there are other ways to measure size. Imagine you are in a city with a grid of streets. You cannot walk diagonally through buildings, so the "length" of your journey is the total blocks walked along each street. This is the Manhattan norm (L1):
$$|\mathbf{v}|_1 = |v_1| + |v_2| + \cdots + |v_n|$$
- Or you might only care about the single largest component, ignoring the rest. This is the Max norm (L-infinity):
$$|\mathbf{v}|_\infty = \max(|v_1|, |v_2|, \ldots, |v_n|)$$
- All three are special cases of the general Lp norm:
$$|\mathbf{v}|_p = (|v_1|^p + |v_2|^p + \cdots + |v_n|^p)^{1/p}$$
-
Setting $p = 2$ gives Euclidean, $p = 1$ gives Manhattan, and as $p \to \infty$ you get the Max norm. As $p$ grows, the largest component contributes more and more, until eventually only it matters.
-
Every norm must obey three rules:
-
Non-negativity: $|\mathbf{v}| \geq 0$, and $|\mathbf{v}| = 0$ only if $\mathbf{v} = \mathbf{0}$. Size is never negative, and only the zero vector has zero size.
-
Scaling: $|c\mathbf{v}| = |c| \cdot |\mathbf{v}|$. Doubling a vector doubles its size.
-
Triangle inequality: $|\mathbf{u} + \mathbf{v}| \leq |\mathbf{u}| + |\mathbf{v}|$. The shortcut is never longer than going the long way round.
-
-
Now, a metric measures the distance between two vectors. Think of it as asking: "how far apart are these two points?"
-
The simplest way to get a metric is to use a norm on the difference: $d(\mathbf{u}, \mathbf{v}) = |\mathbf{u} - \mathbf{v}|$. Subtract the two vectors, then measure the size of what remains.
-
Using the Euclidean norm this gives us the familiar Euclidean distance:
$$d(\mathbf{u}, \mathbf{v}) = \sqrt{(u_1 - v_1)^2 + (u_2 - v_2)^2 + \cdots + (u_n - v_n)^2}$$
-
Using the Manhattan norm gives Manhattan distance, the total difference along each axis, like counting city blocks between two locations.
-
Every metric must obey four rules:
-
Non-negativity: $d(\mathbf{u}, \mathbf{v}) \geq 0$. Distance is never negative.
-
Identity: $d(\mathbf{u}, \mathbf{v}) = 0$ if and only if $\mathbf{u} = \mathbf{v}$. Zero distance means the same point.
-
Symmetry: $d(\mathbf{u}, \mathbf{v}) = d(\mathbf{v}, \mathbf{u})$. The distance from A to B is the same as from B to A.
-
Triangle inequality: $d(\mathbf{u}, \mathbf{w}) \leq d(\mathbf{u}, \mathbf{v}) + d(\mathbf{v}, \mathbf{w})$. Going directly is never longer than taking a detour.
-
-
So what is the relationship between the two? A norm measures one vector, a metric measures the gap between two. Every norm naturally creates a metric (by measuring the difference), but not every metric comes from a norm.
-
For example, Hamming distance counts the number of positions where two vectors differ. It is a valid metric, but it does not come from any norm.
-
In ML, choosing the right norm or metric matters.
-
L2 distance squares each difference before summing, so a single large difference dominates the result.
-
L1 distance sums the absolute differences, treating each one equally. A single large difference has less influence compared to L2.
Coding Tasks (use CoLab or notebook)
- Compute L1, and L2 norms of the same vector. Try changing the values and notice which norm is most sensitive to large components vs many small ones. Then try computing the Lp norm for increasing values of p (e.g. 1, 2, 5, 10, 50, 100) and watch it converge towards the L-infinity value.
import jax.numpy as jnp
v = jnp.array([3.0, -4.0, 1.0])
l1 = jnp.sum(jnp.abs(v))
l2 = jnp.sqrt(jnp.sum(v ** 2))
print(f"L1: {l1}, L2: {l2:.2f}")
- Compute the Euclidean and Manhattan distance between two vectors. Try moving the vectors closer or further apart and observe how each distance responds differently.
import jax.numpy as jnp
u = jnp.array([1.0, 2.0, 3.0])
v = jnp.array([4.0, 0.0, 1.0])
euclidean = jnp.sqrt(jnp.sum((u - v) ** 2))
manhattan = jnp.sum(jnp.abs(u - v))
print(f"Euclidean: {euclidean:.2f}, Manhattan: {manhattan}")
Vector Products
-
We have seen how to add and scale vectors. But can we multiply two vectors together? It turns out there is more than one way to do it, and each answers a different question.
-
An inner product is the general idea: a function that takes two vectors and produces a single number (a scalar). It is the abstract blueprint for "multiplying" vectors.
-
Any inner product must satisfy three rules:
-
Positive definiteness: $\langle \mathbf{v}, \mathbf{v} \rangle \geq 0$, and equals zero only for the zero vector. Multiplying a vector with itself always gives a non-negative result.
-
Symmetry: $\langle \mathbf{u}, \mathbf{v} \rangle = \langle \mathbf{v}, \mathbf{u} \rangle$. The order does not matter.
-
Linearity: $\langle a\mathbf{u} + b\mathbf{v}, \mathbf{w} \rangle = a\langle \mathbf{u}, \mathbf{w} \rangle + b\langle \mathbf{v}, \mathbf{w} \rangle$. It distributes over addition and scaling.
-
-
The dot product is the most common inner product. It is the concrete version you will use almost everywhere. For two vectors $\mathbf{a} = (a_1, a_2, \ldots, a_n)$ and $\mathbf{b} = (b_1, b_2, \ldots, b_n)$:
$$\mathbf{a} \cdot \mathbf{b} = a_1 b_1 + a_2 b_2 + \cdots + a_n b_n$$
-
Multiply matching components, then add everything up. That is all it is.
-
But what does this number mean? The dot product has a beautiful geometric interpretation:
$$\mathbf{a} \cdot \mathbf{b} = |\mathbf{a}| , |\mathbf{b}| \cos(\theta)$$
-
This connects the dot product directly to the angle $\theta$ between the two vectors. The result tells you how much the two vectors "agree" in direction.
-
If they point the same way ($\theta = 0°$), $\cos(\theta) = 1$ and the dot product is maximised.
-
If they are orthogonal ($\theta = 90°$), $\cos(\theta) = 0$ and the dot product is exactly zero. This gives us a precise test for orthogonality.
-
If they point in opposite directions ($\theta = 180°$), $\cos(\theta) = -1$ and the dot product is negative.
-
A vector dotted with itself gives its magnitude squared: $\mathbf{a} \cdot \mathbf{a} = |\mathbf{a}|^2$.
-
The dot product also gives us projection, the shadow one vector casts onto another. The projection of $\mathbf{a}$ onto $\mathbf{b}$ is:
$$\text{proj}_{\mathbf{b}}(\mathbf{a}) = \frac{\mathbf{a} \cdot \mathbf{b}}{|\mathbf{b}|^2} , \mathbf{b}$$
-
Think of shining a light straight down onto $\mathbf{b}$. The shadow of $\mathbf{a}$ on that line is the projection. It tells you how much of $\mathbf{a}$ lies in the direction of $\mathbf{b}$.
-
Cosine similarity normalises the dot product by dividing out both magnitudes:
$$\cos(\theta) = \frac{\mathbf{a} \cdot \mathbf{b}}{|\mathbf{a}| , |\mathbf{b}|}$$
-
This gives a value between $-1$ and $1$ that measures direction alignment, ignoring how long the vectors are. It is widely used in ML to compare things like documents, embeddings, and user preferences.
-
Now, the dot product takes two vectors and returns a scalar. The cross product does the opposite, it takes two vectors and returns a new vector.
-
The cross product $\mathbf{a} \times \mathbf{b}$ produces a vector that is perpendicular to both $\mathbf{a}$ and $\mathbf{b}$:
$$\mathbf{a} \times \mathbf{b} = (a_2 b_3 - a_3 b_2, ; a_3 b_1 - a_1 b_3, ; a_1 b_2 - a_2 b_1)$$
-
The cross product only works in 3D. While the dot product works in any number of dimensions, the cross product is specific to three-dimensional space.
-
Its magnitude equals the area of the parallelogram formed by the two vectors:
$$|\mathbf{a} \times \mathbf{b}| = |\mathbf{a}| , |\mathbf{b}| \sin(\theta)$$
-
Notice the pattern: the dot product uses $\cos(\theta)$ and the cross product uses $\sin(\theta)$. The dot product measures how much two vectors align, the cross product measures how much they differ in direction.
-
The direction of the result follows the right-hand rule: curl the fingers of your right hand from $\mathbf{a}$ towards $\mathbf{b}$, and your thumb points in the direction of $\mathbf{a} \times \mathbf{b}$.
-
Unlike the dot product, the cross product is not commutative: $\mathbf{a} \times \mathbf{b} = -(\mathbf{b} \times \mathbf{a})$. Swapping the order flips the direction.
-
If two vectors are parallel, their cross product is the zero vector (since $\sin(0°) = 0$). No area, no perpendicular direction.
-
What happens when you combine three vectors using both products? This gives us triple products.
-
The scalar triple product $\mathbf{a} \cdot (\mathbf{b} \times \mathbf{c})$ first takes the cross product of two vectors, then dots the result with the third. The output is a single number that equals the volume of the parallelepiped (a slanted 3D box) formed by the three vectors.
-
If the scalar triple product is zero, the three vectors are coplanar, they all lie in the same flat plane and form no volume.
-
The order can be cycled without changing the result: $\mathbf{a} \cdot (\mathbf{b} \times \mathbf{c}) = \mathbf{b} \cdot (\mathbf{c} \times \mathbf{a}) = \mathbf{c} \cdot (\mathbf{a} \times \mathbf{b})$.
-
The vector triple product $\mathbf{a} \times (\mathbf{b} \times \mathbf{c})$ applies the cross product twice and returns a vector. It expands neatly using the identity:
$$\mathbf{a} \times (\mathbf{b} \times \mathbf{c}) = (\mathbf{a} \cdot \mathbf{c})\mathbf{b} - (\mathbf{a} \cdot \mathbf{b})\mathbf{c}$$
- The result always lies in the plane spanned by $\mathbf{b}$ and $\mathbf{c}$. Note that the cross product is not associative: $\mathbf{a} \times (\mathbf{b} \times \mathbf{c}) \neq (\mathbf{a} \times \mathbf{b}) \times \mathbf{c}$.
Coding Tasks (use CoLab or notebook)
- Compute the dot product of two vectors and use it to find the angle between them. Try making them orthogonal, parallel, or opposite and see how the angle changes.
import jax.numpy as jnp
a = jnp.array([1.0, 2.0, 3.0])
b = jnp.array([4.0, -1.0, 2.0])
dot = jnp.dot(a, b)
angle = jnp.arccos(dot / (jnp.linalg.norm(a) * jnp.linalg.norm(b)))
print(f"Dot product: {dot}")
print(f"Angle: {jnp.degrees(angle):.1f}°")
- Compute the cross product of two 3D vectors and verify the result is perpendicular to both by checking that its dot product with each original vector is zero.
import jax.numpy as jnp
a = jnp.array([1.0, 0.0, 0.0])
b = jnp.array([0.0, 1.0, 0.0])
cross = jnp.cross(a, b)
print(f"a x b = {cross}")
print(f"Perpendicular to a: {jnp.dot(cross, a) == 0}")
print(f"Perpendicular to b: {jnp.dot(cross, b) == 0}")
Basis and Duality
-
We have seen that vectors live in spaces with a certain number of dimensions. But what defines those dimensions? This is where basis vectors come in.
-
A basis is a set of vectors that can build every other vector in the space through scaling and adding (linear combination), with no redundancy. They are the building blocks of the space.
-
A basis must satisfy two conditions:
-
Linearly independent: No basis vector can be built from the others. Each one contributes a genuinely new direction.
-
Spanning: Every vector in the space can be expressed as a combination of the basis vectors. Nothing is left out.
-
-
The number of vectors in a basis equals the dimension of the space. In $\mathbb{R}^2$ you need 2, in $\mathbb{R}^3$ you need 3, and so on.
-
The most natural basis is the standard basis, the unit vectors along each axis:
- In $\mathbb{R}^2$: $\hat{\mathbf{i}} = (1, 0)$ and $\hat{\mathbf{j}} = (0, 1)$
- In $\mathbb{R}^3$: $\hat{\mathbf{i}} = (1, 0, 0)$, $\hat{\mathbf{j}} = (0, 1, 0)$, $\hat{\mathbf{k}} = (0, 0, 1)$
-
Any vector is just a weighted sum of these basis vectors. The vector $(3, 2)$ is really $3\hat{\mathbf{i}} + 2\hat{\mathbf{j}}$. The weights (3 and 2) are the coordinates of the vector in that basis.
-
But the standard basis is not the only valid basis. In $\mathbb{R}^2$, the vectors $(1, 1)$ and $(-1, 1)$ also form a basis. They are linearly independent and can reach any point in the plane. The same vector will just have different coordinates in this new basis.
-
A change of basis re-expresses the same vector using different building blocks. The vector has not moved, we are just describing it from a different perspective.
-
This is done by multiplying by a change of basis matrix $P$, whose columns are the new basis vectors written in the old coordinates. To go back, multiply by $P^{-1}$.
-
In ML, change of basis appears frequently. PCA, for example, finds a new basis (the principal components) where the data is easier to understand, the axes align with the directions of greatest variation.
-
Now, there is a deeper idea hiding here. When we write $\mathbf{v} = (3, 2)$, the coordinates 3 and 2 are really the result of "measuring" $\mathbf{v}$ along each basis direction. The first coordinate asks "how much of $\hat{\mathbf{i}}$ is in $\mathbf{v}$?", the second asks "how much of $\hat{\mathbf{j}}$?"
-
Each of these measurements is a linear functional, a function that takes a vector and returns a single number. The collection of all such linear functionals forms the dual space $V^\ast$.
-
Think of it this way: vectors are the objects, and linear functionals are the rulers that measure them. The dual space is the set of all possible rulers.
-
For every basis ${\mathbf{e}_1, \mathbf{e}_2, \ldots, \mathbf{e}_n}$, there is a corresponding dual basis ${\mathbf{e}_1^\ast, \mathbf{e}_2^\ast, \ldots, \mathbf{e}_n^\ast}$. Each dual basis vector extracts exactly one coordinate:
-
$\mathbf{e}_1^\ast$ returns 1 when applied to $\mathbf{e}_1$ and 0 for everything else. It perfectly isolates the first coordinate.
-
The dot product connects these two worlds. When you compute $\mathbf{u} \cdot \mathbf{v}$, you can think of one vector acting as a "ruler" measuring the other. The dot product $\mathbf{u} \cdot \mathbf{v}$ is the same as applying the linear functional defined by $\mathbf{u}$ to the vector $\mathbf{v}$.
-
This means every vector secretly defines a linear functional, and every linear functional can be represented by a vector. In finite dimensions, the dual space is essentially a mirror image of the original space.
-
Duality may seem abstract now, but it underlies many practical ideas: coordinates are dual basis evaluations, the dot product is a duality pairing, and transformations like attention in neural networks operate by having one set of vectors "query" another, which is duality in action.
Coding Tasks (use CoLab or notebook)
- Express a vector in two different bases and verify they represent the same point. Try creating your own basis and see what coordinates the vector gets.
import jax.numpy as jnp
v = jnp.array([3.0, 2.0])
# Standard basis: coordinates are just the components
print(f"Standard basis coords: {v}")
# New basis: (1,1) and (-1,1)
P = jnp.array([[1.0, -1.0],
[1.0, 1.0]])
new_coords = jnp.linalg.solve(P, v)
print(f"New basis coords: {new_coords}")
# Verify: reconstruct from new coords
reconstructed = new_coords[0] * P[:, 0] + new_coords[1] * P[:, 1]
print(f"Reconstructed: {reconstructed}")
- Verify the dual basis property: each dual basis vector extracts exactly one coordinate and returns zero for the others.
import jax.numpy as jnp
# Standard basis in R3
e1 = jnp.array([1.0, 0.0, 0.0])
e2 = jnp.array([0.0, 1.0, 0.0])
e3 = jnp.array([0.0, 0.0, 1.0])
v = jnp.array([5.0, 3.0, 7.0])
# Each dot product extracts one coordinate
print(f"e1 · v = {jnp.dot(e1, v)}")
print(f"e2 · v = {jnp.dot(e2, v)}")
print(f"e3 · v = {jnp.dot(e3, v)}")
Matrix Properties
- At its core, a matrix is a rectangular grid of numbers arranged in rows and columns. If a vector is a single list of numbers, a matrix is a table of them.
-
You can also think of a matrix as a stack of vectors.
-
If a single person is described by the vector $[\text{age}, \text{height}, \text{weight}]$, then three people form a matrix where each row is one person:
-
This matrix has 3 rows and 3 columns, so we call it a $3 \times 3$ matrix.
-
Each number in the grid is called an element or entry, identified by its row and column: $A_{ij}$ is the element in row $i$, column $j$.
-
The transpose of a matrix flips it along its diagonal, turning rows into columns and columns into rows. If $A$ is $m \times n$, then $A^T$ is $n \times m$.
-
Multiplying a matrix by its transpose always gives a square matrix: $AA^T$ is $m \times m$ and $A^TA$ is $n \times n$.
-
The trace of a square matrix is the sum of its diagonal elements: $\text{tr}(A) = A_{11} + A_{22} + \cdots + A_{nn}$. The trace equals the sum of the eigenvalues (which we will see later).
-
For the matrix above, $\text{tr}(A) = 1 + 4 + 9 = 14$. Only the highlighted diagonal matters.
-
If two matrices represent the same linear transformation under different bases, their traces will be the same. The trace is "basis-independent."
-
The rank of a matrix is the number of linearly independent rows (or equivalently, columns). It tells you how much "useful information" the matrix carries.
-
For example, the following matrix has rank 2 because neither row is a multiple of the other:
But this matrix has rank 1 because the second row is just twice the first, so it adds no new information:
- A $5 \times 3$ matrix can have rank at most 3. If some rows are just scaled or combined versions of others, the rank drops. A matrix with maximum possible rank is called full rank.
-
A square matrix is invertible (has an inverse) if and only if it is full rank.
-
The rank is connected to the null space (the set of vectors that the matrix maps to zero) through the rank-nullity theorem: $\text{rank}(A) + \text{nullity}(A) = \text{number of columns of } A$. What the matrix keeps (rank) plus what it destroys (nullity) equals the total dimension.
-
The column space of a matrix is the set of all possible outputs when you multiply the matrix by any vector. It is spanned by the columns of the matrix. If a matrix has 3 columns but only 2 are independent, the column space is a 2D plane, not all of 3D space.
-
The row space is the same idea but from the perspective of rows. The rank equals the dimension of both the column space and the row space, so they always agree.
-
Together, the column space tells you "what outputs can this matrix produce?" and the null space tells you "what inputs get mapped to zero?" These two spaces completely describe what the matrix does.
-
The determinant of a square matrix is a single number that captures how the matrix scales space. Think of a $2 \times 2$ matrix as transforming a unit square into a parallelogram. The determinant is the area of that parallelogram (with a sign).
- For example:
The transformation stretches the unit square into a parallelogram with area 6.
-
If the determinant is positive, the transformation preserves orientation (things don't get "flipped"). If negative, it flips orientation (like a mirror reflection). If zero, the matrix squashes space into a lower dimension, collapsing the parallelogram to a line or point.
-
A matrix with determinant zero is called singular. It has no inverse and has lost information permanently.
-
For matrices larger than $2 \times 2$, the determinant is computed using minors and cofactors. The minor $M_{ij}$ is the determinant of the smaller matrix you get by deleting row $i$ and column $j$.
-
The cofactor $C_{ij} = (-1)^{i+j} M_{ij}$ attaches a sign to each minor (alternating like a checkerboard: $+, -, +, \ldots$). The determinant of the full matrix is then the sum along any row or column: $\det(A) = \sum_j A_{1j} \cdot C_{1j}$. This is called cofactor expansion.
-
The inverse of a square matrix $A$, written $A^{-1}$, is the matrix that undoes what $A$ does: $AA^{-1} = A^{-1}A = I$ (the identity matrix). Only non-singular matrices have inverses.
-
For a $2 \times 2$ matrix, the inverse has a direct formula:
Notice the determinant in the denominator, which is why singular matrices (determinant zero) have no inverse.
-
The condition number measures how sensitive a matrix is to small changes in its input. It is defined as $\kappa(A) = |A| \cdot |A^{-1}|$.
-
A condition number close to 1 means the matrix is well-conditioned: small input changes produce small output changes. A large condition number means it is ill-conditioned: tiny errors get amplified enormously. Orthogonal and identity matrices have condition number 1, while singular matrices have infinite condition number.
-
For example, the following matrix has condition number $10^8$. One direction is scaled normally while the other is nearly squashed to zero, so small perturbations along that direction get wildly distorted:
- Just as vectors have norms (length), matrices have norms that measure their "size." The most common is the Frobenius norm, which treats the matrix as a long vector and computes its length:
- For example:
-
The spectral norm $|A|_2$ is the largest singular value of $A$. It measures the maximum amount the matrix can stretch any unit vector. In ML, matrix norms are used for weight regularisation (penalising large weights) and monitoring training stability.
-
A symmetric matrix $A$ is positive definite if for every non-zero vector $\mathbf{x}$: $\mathbf{x}^T A \mathbf{x} > 0$. This quadratic form always produces a positive number.
-
For example, the following matrix is positive definite:
Pick any vector, say $\mathbf{x} = [1, -1]^T$: $\mathbf{x}^T A \mathbf{x} = 2 - 1 - 1 + 3 = 3 > 0$. No matter which non-zero $\mathbf{x}$ you try, you always get a positive result.
-
Positive definite matrices are important because they guarantee that optimisation problems have a unique minimum.
-
If the condition is relaxed to $\mathbf{x}^T A \mathbf{x} \geq 0$ (allowing zero), the matrix is positive semi-definite (PSD). PSD matrices come up constantly: covariance matrices, kernel matrices in SVMs, and Hessians at local minima are all PSD. The difference is that PSD allows some directions to be "flat" (zero curvature) rather than strictly curving upward.
Coding Tasks (use CoLab or notebook)
- Compute the trace, rank, and determinant of a matrix. Try making one row a multiple of another and see how rank and determinant change.
import jax.numpy as jnp
A = jnp.array([[1.0, 2.0],
[3.0, 4.0]])
print(f"Trace: {jnp.trace(A)}")
print(f"Rank: {jnp.linalg.matrix_rank(A)}")
print(f"Determinant: {jnp.linalg.det(A):.2f}")
- Compute the inverse of a matrix, multiply it by the original, and verify you get the identity. Then try a singular matrix and observe what happens.
import jax.numpy as jnp
A = jnp.array([[1.0, 2.0],
[3.0, 4.0]])
A_inv = jnp.linalg.inv(A)
print(f"A * A_inv:\n{A @ A_inv}")
Matrix Types
-
Not all matrices are the same. Different structures give matrices special properties that make them faster to compute with, easier to reason about, or both. Here are the types you will encounter most.
-
A square matrix has the same number of rows and columns ($n \times n$). Most of the interesting properties (determinant, eigenvalues, inverse) only apply to square matrices.
-
The identity matrix $I$ is a square matrix with 1s on the diagonal and 0s everywhere else. It is the "do nothing" transformation: $AI = IA = A$ for any compatible matrix $A$.
-
The zero matrix $O$ has all elements equal to zero. It maps every vector to the zero vector, destroying all information.
-
A diagonal matrix is all zeros except on the main diagonal. Multiplying a vector by a diagonal matrix simply scales each component independently, making it very efficient.
- A symmetric matrix equals its own transpose: $A = A^T$, meaning $A_{ij} = A_{ji}$. Symmetric matrices have the special property that their eigenvectors are always perpendicular to each other. Covariance matrices are always symmetric.
- A triangular matrix has all zeros on one side of the diagonal. Lower triangular has zeros above, upper triangular has zeros below. They are essential for solving systems of equations efficiently through forward or back substitution.
-
The determinant of a triangular matrix is simply the product of its diagonal elements.
-
An orthogonal matrix has the property that its transpose equals its inverse: $Q^TQ = QQ^T = I$.
-
This means you can "undo" the transformation just by transposing, which is computationally cheap. Its columns are orthonormal (unit length and mutually perpendicular).
-
A sparse matrix has most of its elements equal to zero, while a dense matrix has most elements nonzero.
-
In practice, many real-world matrices are extremely sparse.
-
A social network with a million users could be represented as a $10^6 \times 10^6$ matrix, but each person only connects to a handful of others, so nearly all entries are zero.
-
A permutation matrix is obtained by rearranging the rows of an identity matrix. Multiplying by it shuffles the elements of a vector. Every row and every column has exactly one 1 and the rest are 0s.
-
For example, the matrix below moves element 3 to position 1, element 1 to position 2, and element 2 to position 3:
- A Toeplitz matrix has the same value along every diagonal (upper-left to lower-right). Notice how each diagonal is constant:
-
This structure appears in signal processing and convolution, because sliding a fixed filter across a signal is equivalent to multiplying by a Toeplitz matrix.
-
A circulant matrix is a special Toeplitz matrix where each row is a cyclic shift of the one above. When a row reaches the end, it wraps around:
-
Circulant matrices are closely connected to the discrete Fourier transform (DFT) and are central to how circular convolution works.
-
A Hermitian matrix is the complex equivalent of a symmetric matrix: $A = A^\ast$ (where $A^\ast$ is the conjugate transpose).
-
For real-valued matrices, Hermitian and symmetric are the same thing. You will encounter these in quantum computing and signal processing.
-
A unitary matrix is the complex equivalent of an orthogonal matrix: $U^\ast U = UU^\ast = I$. Just as orthogonal matrices preserve lengths in real spaces, unitary matrices preserve lengths in complex spaces.
-
An idempotent matrix satisfies $A^2 = A$. Applying the transformation twice is the same as applying it once, which makes it a projection. Once you have projected, projecting again changes nothing.
-
A nilpotent matrix satisfies $A^k = O$ (the zero matrix) for some power $k$. Apply the transformation enough times and everything collapses to zero. For example:
- A Boolean matrix (or binary matrix) contains only 0s and 1s. It represents yes/no relationships. For example, in a graph with 3 nodes, the adjacency matrix records which nodes are connected:
-
Here, node 1 connects to nodes 2 and 3, but nodes 2 and 3 are not connected to each other.
-
A Vandermonde matrix is built from consecutive powers of a set of values. Given values $x_1, x_2, x_3$:
-
This structure appears in polynomial interpolation: finding the unique polynomial that passes through a given set of points.
-
A Hessenberg matrix is "almost" triangular, with zeros below the first subdiagonal:
- It is a useful intermediate form for computing eigenvalues efficiently. Reducing a matrix to Hessenberg form first makes iterative algorithms converge faster.
Coding Tasks (use CoLab or notebook)
- Create an orthogonal matrix (rotation matrix), multiply it by its transpose, and verify you get the identity. Try different angles.
import jax.numpy as jnp
theta = jnp.pi / 4
Q = jnp.array([[jnp.cos(theta), -jnp.sin(theta)],
[jnp.sin(theta), jnp.cos(theta)]])
print(f"Q @ Q.T:\n{Q @ Q.T}")
print(f"Determinant: {jnp.linalg.det(Q):.2f}")
- Create a symmetric matrix and verify that it equals its transpose. Then compute its eigenvalues and check that the eigenvectors are perpendicular.
import jax.numpy as jnp
S = jnp.array([[4.0, 2.0],
[2.0, 3.0]])
print(f"Symmetric: {jnp.allclose(S, S.T)}")
eigenvalues, eigenvectors = jnp.linalg.eigh(S)
print(f"Eigenvalues: {eigenvalues}")
print(f"Dot product of eigenvectors: {jnp.dot(eigenvectors[:, 0], eigenvectors[:, 1]):.6f}")
Matrix Operations
-
Matrices can be added and scaled just like vectors.
-
For addition, both matrices must have the same dimensions, and you add element by element:
- For scalar multiplication, you multiply every element by the scalar:
- The simplest thing you can do with a matrix is multiply it by a vector. Matrix-vector multiplication $A\mathbf{x}$ combines the columns of $A$ using the entries of $\mathbf{x}$ as weights:
-
This is the core operation in ML. Every neural network layer computes $A\mathbf{x} + \mathbf{b}$: a matrix times an input vector, plus a bias.
-
The general case is matrix multiplication. Given $A$ ($m \times n$) and $B$ ($n \times p$), the product $C = AB$ is an $m \times p$ matrix where each element is a dot product:
$$C_{ij} = \sum_{k=1}^{n} A_{ik} B_{kj}$$
-
Each entry in the result is the dot product of a row from $A$ with a column from $B$. The inner dimensions must match ($n$), and the result takes the outer dimensions ($m \times p$).
-
Another way to see it: each column of the result is a weighted sum of the columns of $A$, where the weights come from the corresponding column of $B$.
-
If $B$ has column $[2, 3]^T$, the result column is $2 \times (\text{column 1 of } A) + 3 \times (\text{column 2 of } A)$.
-
A useful special case: multiplying a matrix by its transpose always gives a square matrix. $AA^T$ is $m \times m$ and $A^TA$ is $n \times n$:
-
Matrix multiplication has important rules:
-
Not commutative: $AB \neq BA$ in general. The order matters.
-
Associative: $(AB)C = A(BC)$. You can group multiplications however you like.
-
Distributive: $A(B + C) = AB + AC$.
-
Identity: $AI = IA = A$.
-
-
The Hadamard product (element-wise product) multiplies two matrices of the same size entry by entry, written $A \odot B$:
-
Unlike standard matrix multiplication, the Hadamard product is commutative ($A \odot B = B \odot A$) and requires both matrices to have the same dimensions. It is used heavily in ML for gating: multiplying element-wise by a mask of values between 0 and 1 controls how much of each entry "passes through."
-
The outer product of two vectors $\mathbf{u}$ and $\mathbf{v}$ produces a matrix: $\mathbf{u}\mathbf{v}^T$. Each entry is the product of one element from $\mathbf{u}$ and one from $\mathbf{v}$:
-
The result always has rank 1, because every row is a scaled version of $\mathbf{v}^T$. Any matrix can be written as a sum of rank-1 outer products, which is exactly what SVD does (covered in decompositions).
-
Matrix multiplication is computationally expensive. Multiplying two $n \times n$ matrices takes $O(n^3)$ operations. For a $1000 \times 1000$ matrix, that is a billion multiplications.
-
When matrices are sparse (mostly zeros), naive multiplication wastes time multiplying by zero. The Compressed Sparse Row (CSR) format stores only the nonzero elements along with their positions:
- Values: the nonzero entries in row order
- Column indices: which column each value belongs to
- Row offsets: where each row starts in the values list
-
For example, the matrix:
-
Is stored as: values = [5, 2, 3, -1], columns = [0, 3, 2, 3], row offsets = [0, 2, 3, 4]. This skips all the zeros and makes sparse operations much faster.
-
A core use of matrices is solving systems of linear equations. The system $A\mathbf{x} = \mathbf{b}$ asks: "what vector $\mathbf{x}$, when transformed by $A$, produces $\mathbf{b}$?"
-
For example, say you are buying fruit. Apples cost $x_1$ dollars each and bananas cost $x_2$ dollars each. You know that 2 apples and 1 banana cost $5, and 1 apple and 3 bananas cost $10. In matrix form:
- Multiplying the matrix by the vector row by row (each row dotted with $[x_1, x_2]^T$) gives two equations:
$$2x_1 + 1x_2 = 5 \qquad \text{(row 1)} \qquad \qquad x_1 + 3x_2 = 10 \qquad \text{(row 2)}$$
-
From row 1, $x_2 = 5 - 2x_1$. Substituting into row 2: $x_1 + 3(5 - 2x_1) = 10$, which gives $x_1 = 1$, then $x_2 = 3$. Apples cost $1 and bananas cost $3.
-
Verify — it checks out:
-
If $A$ has an inverse, the solution is simply $\mathbf{x} = A^{-1}\mathbf{b}$. But computing the inverse directly is expensive and numerically unstable. In practice, we use decompositions instead.
-
Not every matrix is square, and not every square matrix is invertible. The pseudo-inverse $A^+$ generalises the inverse to any matrix. It always exists and provides the "best possible" inverse:
$$A^+ = (A^TA)^{-1}A^T$$
-
When $A$ is lower triangular, solving $L\mathbf{x} = \mathbf{b}$ is easy by forward substitution: solve for $x_1$ first, then use it to find $x_2$, and so on down.
-
When $A$ is upper triangular, solving $U\mathbf{x} = \mathbf{b}$ works by back substitution: solve for the last variable first, then work upward.
-
This is why decomposing a matrix into triangular factors (as we will see in decompositions) is so useful. It turns a hard problem into two easy ones.
Coding Tasks (use CoLab or notebook)
- Multiply two matrices and verify the dimensions. Then swap the order and observe that the result changes (or that it fails if dimensions don't match).
import jax.numpy as jnp
A = jnp.array([[1.0, 2.0],
[3.0, 4.0]])
B = jnp.array([[5.0, 6.0],
[7.0, 8.0]])
print(f"A @ B:\n{A @ B}")
print(f"B @ A:\n{B @ A}")
print(f"Equal: {jnp.allclose(A @ B, B @ A)}")
- Solve a system of linear equations $A\mathbf{x} = \mathbf{b}$ and verify the solution by multiplying back. Try changing $\mathbf{b}$ to see how the solution shifts.
import jax.numpy as jnp
A = jnp.array([[2.0, 1.0],
[5.0, 3.0]])
b = jnp.array([4.0, 7.0])
x = jnp.linalg.solve(A, b)
print(f"Solution x: {x}")
print(f"A @ x: {A @ x}")
Linear Transformations
-
A linear transformation (or linear map) is a function that takes a vector and produces another vector, while preserving addition and scaling. If $T$ is linear, then:
- $T(\mathbf{u} + \mathbf{v}) = T(\mathbf{u}) + T(\mathbf{v})$
- $T(c\mathbf{u}) = cT(\mathbf{u})$
-
Every linear transformation can be represented as multiplication by a matrix. The matrix is the transformation. When you multiply a vector by a matrix, you are applying a linear transformation to it.
-
Think of a $2 \times 2$ matrix as a machine that takes in 2D vectors and outputs new 2D vectors. The columns of the matrix tell you where the standard basis vectors $\hat{\mathbf{i}}$ and $\hat{\mathbf{j}}$ end up after the transformation. Everything else follows from linearity.
- For example, if
then $\hat{\mathbf{i}} = [1, 0]^T$ lands at $[2, 1]^T$ (column 1) and $\hat{\mathbf{j}} = [0, 1]^T$ lands at $[1, 2]^T$ (column 2). Every other vector is a combination of these two, so its output follows automatically.
-
Multiplying two matrices can be thought of as applying one transformation after another. If $B$ transforms vectors from one space and $A$ transforms the result, then $AB$ does both in sequence. In a game engine, rotating a character and then moving them forward is a different result from moving them first and then rotating, which is why matrix multiplication is not commutative.
-
Rotation turns vectors by an angle $\theta$ without changing their length. The vector stays the same size, it just points in a new direction.
- In 2D, the rotation matrix is:
- For $\theta = 90°$:
so $[1, 0]^T$ becomes $[0, 1]^T$. The vector pointing right now points up. Rotation matrices are orthogonal and always have determinant 1. When you rotate a photo on your phone, this is the exact matrix being applied to every pixel coordinate.
- In 3D, there are separate rotation matrices for each axis. A robotic arm rotates each joint around a specific axis, and each joint is one rotation matrix. Rotation around the z-axis looks like the 2D case embedded in 3D:
- Scaling stretches or shrinks vectors along each axis independently:
-
$S(2, 1.5)$ doubles the x-component and multiplies the y-component by 1.5. Scaling by $-1$ along an axis flips that component. A diagonal matrix is always a scaling transformation. When you resize an image to 50%, you are applying $S(0.5, 0.5)$ to every pixel coordinate.
-
Reflection flips vectors across an axis or line, like a mirror. Reflecting across the x-axis keeps the x-component and negates the y-component:
- For example, $[3, 2]^T$ becomes $[3, -2]^T$. When your phone flips a selfie horizontally so text reads correctly, it is applying a reflection matrix. Reflecting across the line $y = x$ swaps the two components:
-
Reflection matrices have determinant $-1$, confirming they flip orientation.
-
Rotations and reflections are both rigid transformations: they preserve distances and angles. The matrices that represent them are orthogonal matrices, which is why orthogonal matrices always have determinant $+1$ (rotation) or $-1$ (reflection).
-
Shearing skews vectors along one axis proportionally to the other. A horizontal shear by factor $k$:
-
Each point slides horizontally by $k$ times its height. With $k = 0.5$, a point at height 2 shifts right by 1. The bottom row stays put, the top row slides. This is how italic text works: upright letters are sheared so they slant to the right.
-
All of the above (rotation, scaling, reflection, shearing) are linear transformations. They keep the origin fixed and preserve straight lines. But what about translation (shifting everything by a fixed amount)?
-
Translation is not a linear transformation because it moves the origin. If you shift every point right by 3, the zero vector moves to $[3, 0]^T$, breaking linearity. To handle it, we use an affine transformation, which combines a linear transformation with a translation:
$$\mathbf{y} = A\mathbf{x} + \mathbf{t}$$
- To represent this as a single matrix multiplication, we use homogeneous coordinates: add an extra 1 to every vector and use an $(n+1) \times (n+1)$ matrix:
-
Affine transformations preserve straight lines and parallelism, but not necessarily angles or lengths. Every object in a video game is positioned using affine transformations: rotate it, scale it, then place it at the right location, all encoded in a single matrix.
-
A degenerate transformation (singular matrix) collapses space into a lower dimension.
-
For example, the matrix
maps every 2D vector onto a single line, because both columns point in the same direction. The determinant is zero, information is lost, and the transformation cannot be undone.
-
Converting a colour image (3 values per pixel: red, green, blue) to grayscale (1 value per pixel) is a degenerate transformation: the colour information is permanently gone.
-
In ML, linear transformations are the core of neural networks, data is represented as a matrix (a stack of vectors representing features of an object like humans, planes, text, image...anything!)
-
Each layer applies a matrix multiplication (linear transformation), details are provided in other chapters, we need to explain hpw to structure these data and motivate neural networks properly.
-
However, the most used techniques today often almost exclusively passes the data through a bunch of linear transformations, we call these Transformers.
-
Gemini, ChatGPT, Claude, Qwen, DeepSeek and the best performing AI in the world today, are transformers!
Coding Tasks (use CoLab or notebook)
- Apply a rotation matrix to a vector and plot both the original and rotated vector. Try different angles.
import jax.numpy as jnp
import matplotlib.pyplot as plt
theta = jnp.pi / 3
R = jnp.array([[jnp.cos(theta), -jnp.sin(theta)],
[jnp.sin(theta), jnp.cos(theta)]])
v = jnp.array([1.0, 0.0])
v_rot = R @ v
plt.figure(figsize=(5, 5))
plt.quiver(0, 0, v[0], v[1], angles='xy', scale_units='xy', scale=1, color='red', label='original')
plt.quiver(0, 0, v_rot[0], v_rot[1], angles='xy', scale_units='xy', scale=1, color='blue', label='rotated')
plt.xlim(-1.5, 1.5); plt.ylim(-1.5, 1.5)
plt.grid(True); plt.legend(); plt.gca().set_aspect('equal')
plt.show()
- Apply a shearing transformation to a set of points forming a square and visualise the deformed shape.
import jax.numpy as jnp
import matplotlib.pyplot as plt
square = jnp.array([[0,0],[1,0],[1,1],[0,1],[0,0]]).T
k = 0.5
shear = jnp.array([[1, k],
[0, 1]])
sheared = shear @ square
plt.figure(figsize=(6, 4))
plt.plot(square[0], square[1], 'r-o', label='original')
plt.plot(sheared[0], sheared[1], 'b-o', label='sheared')
plt.grid(True); plt.legend(); plt.gca().set_aspect('equal')
plt.show()
Matrix Decompositions
-
A matrix decomposition (or factorisation) breaks a matrix into simpler pieces that are easier to work with. Think of it like factoring a number: $12 = 3 \times 4$ is easier to reason about than 12 alone.
-
We decompose matrices to solve systems of equations faster, compute inverses stably, find eigenvalues, compress data, and understand the geometry of transformations.
-
The most fundamental technique is Gaussian elimination (row reduction). The idea is simple: given a system $A\mathbf{x} = \mathbf{b}$, use three allowed operations to simplify $A$ until the answer is obvious.
-
The operations are: swap two rows, multiply a row by a nonzero scalar, or add a multiple of one row to another.
-
For example, to eliminate the first column below the pivot, subtract multiples of row 1 from the rows below:
- The goal is row echelon form (REF): zeros below each pivot (the first nonzero entry in each row), with each pivot to the right of the one above it. The matrix becomes a staircase shape.
-
Going further to reduced row echelon form (RREF) makes every pivot equal to 1 and the only nonzero entry in its column. Every matrix has a unique RREF.
-
Once in triangular form, we solve by back substitution: the bottom row gives the last variable directly, then work upward.
-
This is the foundation that all other decompositions build upon, the goal of decompositions is to reduce a matrix to a triangular form, so we can back substitute and solve for the variables.
-
LU decomposition formalises Gaussian elimination by factoring a square matrix into $A = LU$ (or $A = PLU$ with row swaps), where $L$ is lower triangular and $U$ is upper triangular.
-
To solve $A\mathbf{x} = \mathbf{b}$: first solve $L\mathbf{y} = \mathbf{b}$ by forward substitution (top to bottom), then solve $U\mathbf{x} = \mathbf{y}$ by back substitution (bottom to top). Two easy triangular solves instead of one hard general solve.
-
The advantage over raw Gaussian elimination is reuse. Once you have $L$ and $U$, you can solve for many different $\mathbf{b}$ vectors without redoing the factorisation.
-
If you need to solve the same system with 1000 different right-hand sides (common in simulations), you factorise once and reuse.
-
When a matrix is symmetric and positive definite (like a covariance matrix), we can do even better.
-
Cholesky decomposition factors it as $A = LL^T$, where $L$ is lower triangular. For example:
-
This is roughly twice as fast as LU and is guaranteed to be numerically stable. Think of it as a "square root" of the matrix.
-
If the decomposition fails (a negative value under a square root), the matrix is not positive definite. Cholesky thus doubles as a test for positive definiteness.
-
The eigenvectors of a square matrix $A$ are the special directions that the transformation only stretches or shrinks, without rotating. The eigenvalue is the scaling factor:
$$A\mathbf{x} = \lambda\mathbf{x}$$
-
Most vectors change direction when multiplied by a matrix. But eigenvectors are special: the output points in the same direction as the input, just scaled by $\lambda$. If $\lambda = 2$, the eigenvector doubles in length. If $\lambda = -1$, it flips direction. If $\lambda = 0$, it gets squashed to zero.
-
For example, with:
the vector $[1, 0]^T$ is an eigenvector with $\lambda = 3$ because $A[1, 0]^T = [3, 0]^T = 3[1, 0]^T$.
-
To find eigenvalues, solve the characteristic polynomial $\det(A - \lambda I) = 0$. The roots are the eigenvalues. Then substitute each $\lambda$ back into $(A - \lambda I)\mathbf{x} = \mathbf{0}$ to find the corresponding eigenvectors.
-
Key properties:
- The trace of $A$ equals the sum of its eigenvalues.
- The determinant of $A$ equals the product of its eigenvalues.
- Symmetric matrices have perpendicular eigenvectors and real eigenvalues.
- Positive definite matrices have all positive eigenvalues.
- Covariance matrices (which we will encounter in statistics) are always positive semi-definite.
-
Computing eigenvalues via the characteristic polynomial is impractical for large matrices. Instead, iterative methods are used:
-
Power iteration: repeatedly multiply by $A$ and normalise. Converges to the dominant eigenvector (largest eigenvalue). Simple but only finds one eigenpair.
-
QR algorithm: the workhorse method. Repeatedly decompose and recombine using QR factorisation until the matrix converges to triangular form, revealing all eigenvalues on the diagonal.
-
Inverse iteration: finds the eigenvector closest to a given target value. Useful when you know roughly which eigenvalue you want.
-
For large sparse matrices, Arnoldi and Lanczos iterations exploit the sparsity for efficiency.
-
-
If a square matrix has a full set of linearly independent eigenvectors, it can be diagonalised: $A = PDP^{-1}$, where $D$ is a diagonal matrix of eigenvalues and the columns of $P$ are the eigenvectors.
-
Why is this useful? Diagonal matrices are trivial to work with. Need $A^{100}$? Instead of multiplying $A$ by itself 100 times, compute $PD^{100}P^{-1}$, and raising a diagonal matrix to a power just raises each entry independently. This turns an expensive operation into a cheap one.
-
An eigenbasis is a basis made entirely of eigenvectors. In this basis, the matrix becomes diagonal and the transformation is just independent scaling along each eigenvector direction. This is like finding the natural coordinate system for the transformation.
-
QR decomposition factors any matrix $A$ into $A = QR$, where $Q$ is orthogonal (its columns are orthonormal) and $R$ is upper triangular. Think of it as separating the "direction" information ($Q$) from the "scaling and mixing" information ($R$).
-
The Gram-Schmidt process builds $Q$ column by column. Take the first column of $A$ and normalise it. Take the second column, subtract its projection onto the first (to make it perpendicular), and normalise. Repeat for each column. The result is an orthonormal set of vectors.
-
QR decomposition is the engine behind the QR algorithm for eigenvalues. It is also used directly for solving least-squares problems: when $A\mathbf{x} = \mathbf{b}$ has no exact solution (more equations than unknowns), QR finds the best approximate answer.
-
SVD (Singular Value Decomposition) is the most general and arguably the most important decomposition. Every matrix (any shape, any rank) has an SVD: $A = U\Sigma V^T$
- $V^T$ ($n \times n$, orthogonal): rotates the input
- $\Sigma$ ($m \times n$, diagonal): scales along orthogonal axes (the singular values, non-negative, in descending order)
- $U$ ($m \times m$, orthogonal): rotates the output
-
Geometrically, SVD says that every linear transformation, no matter how complicated, is just a rotation, followed by a stretch along the axes, followed by another rotation. A circle becomes an ellipse.
-
The singular values ($\sigma_1 \geq \sigma_2 \geq \ldots$) reveal the "importance" of each direction. Large singular values correspond to directions that matter most. The rank of $A$ equals the number of nonzero singular values.
-
Low-rank approximation: by keeping only the $k$ largest singular values and zeroing the rest, you get the best possible rank-$k$ approximation of $A$. This is how image compression works: a $1000 \times 1000$ image might need only $k = 50$ singular values to look nearly identical, compressing it by 20x.
-
SVD also provides the pseudo-inverse: $A^+ = V\Sigma^+U^T$, where $\Sigma^+$ inverts the nonzero singular values.
-
While eigendecomposition only works for square matrices, SVD works for any matrix. This is its key advantage.
-
PCA (Principal Component Analysis) uses eigendecomposition (or SVD) for dimensionality reduction.
-
Imagine a dataset with 100 features per sample (vector of dim 100 stacked into a matrix). Many of those features are correlated and redundant.
-
PCA finds the directions along which the data actually varies, letting you keep only what matters.
-
The first principal component (PC1) is the direction of greatest variance.
-
The second (PC2) captures the most variance of what remains, and is perpendicular to the first.
-
If most of the variance lives along just a few directions, you can project the data down to those dimensions and discard the rest with minimal loss.
-
The steps:
- Standardise the data (subtract mean, divide by standard deviation) so all features contribute equally
- Compute the covariance matrix
- Find its eigenvalues and eigenvectors
- Select the $k$ eigenvectors with the largest eigenvalues (these are the principal components)
- Project the data onto these components
-
Standardisation is critical: without it, a feature measured in kilometres would dominate one measured in centimetres, regardless of actual importance.
-
In practice, PCA is used for visualisation (projecting high-dimensional data to 2D or 3D), noise reduction (discarding low-variance directions that are mostly noise), and speeding up ML models by reducing the number of input features.
-
Kernel PCA extends PCA to nonlinear relationships. It maps the data through a kernel function into a higher-dimensional space where the structure becomes linear, then applies standard PCA and projects back.
-
Schur decomposition factors a square matrix as $A = QTQ^\ast$, where $Q$ is unitary and $T$ is upper triangular. Every square matrix has a Schur decomposition, even if it cannot be diagonalised.
-
Non-negative Matrix Factorisation (NMF) decomposes a matrix into two non-negative matrices: $A \approx WH$, where all entries in $W$ and $H$ are $\geq 0$. Unlike SVD, which can produce negative entries, NMF only adds, never subtracts. This makes the parts interpretable: in topic modelling, $W$ gives the topic weights per document and $H$ gives the word weights per topic, all non-negative, matching how we think about "how much of each topic" a document contains.
-
The spectral theorem states that symmetric (or Hermitian) matrices can always be diagonalised with an orthogonal (or unitary) matrix. Their eigenvalues are always real and their eigenvectors always orthogonal. This is the theoretical foundation behind PCA.
Coding Tasks (use CoLab or notebook)
- Compute the eigenvalues and eigenvectors of a symmetric matrix. Verify that the eigenvectors are perpendicular and reconstruct the matrix from its eigendecomposition.
import jax.numpy as jnp
A = jnp.array([[4.0, 2.0],
[2.0, 3.0]])
eigenvalues, eigenvectors = jnp.linalg.eigh(A)
print(f"Eigenvalues: {eigenvalues}")
print(f"Eigenvectors orthogonal: {jnp.dot(eigenvectors[:,0], eigenvectors[:,1]):.6f}")
# Reconstruct: A = P D P^T
D = jnp.diag(eigenvalues)
A_reconstructed = eigenvectors @ D @ eigenvectors.T
print(f"Reconstruction matches: {jnp.allclose(A, A_reconstructed)}")
- Implement power iteration to find the largest eigenvalue, and inverse iteration to find the smallest. Compare with
jnp.linalg.eigh. Then try implementing the QR algorithm yourself.
import jax.numpy as jnp
A = jnp.array([[4.0, 2.0],
[2.0, 3.0]])
# Power iteration: finds the LARGEST eigenvalue
v = jnp.array([1.0, 0.0])
for _ in range(20):
v = A @ v
v = v / jnp.linalg.norm(v)
print(f"Largest eigenvalue: {v @ A @ v:.4f}")
# Inverse iteration: multiply by A^{-1} instead of A, finds the SMALLEST eigenvalue
v = jnp.array([1.0, 0.0])
for _ in range(20):
v = jnp.linalg.solve(A, v)
v = v / jnp.linalg.norm(v)
print(f"Smallest eigenvalue: {1.0 / (v @ jnp.linalg.solve(A, v)):.4f}")
print(f"jnp.linalg.eigh: {jnp.linalg.eigh(A)[0]}")
- Compute the SVD of a matrix, then reconstruct it using only the top-k singular values and observe how the approximation quality changes with k.
import jax.numpy as jnp
A = jnp.array([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0]])
U, S, Vt = jnp.linalg.svd(A)
for k in [1, 2, 3]:
approx = U[:, :k] @ jnp.diag(S[:k]) @ Vt[:k, :]
error = jnp.linalg.norm(A - approx)
print(f"k={k}, reconstruction error: {error:.4f}")
Differential Calculus
-
In the previous chapters, we learned how to represent data as vectors and transform it with matrices. But many real-world phenomena are not static. A car accelerates, a stock price fluctuates, a neural network's loss changes as weights update. Calculus is the mathematics of change.
-
Calculus asks two questions: how fast is something changing right now? (differential calculus) and how much has it accumulated over time? (integral calculus). This section tackles the "how fast" question.
-
Imagine you are driving and glance at your speedometer. It reads 60 km/h. That number is not the average speed of your entire trip; it is your speed at this exact instant. Differential calculus gives us the tools to compute such instantaneous rates of change.
-
But first, let us revisit the equation of a straight line: $y = mx + b$.
-
This is the simplest relationship between two quantities.
- $b$ is the y-intercept, where the line crosses the y-axis (the starting value when $x = 0$).
- $m$ is the slope, the rate of change: for every 1 unit increase in $x$, $y$ changes by $m$.
-
If $m = 3$, the line rises steeply; if $m = 0$, the line is flat; if $m = -2$, the line falls.
-
The slope is computed as $m = \frac{\Delta y}{\Delta x} = \frac{y_2 - y_1}{x_2 - x_1}$, the ratio of "how much did $y$ change" to "how much did $x$ change."
-
Once you know $m$ and $b$, you can compute $y$ for any $x$.
-
For example, if $m = 2$ and $b = 3$, then at $x = 5$: $y = 2(5) + 3 = 13$.
-
The two parameters fully determine the line, and predicting any output is just plugging in.
-
For a straight line, the slope is the same everywhere.
-
This idea generalises beyond lines. Any function is a rule that maps inputs to outputs, and once you know its formula (its parameters and shape), you can compute the output for any input and plot the result.
-
$y = x^2$ gives a parabola, $y = \sin(x)$ gives a wave, $y = e^x$ gives exponential growth. Each formula defines a specific curve, and being comfortable reading a function as a shape is essential for everything that follows.
-
For a straight line, the slope is the same everywhere. But most interesting functions are curved, so the slope varies from point to point. Calculus gives us a way to find the slope at any single point on a curve.
-
We also need the concept of a limit. A limit describes what value a function approaches as its input gets closer and closer to some target, without necessarily reaching it.
$$\lim_{x \to a} f(x) = L$$
-
This reads: "as $x$ approaches $a$, $f(x)$ approaches $L$." The function does not need to actually equal $L$ at $x = a$. It just needs to get arbitrarily close.
-
For example, take $f(x) = \frac{x^2 - 1}{x - 1}$. If you plug in $x = 1$ directly, you get $\frac{0}{0}$, which is undefined.
-
But try values close to 1: $f(0.9) = 1.9$, $f(0.99) = 1.99$, $f(1.01) = 2.01$. The outputs are clearly heading towards 2.
-
Algebraically, we can see why: factor the numerator as $(x-1)(x+1)$, cancel the $(x-1)$ terms, and we get $f(x) = x + 1$ for all $x \neq 1$. So as $x \to 1$, $f(x) \to 2$.
-
The function has a hole at $x = 1$, but the limit still exists.
-
Limits are the foundation that everything else in calculus rests on.
-
The derivative of a function $f(x)$ at a point $x = a$ measures the instantaneous rate of change. Geometrically, it is the slope of the tangent line to the curve at that point.
- To compute this slope, we start with two points on the curve and compute the slope of the line through them (a secant line). Then we slide the second point closer and closer to the first, and see what slope the secant line approaches. This is the difference quotient:
$$f'(a) = \lim_{h \to 0} \frac{f(a + h) - f(a)}{h}$$
-
The numerator $f(a+h) - f(a)$ is the change in output. The denominator $h$ is the change in input. Their ratio is the average rate of change over a tiny interval. As $h \to 0$, this average becomes the instantaneous rate.
-
For example, let $f(x) = x^2$. At $x = 3$:
$$f'(3) = \lim_{h \to 0} \frac{(3+h)^2 - 9}{h} = \lim_{h \to 0} \frac{9 + 6h + h^2 - 9}{h} = \lim_{h \to 0} (6 + h) = 6$$
-
So at $x = 3$, the function $x^2$ is increasing at a rate of 6 units of output per unit of input.
-
A function is differentiable at a point if this limit exists. For that to happen, the function must be continuous (no jumps), smooth (no sharp corners), and defined in a neighbourhood around the point.
-
If you can draw the curve without lifting your pen and without any kinks, it is probably differentiable there.
-
Computing derivatives from the limit definition every time would be tedious. Fortunately, a handful of rules let us differentiate almost any function quickly.
-
Constant rule: the derivative of a constant is zero. If $f(x) = 5$, then $f'(x) = 0$. A flat line has zero slope.
-
Power rule: the workhorse of differentiation. Bring the exponent down and reduce it by one:
$$\frac{d}{dx} x^n = n x^{n-1}$$
-
For example: $\frac{d}{dx} x^3 = 3x^2$. The cubic becomes a quadratic. This works for any real exponent, including negatives and fractions: $\frac{d}{dx} x^{-1} = -x^{-2}$ and $\frac{d}{dx} \sqrt{x} = \frac{d}{dx} x^{1/2} = \frac{1}{2}x^{-1/2}$.
-
Sum/Difference rule: differentiate term by term.
$$\frac{d}{dx}[f(x) \pm g(x)] = f'(x) \pm g'(x)$$
- Product rule: when two functions are multiplied, the derivative is not simply the product of the derivatives. Instead:
$$\frac{d}{dx}[f(x) \cdot g(x)] = f'(x)g(x) + f(x)g'(x)$$
-
Think of it as: "the rate of change of the first times the second, plus the first times the rate of change of the second." For example, $\frac{d}{dx}[x^2 \sin x] = 2x \sin x + x^2 \cos x$.
-
Quotient rule: for a ratio of functions:
$$\frac{d}{dx}\left[\frac{f(x)}{g(x)}\right] = \frac{f'(x)g(x) - f(x)g'(x)}{[g(x)]^2}$$
-
A useful mnemonic: "low d-high minus high d-low, over the square of what's below."
-
Chain rule: the most important rule for ML. When functions are composed (one inside another), the derivative is the product of the derivatives along the chain:
$$\frac{d}{dx} f(g(x)) = f'(g(x)) \cdot g'(x)$$
- Think of it as peeling an onion. Differentiate the outer function (keeping the inner function untouched), then multiply by the derivative of the inner function.
-
For example, $\frac{d}{dx} (3x + 1)^5 = 5(3x+1)^4 \cdot 3 = 15(3x+1)^4$. The outer function is $(\cdot)^5$ and the inner is $3x+1$.
-
The chain rule is the mathematical foundation of backpropagation in neural networks. A deep network is a long chain of composed functions. To compute how the loss changes with respect to each weight, we apply the chain rule repeatedly from the output layer back to the input, multiplying local derivatives at each step.
-
Here are the most common derivatives you will encounter. Each one can be derived from the limit definition, but knowing them by heart saves time:
| Function | Derivative | Notes |
|---|---|---|
| $e^x$ | $e^x$ | The only function that is its own derivative |
| $a^x$ | $a^x \ln a$ | Generalises the exponential |
| $\ln x$ | $\frac{1}{x}$ | The natural logarithm |
| $\log_a x$ | $\frac{1}{x \ln a}$ | General logarithm |
| $\sin x$ | $\cos x$ | |
| $\cos x$ | $-\sin x$ | Note the negative sign |
| $\tan x$ | $\sec^2 x$ |
-
The exponential function $e^x$ is remarkable: it is the only function that equals its own derivative. This is why $e$ appears everywhere in ML, from softmax activations to probability distributions.
-
L'Hopital's Rule handles limits that produce indeterminate forms like $\frac{0}{0}$ or $\frac{\infty}{\infty}$. When direct substitution gives one of these forms, you can take the derivative of the numerator and denominator separately and try the limit again:
$$\lim_{x \to a} \frac{f(x)}{g(x)} = \lim_{x \to a} \frac{f'(x)}{g'(x)}$$
-
Conditions: both $f$ and $g$ must be differentiable near $a$, and $g'(x) \neq 0$ near $a$ (except possibly at $a$ itself). The original limit must give an indeterminate form.
-
For example: $\lim_{x \to 0} \frac{\sin x}{x}$. Direct substitution gives $\frac{0}{0}$. Applying L'Hopital's Rule: $\lim_{x \to 0} \frac{\cos x}{1} = 1$. This limit is fundamental, it appears in signal processing and Fourier analysis.
-
You can apply the rule repeatedly if the result is still indeterminate. For instance, $\lim_{x \to 0} \frac{1 - \cos x}{x^2}$ gives $\frac{0}{0}$. First application: $\lim_{x \to 0} \frac{\sin x}{2x}$, still $\frac{0}{0}$. Second application: $\lim_{x \to 0} \frac{\cos x}{2} = \frac{1}{2}$.
-
If two functions are differentiable, their sum, difference, product, composition, and quotient (where the denominator is non-zero) are also differentiable. This is why we can confidently differentiate complex expressions built from simple pieces.
Coding Tasks (use CoLab or notebook)
- Visualise common functions. Plot $x^2$, $\sin(x)$, and $e^x$ side by side to build intuition for how different formulas produce different shapes. Try changing parameters (e.g. $2x^2$, $\sin(2x)$) and observe how the curves change.
import jax.numpy as jnp
import matplotlib.pyplot as plt
x = jnp.linspace(-3, 3, 300)
fig, axes = plt.subplots(1, 3, figsize=(12, 3))
axes[0].plot(x, x**2, color="#e74c3c")
axes[0].set_title("x² (parabola)")
axes[1].plot(x, jnp.sin(x), color="#3498db")
axes[1].set_title("sin(x) (wave)")
axes[2].plot(x, jnp.exp(x), color="#27ae60")
axes[2].set_title("eˣ (exponential)")
for ax in axes:
ax.axhline(0, color="gray", linewidth=0.5)
ax.axvline(0, color="gray", linewidth=0.5)
plt.tight_layout()
plt.show()
- Use JAX's automatic differentiation to compute the derivative of $f(x) = x^3 - 2x + 1$ at several points. Compare with the analytical derivative $f'(x) = 3x^2 - 2$.
import jax
import jax.numpy as jnp
f = lambda x: x**3 - 2*x + 1
df = jax.grad(f)
for x in [0.0, 1.0, 2.0, -1.0]:
print(f"x={x:5.1f} autodiff: {df(x):.4f} analytical: {3*x**2 - 2:.4f}")
- Verify the chain rule numerically. Define $f(x) = \sin(x^2)$, compute its derivative via
jax.grad, and compare with the analytical result $2x\cos(x^2)$.
import jax
import jax.numpy as jnp
f = lambda x: jnp.sin(x**2)
df = jax.grad(f)
for x in [0.5, 1.0, 2.0]:
auto = df(x)
analytical = 2*x * jnp.cos(x**2)
print(f"x={x:.1f} autodiff: {auto:.6f} analytical: {analytical:.6f}")
- Visualise the derivative. Plot $f(x) = x^3 - 3x$ and its derivative $f'(x) = 3x^2 - 3$ on the same graph. Notice where $f'(x) = 0$ corresponds to the peaks and valleys of $f$.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
f = lambda x: x**3 - 3*x
# jax.grad works on scalars; jax.vmap vectorises it to operate on an array of inputs at once
df = jax.vmap(jax.grad(f))
x = jnp.linspace(-2.5, 2.5, 200)
plt.plot(x, jax.vmap(f)(x), label="f(x)")
plt.plot(x, df(x), label="f'(x)", linestyle="--")
plt.axhline(0, color="gray", linewidth=0.5)
plt.legend()
plt.title("A function and its derivative")
plt.show()
Integral Calculus
-
Differentiation tells us the rate of change at a single point. Integration goes the other way: it accumulates many tiny pieces to compute a total.
-
If the derivative answers "how fast?", the integral answers "how much?"
-
The simplest way to think about integration is as the area under a curve. If you plot a function $f(x)$ and shade the region between the curve and the x-axis from $x = a$ to $x = b$, the integral gives the signed area of that region.
-
Why "signed"? Regions above the x-axis contribute positive area, regions below contribute negative area. This makes physical sense: if $f(x)$ represents velocity, the integral gives net displacement (forward minus backward), not total distance.
-
To compute this area, imagine slicing the region into $n$ thin vertical rectangles, each of width $\Delta x$. The height of each rectangle is the function value at some point in that slice. Sum them up:
$$\text{Area} \approx \sum_{i=1}^{n} f(x_i^\ast) , \Delta x$$
- As we make the rectangles thinner and thinner ($n \to \infty$, $\Delta x \to 0$), the sum becomes exact. This limiting process defines the definite integral:
$$\int_a^b f(x), dx = \lim_{n \to \infty} \sum_{i=1}^{n} f(x_i^\ast) , \Delta x$$
-
The $\int$ symbol is an elongated "S" for "sum." The $dx$ reminds us that we are summing infinitesimally thin slices along the x-axis.
-
An indefinite integral (or antiderivative) is a function $F(x)$ whose derivative is $f(x)$. We write:
$$\int f(x), dx = F(x) + C$$
-
The $+ C$ is the constant of integration. Since the derivative of any constant is zero, there are infinitely many antiderivatives that differ only by a constant. For example, $\int 2x, dx = x^2 + C$, because the derivative of $x^2 + 7$ or $x^2 - 3$ is still $2x$.
-
The Fundamental Theorem of Calculus is the bridge that connects differentiation and integration. It has two parts:
-
Part 1: If $F(x)$ is an antiderivative of $f(x)$, then the definite integral equals the difference of $F$ at the endpoints:
$$\int_a^b f(x), dx = F(b) - F(a)$$
-
This is remarkably practical. Instead of computing a limit of sums (which is hard), we find an antiderivative and evaluate it at two points (which is usually easy).
-
Part 2: If we define $F(x) = \int_a^x f(t), dt$, then $F'(x) = f(x)$. Differentiation and integration are inverse operations, they undo each other.
-
For example, to compute $\int_1^3 x^2, dx$: the antiderivative of $x^2$ is $\frac{x^3}{3}$. So $\int_1^3 x^2, dx = \frac{27}{3} - \frac{1}{3} = \frac{26}{3} \approx 8.67$.
-
Just as differentiation has rules, integration has corresponding rules that reverse them:
| Function | Integral | Condition |
|---|---|---|
| $x^n$ | $\frac{x^{n+1}}{n+1} + C$ | $n \neq -1$ |
| $\frac{1}{x}$ | $\ln|x| + C$ | |
| $e^x$ | $e^x + C$ | |
| $a^x$ | $\frac{a^x}{\ln a} + C$ | |
| $\sin x$ | $-\cos x + C$ | |
| $\cos x$ | $\sin x + C$ | |
| $k$ (constant) | $kx + C$ |
-
The sum/difference rule carries over: $\int [f(x) \pm g(x)], dx = \int f(x), dx \pm \int g(x), dx$. Constants can be pulled out: $\int k, f(x), dx = k \int f(x), dx$.
-
When a function is too complex to integrate directly, we have techniques to simplify it.
-
u-substitution is the reverse of the chain rule. If you spot a composite function $f(g(x))$ multiplied by $g'(x)$, substitute $u = g(x)$ so that $du = g'(x), dx$, and the integral simplifies.
-
For example: $\int 2x \cos(x^2), dx$. Let $u = x^2$, so $du = 2x, dx$. The integral becomes $\int \cos(u), du = \sin(u) + C = \sin(x^2) + C$.
-
Integration by parts is the reverse of the product rule. If the integrand is a product of two functions:
$$\int u, dv = uv - \int v, du$$
-
Choose $u$ and $dv$ strategically so that the remaining integral $\int v, du$ is simpler than the original. A common mnemonic for choosing $u$ is LIATE: Logarithmic, Inverse trig, Algebraic, Trigonometric, Exponential (pick $u$ from the earlier category).
-
For example: $\int x, e^x, dx$. Let $u = x$ (algebraic) and $dv = e^x, dx$. Then $du = dx$ and $v = e^x$. So: $\int x, e^x, dx = x, e^x - \int e^x, dx = x, e^x - e^x + C = e^x(x - 1) + C$.
-
In ML, integration appears in probability theory (computing probabilities by integrating density functions), in expected values (weighted averages over continuous distributions), and in computing the area under ROC curves. While we rarely integrate by hand in practice, understanding what integration means helps interpret these quantities.
Coding Tasks (use CoLab or notebook)
- Numerically approximate $\int_0^1 x^2, dx$ using a Riemann sum with increasing numbers of rectangles. Compare with the exact answer $\frac{1}{3}$.
import jax.numpy as jnp
for n in [10, 100, 1000, 10000]:
x = jnp.linspace(0, 1, n, endpoint=False)
dx = 1.0 / n
area = jnp.sum(x**2 * dx)
print(f"n={n:5d} approx: {area:.6f} exact: {1/3:.6f}")
- Verify the Fundamental Theorem of Calculus numerically. Define $F(x) = \int_0^x t^2, dt = \frac{x^3}{3}$ and check that its derivative (computed via
jax.grad) equals $x^2$.
import jax
import jax.numpy as jnp
F = lambda x: x**3 / 3
dF = jax.grad(F)
for x in [0.5, 1.0, 2.0, 3.0]:
print(f"x={x:.1f} F'(x)={dF(x):.4f} x^2={x**2:.4f}")
- Visualise the area under $f(x) = \sin(x)$ from $0$ to $\pi$. Use
plt.fill_betweento shade the area and compute it numerically with a Riemann sum.
import jax.numpy as jnp
import matplotlib.pyplot as plt
x = jnp.linspace(0, jnp.pi, 500)
y = jnp.sin(x)
plt.plot(x, y, color="purple", linewidth=2)
plt.fill_between(x, y, alpha=0.2, color="purple")
plt.title(f"Area = {jnp.sum(jnp.sin(x) * (jnp.pi / 500)):.4f} (exact: 2.0)")
plt.show()
Multivariate Calculus
-
So far, our functions have taken a single input $x$ and produced a single output $f(x)$. But in ML, we almost never work with just one variable.
-
Consider a function of two variables, like $f(x, y) = x^2 + y^2$. This defines a surface in 3D space, a bowl shape. We want to know: if we nudge $x$ a little while keeping $y$ fixed, how does $f$ change? That is a partial derivative.
-
The partial derivative of $f$ with respect to $x$, written $\frac{\partial f}{\partial x}$, treats every other variable as a constant and differentiates normally with respect to $x$.
-
For $f(x, y) = x^2y + 3x - 2y$:
$$\frac{\partial f}{\partial x} = 2xy + 3 \qquad \frac{\partial f}{\partial y} = x^2 - 2$$
-
To compute $\frac{\partial f}{\partial x}$, we treated $y$ as a constant, so $x^2y$ differentiated to $2xy$, $3x$ to $3$, and $-2y$ to $0$.
-
To compute $\frac{\partial f}{\partial y}$, we treated $x$ as a constant, so $x^2y$ differentiated to $x^2$, $3x$ to $0$, and $-2y$ to $-2$.
-
Geometrically, taking a partial derivative with respect to $x$ is like slicing the 3D surface with a plane parallel to the $xz$-plane (at a fixed $y$ value) and finding the slope of the resulting curve.
- The gradient collects all the partial derivatives into a single vector:
$$\nabla f = \left(\frac{\partial f}{\partial x_1}, \frac{\partial f}{\partial x_2}, \ldots, \frac{\partial f}{\partial x_n}\right)$$
-
For $f(x, y) = x^2 + y^2$: $\nabla f(x, y) = (2x, 2y)$. At the point $(1, 2)$: $\nabla f(1, 2) = (2, 4)$.
-
The gradient has two key properties:
-
Direction: it points in the direction of steepest increase. Imagine a hiker on a mountain. The gradient at their position points straight uphill, along the steepest path.
-
Magnitude: $|\nabla f|$ gives the rate of increase in that steepest direction. A large gradient means the terrain is steep; a small gradient means it is nearly flat.
-
-
Since the gradient points uphill, moving in the opposite direction ($-\nabla f$) goes downhill, towards lower values. This simple idea is the basis of gradient descent, an optimisation technique we will explore in detail in later chapters. For now, the key takeaway is that the gradient tells you which way is "up" and how steep the climb is.
-
The directional derivative generalises partial derivatives. Instead of asking "how does $f$ change along the $x$-axis?", it asks "how does $f$ change along any direction $\mathbf{u}$?" It is computed as the dot product of the gradient with a unit vector:
$$D_{\mathbf{u}} f = \nabla f \cdot \mathbf{u}$$
-
For $f(x, y) = x^2 + y^2$ at $(1, 2)$ in the direction of $\mathbf{v} = (3, 4)$: first normalise to get $\mathbf{u} = (3/5, 4/5)$, then $D_{\mathbf{u}} f = (2, 4) \cdot (3/5, 4/5) = 6/5 + 16/5 = 22/5$.
-
Partial derivatives are special cases of directional derivatives where the direction is along a coordinate axis. If the directional derivative is zero in some direction, the function is flat in that direction at that point.
-
Contour lines (or level curves) connect points where a function has the same value. For $f(x, y) = x^2 + y^2$, the contour lines are circles centred at the origin: $x^2 + y^2 = c$ for different values of $c$.
-
Contour lines never cross each other (a point cannot have two different function values).
-
The gradient is always perpendicular to the contour lines, pointing from lower to higher values.
-
Closely spaced contour lines indicate steep terrain; widely spaced lines indicate gentle slopes.
-
So far, our functions produced a single output. But many functions produce multiple outputs. A function $\mathbf{F}: \mathbb{R}^n \to \mathbb{R}^m$ takes $n$ inputs and produces $m$ outputs. The Jacobian matrix organises all the partial derivatives of such a vector-valued function:
-
Each row of the Jacobian is the gradient of one output component. For a function with 3 inputs and 2 outputs, the Jacobian is a $2 \times 3$ matrix.
-
The Jacobian generalises the derivative to vector-valued functions.
-
Just as the derivative of a scalar function tells you how much the output changes per unit input change, the Jacobian tells you how each output changes with respect to each input.
-
The determinant of the Jacobian measures how much a transformation locally stretches or compresses space.
-
If the determinant is 2, small regions double in area. If it is 0, the transformation squashes space to a lower dimension (recall from our chapter on matrices that a zero determinant means a singular, non-invertible transformation).
-
When several transformations are composed (one feeding into the next), the Jacobian of the overall mapping is the product of the individual Jacobians. We will see this idea become central in later chapters.
-
Where the gradient captures first-order information (slopes), the Hessian matrix captures second-order information (curvature).
-
For a scalar function $f(x_1, \ldots, x_n)$, the Hessian is the $n \times n$ matrix of all second partial derivatives:
- For $f(x, y) = x^3 + 2xy^2 - y^3$, the gradient is $(3x^2 + 2y^2,; 4xy - 3y^2)$, and the Hessian is:
-
The diagonal entries ($6x$ and $4x - 6y$) tell you how the slope in the $x$-direction changes as you move in $x$, and similarly for $y$.
-
The off-diagonal entries ($4y$) tell you how the slope in one direction changes as you move in the other direction.
-
Clairaut's theorem guarantees that for functions with continuous second derivatives, the mixed partial derivatives are equal: $\frac{\partial^2 f}{\partial x \partial y} = \frac{\partial^2 f}{\partial y \partial x}$.
-
This means the Hessian is symmetric, which (as we saw in the matrices chapter) guarantees real eigenvalues and orthogonal eigenvectors.
-
The Hessian tells us about the shape of the function near a critical point (where the gradient is zero):
- If $H$ is positive definite (all eigenvalues positive), the point is a local minimum, the surface curves upward in every direction like a bowl.
- If $H$ is negative definite (all eigenvalues negative), the point is a local maximum, the surface curves downward like an inverted bowl.
- If $H$ has both positive and negative eigenvalues, the point is a saddle point, the surface curves up in some directions and down in others, like a mountain pass.
-
The multivariate chain rule extends the chain rule to functions of several variables. If $z = f(x, y)$ where $x = g(t)$ and $y = h(t)$, then:
$$\frac{dz}{dt} = \frac{\partial f}{\partial x}\frac{dx}{dt} + \frac{\partial f}{\partial y}\frac{dy}{dt}$$
-
Each path from $t$ to $z$ contributes a term: the partial derivative along that path times the derivative of the intermediate variable with respect to $t$.
-
For example, if $z = x^2 y + 3x - y^2$, $x = \cos(t)$, $y = \sin(t)$:
$$\frac{dz}{dt} = (2xy + 3)(-\sin t) + (x^2 - 2y)(\cos t)$$
-
Beyond computing derivatives by hand, there are three approaches:
- Numerical differentiation: approximate $f'(x) \approx \frac{f(x+h) - f(x-h)}{2h}$ for small $h$. Simple but noisy and inaccurate.
- Symbolic differentiation: apply differentiation rules algebraically to produce an exact formula. Can produce expressions that grow exponentially large.
- Automatic differentiation (autodiff): tracks the chain of operations and computes exact derivatives efficiently. This is what JAX, PyTorch, and TensorFlow use. It gives exact numerical values (not approximate) without producing bloated symbolic expressions.
Coding Tasks (use CoLab or notebook)
- Compute the gradient of $f(x, y) = x^2 y + 3x - 2y$ at the point $(1, 2)$ using
jax.grad. Since $f$ takes a vector input, usejax.gradwithargnums.
import jax
import jax.numpy as jnp
def f(x, y):
return x**2 * y + 3*x - 2*y
df_dx = jax.grad(f, argnums=0)
df_dy = jax.grad(f, argnums=1)
x, y = 1.0, 2.0
print(f"∂f/∂x = {df_dx(x, y):.4f} (expected: {2*x*y + 3:.4f})")
print(f"∂f/∂y = {df_dy(x, y):.4f} (expected: {x**2 - 2:.4f})")
- Compute the Jacobian of a vector-valued function using
jax.jacobian. Compare with manual calculation.
import jax
import jax.numpy as jnp
def F(x):
return jnp.array([x[0]**2 + x[1], x[0] * x[1]**2])
J = jax.jacobian(F)
x = jnp.array([1.0, 2.0])
print(f"Jacobian at (1,2):\n{J(x)}")
# Expected: [[2*x[0], 1], [x[1]**2, 2*x[0]*x[1]]] = [[2, 1], [4, 4]]
- Compute the Hessian of $f(x, y) = x^3 + 2xy^2 - y^3$ using
jax.hessianand verify it is symmetric.
import jax
import jax.numpy as jnp
def f(xy):
x, y = xy[0], xy[1]
return x**3 + 2*x*y**2 - y**3
H = jax.hessian(f)
point = jnp.array([1.0, 2.0])
hess = H(point)
print(f"Hessian:\n{hess}")
print(f"Symmetric: {jnp.allclose(hess, hess.T)}")
# Expected: [[6x, 4y], [4y, 4x-6y]] = [[6, 8], [8, -8]]
- Build a minimal autodiff engine from scratch.
- Each
Vartracks its value and how to propagate gradients backward through the chain rule. - Try extending it with more operations (division, power, etc.).
- This is the foundations of how JAX, PyTorch and Numpy were designed.
- Each
class Var:
def __init__(self, val, children=(), backward_fn=None):
self.val = val
self.grad = 0.0
self.children = children
self.backward_fn = backward_fn
def __add__(self, other):
out = Var(self.val + other.val, children=(self, other))
def _backward():
self.grad += out.grad # d(a+b)/da = 1
other.grad += out.grad # d(a+b)/db = 1
out.backward_fn = _backward
return out
def __mul__(self, other):
out = Var(self.val * other.val, children=(self, other))
def _backward():
self.grad += other.val * out.grad # d(a*b)/da = b
other.grad += self.val * out.grad # d(a*b)/db = a
out.backward_fn = _backward
return out
def backward(self):
# topological sort then propagate gradients
# we will go through this in data structures and algorithms
order, visited = [], set()
def topo(v):
if v not in visited:
visited.add(v)
for c in v.children:
topo(c)
order.append(v)
topo(self)
self.grad = 1.0
for v in reversed(order):
if v.backward_fn:
v.backward_fn()
# f(x, y) = x*x*y + x at (3, 2)
x = Var(3.0)
y = Var(2.0)
f = x * x * y + x # = 3*3*2 + 3 = 21
f.backward()
print(f"f = {f.val}") # 21.0
print(f"df/dx = {x.grad}") # 2*x*y + 1 = 13.0
print(f"df/dy = {y.grad}") # x*x = 9.0
Function Approximation
-
Many functions we encounter are too complex to work with directly. Computing $e^{0.1}$ on paper, predicting the trajectory of a satellite, etc. all involve functions that do not have simple closed-form answers.
-
Function approximation replaces a complicated function with a simpler one that is "close enough" over the region we care about.
-
The most natural approximation is a polynomial. Polynomials are just sums of powers of $x$ with coefficients, and they are easy to evaluate, differentiate, and integrate.
-
But why do polynomials work so well as approximators? Consider what each power of $x$ contributes.
- The constant term $a_0$ sets the baseline value.
- The $a_1 x$ term adds a slope.
- The $a_2 x^2$ term adds curvature.
- Each higher power captures finer detail about the function's shape.
-
By choosing the right coefficients, we can match a function's value, slope, curvature, and higher-order behaviour at a point, one piece at a time.
-
With enough terms, the polynomial can mimic almost any smooth function.
-
The question becomes: how do we find the right coefficients?
-
Linearisation is the simplest approximation. Near a point $x = a$, we replace the function with its tangent line:
$$L(x) = f(a) + f'(a)(x - a)$$
-
This is the first-order Taylor approximation. It says: start at the known value $f(a)$, then adjust by the slope times the distance from $a$.
-
For example, linearise $\sin(x)$ at $x = 0$: $f(0) = 0$, $f'(0) = \cos(0) = 1$, so $L(x) = x$. Near zero, $\sin(x) \approx x$. Try it: $\sin(0.1) = 0.0998\ldots \approx 0.1$.
-
But linearisation is only good very close to $a$. Move further away and the approximation falls apart. To do better, we include higher-order terms.
-
The Taylor series represents a function as an infinite sum of polynomial terms, each capturing finer detail about the function's behaviour near a point $a$:
$$f(x) = \sum_{n=0}^{\infty} \frac{f^{(n)}(a)}{n!}(x - a)^n = f(a) + f'(a)(x-a) + \frac{f''(a)}{2!}(x-a)^2 + \frac{f'''(a)}{3!}(x-a)^3 + \cdots$$
-
Each successive term adds a correction. The first term matches the value, the second matches the slope, the third matches the curvature, and so on. The more terms we include, the larger the region where the approximation is accurate.
-
The $n!$ in the denominator is not arbitrary. When you differentiate $(x - a)^n$ exactly $n$ times, you get $n!$. The factorial cancels this out, ensuring that the $n$-th derivative of the Taylor polynomial equals the $n$-th derivative of the original function at $x = a$.
-
A Maclaurin series is simply a Taylor series centred at $a = 0$:
$$f(x) = \sum_{n=0}^{\infty} \frac{f^{(n)}(0)}{n!} x^n$$
- Some famous Maclaurin series:
$$e^x = 1 + x + \frac{x^2}{2!} + \frac{x^3}{3!} + \cdots$$
$$\sin x = x - \frac{x^3}{3!} + \frac{x^5}{5!} - \frac{x^7}{7!} + \cdots$$
$$\cos x = 1 - \frac{x^2}{2!} + \frac{x^4}{4!} - \frac{x^6}{6!} + \cdots$$
-
Notice that $\sin x$ has only odd powers (it is an odd function) and $\cos x$ has only even powers (it is an even function). The alternating signs cause the approximation to oscillate around the true value, converging from both sides.
-
Let us approximate $e^{0.5}$ using four terms: $1 + 0.5 + \frac{0.25}{2} + \frac{0.125}{6} = 1 + 0.5 + 0.125 + 0.02083 \approx 1.6458$. The true value is $1.6487\ldots$, so four terms already give us three correct decimal places.
-
Not every Taylor series converges everywhere. The radius of convergence tells us how far from the centre $a$ the series gives valid results. Within that radius, the polynomial approximation can be made as accurate as we want by adding more terms. Outside it, the series diverges.
-
A power series is the general form: $\sum_{n=0}^{\infty} a_n (x - c)^n$. Taylor series are power series where the coefficients are determined by derivatives. Other power series might be defined by some other rule. The ratio test determines convergence: compute $\lim_{n \to \infty} \left|\frac{a_{n+1}}{a_n}\right|$. If this limit is $L$, the radius of convergence is $R = 1/L$.
-
When we truncate a Taylor series after $n$ terms, we incur an error. The Lagrange remainder bounds this error:
$$R_n(x) = \frac{f^{(n+1)}(c)}{(n+1)!}(x-a)^{n+1}$$
-
Here $c$ is some unknown point between $a$ and $x$. We do not know $c$ exactly, but we can often bound $|f^{(n+1)}(c)|$ to get a worst-case error estimate. The $(n+1)!$ in the denominator grows extremely fast, so the error shrinks rapidly as we add more terms (for functions within the radius of convergence).
-
For a function of multiple variables, the Taylor expansion includes mixed partial derivatives. The second-order approximation of $f(\mathbf{x})$ around a point $\mathbf{a}$ is:
$$f(\mathbf{x}) \approx f(\mathbf{a}) + \nabla f(\mathbf{a})^T (\mathbf{x} - \mathbf{a}) + \frac{1}{2} (\mathbf{x} - \mathbf{a})^T H(\mathbf{a}) (\mathbf{x} - \mathbf{a})$$
-
The first term is the value, the second uses the gradient (a vector, as we saw in multivariate calculus), and the third uses the Hessian matrix (which captures curvature). This connects our matrices chapter directly to calculus: the Hessian is a matrix of second derivatives that describes the shape of the function's surface.
-
This multivariate second-order approximation is the foundation of Newton's method and other second-order optimisation techniques, which we will see in the next file.
-
Beyond polynomials, there are other approximation methods worth knowing about:
- Spline interpolation: instead of one high-degree polynomial, use many low-degree polynomials stitched together smoothly. This avoids the wild oscillations that high-degree polynomials can produce.
- Fourier series: approximate periodic functions as sums of sines and cosines. Essential in signal processing and audio.
- Neural networks: universal function approximators. With enough neurons, they can approximate any continuous function to arbitrary accuracy. This is the theoretical justification for deep learning.
-
A function is called "well-behaved" if it has properties that make approximation reliable: continuity (no jumps), differentiability (no sharp corners), smoothness (derivatives of all orders exist), and boundedness (outputs stay finite).
-
Polynomials, exponentials, and trigonometric functions are all well-behaved. The better-behaved a function is, the fewer Taylor terms you need for a good approximation.
Coding Tasks (use CoLab or notebook)
- Approximate $e^x$ using increasing numbers of Taylor terms and visualise how the approximation improves.
import jax.numpy as jnp
import matplotlib.pyplot as plt
x = jnp.linspace(-2, 3, 300)
plt.plot(x, jnp.exp(x), "k-", linewidth=2, label="eˣ (exact)")
colors = ["#e74c3c", "#3498db", "#27ae60", "#9b59b6"]
for n, color in zip([1, 2, 4, 8], colors):
approx = sum(x**k / jnp.array(float(jnp.prod(jnp.arange(1, k+1)) if k > 0 else 1))
for k in range(n+1))
plt.plot(x, approx, color=color, linestyle="--", label=f"{n} terms")
plt.ylim(-2, 15)
plt.legend()
plt.title("Taylor approximation of eˣ")
plt.show()
- Compute the Lagrange remainder to bound the error of approximating $\sin(1)$ with different numbers of Taylor terms.
import jax.numpy as jnp
x = 1.0
exact = jnp.sin(x)
taylor = 0.0
for n in range(8):
sign = (-1)**n
factorial = float(jnp.prod(jnp.arange(1, 2*n+2)))
taylor += sign * x**(2*n+1) / factorial
error = abs(exact - taylor)
bound = x**(2*n+3) / float(jnp.prod(jnp.arange(1, 2*n+4)))
print(f"terms={n+1} approx={taylor:.10f} error={error:.2e} bound={bound:.2e}")
- Compare linearisation vs quadratic Taylor approximation of $\cos(x)$ near $x = 0$. Plot both approximations alongside the true function and observe the range where each is accurate.
import jax.numpy as jnp
import matplotlib.pyplot as plt
x = jnp.linspace(-3, 3, 300)
plt.plot(x, jnp.cos(x), "k-", linewidth=2, label="cos(x)")
plt.plot(x, jnp.ones_like(x), "--", color="#e74c3c", label="linear: 1")
plt.plot(x, 1 - x**2/2, "--", color="#3498db", label="quadratic: 1 - x²/2")
plt.plot(x, 1 - x**2/2 + x**4/24, "--", color="#27ae60", label="4th order")
plt.ylim(-2, 2)
plt.legend()
plt.title("Taylor approximations of cos(x)")
plt.show()
Optimisation
-
Training a neural network, fitting a regression line, tuning hyperparameters: at the core of almost every ML algorithm is an optimisation problem.
-
We have some function (a loss, a cost, an objective) and we want to find the inputs that make it as small (or large) as possible.
-
Before optimising, we need to understand zeros (or roots) of functions. A zero of $f(x)$ is a value $x$ where $f(x) = 0$. Graphically, these are the x-intercepts.
-
For example, $f(x) = x^2 - 3x + 2 = (x-1)(x-2)$ has zeros at $x = 1$ and $x = 2$. Between the zeros, the function is negative ($f(1.5) = -0.25$); outside the zeros, it is positive. The zeros divide the number line into regions where the function has constant sign.
-
The multiplicity of a zero is how many times the corresponding factor appears.
-
At a simple zero (multiplicity 1), the graph crosses the x-axis. At a double zero (multiplicity 2), the graph touches the x-axis but bounces back without crossing, appearing "flat" at that point.
-
Finding zeros matters because the zeros of the derivative $f'(x)$ are the critical points of $f(x)$, the candidates for maxima and minima.
-
At a maximum or minimum, the tangent line is flat (slope = 0), so $f'(x) = 0$.
-
But not every critical point is a maximum or minimum. A point where $f'(x) = 0$ could also be an inflection point (like $x = 0$ for $f(x) = x^3$), where the function flattens momentarily but does not change direction.
-
The second derivative test resolves this. At a critical point $x = c$ where $f'(c) = 0$:
- If $f''(c) > 0$: the curve is concave up (like a bowl), so $c$ is a local minimum.
- If $f''(c) < 0$: the curve is concave down (like a hill), so $c$ is a local maximum.
- If $f''(c) = 0$: the test is inconclusive; higher derivatives or other methods are needed.
-
For example, $f(x) = x^3 - 3x$. The derivative is $f'(x) = 3x^2 - 3 = 3(x-1)(x+1)$, so critical points are at $x = -1$ and $x = 1$. The second derivative is $f''(x) = 6x$. At $x = -1$: $f''(-1) = -6 < 0$ (local max). At $x = 1$: $f''(1) = 6 > 0$ (local min).
-
A function is convex if the line segment between any two points on its graph lies above (or on) the graph. Think of it as a bowl shape, curving upward everywhere. Mathematically, $f$ is convex if $f''(x) \geq 0$ for all $x$.
-
Convexity is powerful because convex functions have a remarkable property: every local minimum is also the global minimum. There are no deceptive local valleys to get trapped in. If you roll a ball into a convex bowl, it will always reach the bottom.
-
A function is concave (curving downward) if $-f$ is convex. Points where the function transitions between concave and convex are inflection points, occurring where $f''(x) = 0$.
-
Newton's method finds zeros of functions (and by extension, critical points of their derivatives) using tangent lines. Starting from an initial guess $x_0$, it iteratively refines:
$$x_{n+1} = x_n - \frac{f(x_n)}{f'(x_n)}$$
-
The idea: at $x_n$, draw the tangent line and find where it crosses the x-axis. That crossing point becomes $x_{n+1}$. For well-behaved functions with a good starting point, Newton's method converges very quickly (quadratically, meaning the number of correct digits roughly doubles each step).
-
For example, to find $\sqrt{5}$ (a zero of $f(x) = x^2 - 5$): $f'(x) = 2x$, so $x_{n+1} = x_n - \frac{x_n^2 - 5}{2x_n}$. Starting at $x_0 = 2$: $x_1 = 2.25$, $x_2 = 2.2361\ldots$, which is already accurate to four decimal places.
-
Newton's method can fail if the initial guess is far from the root, if $f'(x) = 0$ near the root, or if the function has inflection points nearby. It also requires computing the derivative, which may be expensive.
-
For optimisation (finding minima instead of zeros), we apply Newton's method to $f'(x) = 0$, which gives the update:
$$x_{n+1} = x_n - \frac{f'(x_n)}{f''(x_n)}$$
-
In multiple dimensions, this becomes $\mathbf{x}_{n+1} = \mathbf{x}_n - H^{-1} \nabla f(\mathbf{x}_n)$, where $H$ is the Hessian matrix. This is the second-order Taylor approximation from the previous file in action: approximate the function as a quadratic, jump to the minimum of that quadratic, repeat.
-
Lagrange multipliers solve constrained optimisation: find the optimum of $f(x, y)$ subject to a constraint $g(x, y) = c$. Instead of searching all of $\mathbb{R}^n$, we are restricted to the set where the constraint holds (a curve or surface).
-
The key insight is geometric: at the constrained optimum, the gradient of $f$ must be parallel to the gradient of $g$. If they were not parallel, we could move along the constraint in a direction that still improves $f$, so we would not be at the optimum yet.
-
We introduce a new variable $\lambda$ (the Lagrange multiplier) and define the Lagrangian:
$$\mathcal{L}(x, y, \lambda) = f(x, y) - \lambda(g(x, y) - c)$$
- Setting all partial derivatives to zero gives a system of equations whose solutions are the constrained optima:
$$\frac{\partial \mathcal{L}}{\partial x} = 0, \quad \frac{\partial \mathcal{L}}{\partial y} = 0, \quad \frac{\partial \mathcal{L}}{\partial \lambda} = 0$$
- For example, maximise $f(x,y) = x^2 y$ subject to $x^2 + y^2 = 1$. The Lagrangian is $\mathcal{L} = x^2 y - \lambda(x^2 + y^2 - 1)$. Taking partials:
$$2xy - 2\lambda x = 0, \quad x^2 - 2\lambda y = 0, \quad x^2 + y^2 = 1$$
-
From the first equation (assuming $x \neq 0$): $\lambda = y$. Substituting into the second: $x^2 = 2y^2$. Combined with the constraint: $2y^2 + y^2 = 1$, so $y = \frac{1}{\sqrt{3}}$. The maximum value is $f = \frac{2}{3\sqrt{3}}$.
-
For inequality constraints ($g(x,y) \leq c$ instead of $= c$), the Karush-Kuhn-Tucker (KKT) conditions generalise Lagrange multipliers. The constraint is either active (binding, treated as equality) or inactive (the solution lies in the interior and the constraint is irrelevant).
-
In practice, we rarely optimise by hand. Here are the main algorithmic families:
-
First-order methods (use only gradient): gradient descent, stochastic gradient descent (SGD), Adam. These are cheap per step but can converge slowly, especially on ill-conditioned problems.
-
Second-order methods (use gradient and Hessian): Newton's method converges fast but computing and inverting the Hessian is expensive ($O(n^3)$ for $n$ parameters). Quasi-Newton methods (like BFGS and L-BFGS) approximate the Hessian using only gradient information, achieving faster convergence than first-order methods without the full cost of second-order methods.
-
Conjugate gradient: efficient for large sparse systems, using only matrix-vector products instead of storing the full Hessian.
-
Gauss-Newton and Levenberg-Marquardt: specialised for least-squares problems (common in regression), approximating the Hessian via the Jacobian.
-
Natural gradient descent: accounts for the geometry of the parameter space using the Fisher information matrix, which can be more effective for probabilistic models.
-
-
The choice of optimiser depends on the problem. For deep learning, first-order methods (especially Adam) dominate because the number of parameters is enormous (millions to billions), making Hessian computation impractical. For smaller problems with smooth objectives, second-order methods can be dramatically faster.
Coding Tasks (use CoLab or notebook)
- Implement Newton's method to find $\sqrt{7}$ (a zero of $f(x) = x^2 - 7$). Observe the rapid convergence.
import jax.numpy as jnp
f = lambda x: x**2 - 7
df = lambda x: 2*x
x = 3.0 # initial guess
for i in range(6):
x = x - f(x) / df(x)
print(f"step {i+1}: x = {x:.10f} (error: {abs(x - jnp.sqrt(7.0)):.2e})")
- Use gradient descent to minimise $f(x, y) = (x - 3)^2 + (y + 1)^2$. The minimum is at $(3, -1)$. Experiment with different learning rates.
import jax
import jax.numpy as jnp
def f(params):
x, y = params
return (x - 3)**2 + (y + 1)**2
grad_f = jax.grad(f)
params = jnp.array([0.0, 0.0])
lr = 0.1
for i in range(20):
g = grad_f(params)
params = params - lr * g
if i % 5 == 0 or i == 19:
print(f"step {i:2d}: ({params[0]:.4f}, {params[1]:.4f}) loss={f(params):.6f}")
- Solve a constrained optimisation problem numerically. Maximise $f(x,y) = xy$ subject to $x + y = 10$ by parameterising $y = 10 - x$ and finding the optimum of the single-variable function.
import jax
import jax.numpy as jnp
# Substitute constraint: y = 10 - x, so f = x(10 - x) = 10x - x²
f = lambda x: x * (10 - x)
df = jax.grad(f)
# Gradient ascent (we want maximum, so add gradient)
x = 1.0
lr = 0.1
for i in range(20):
x = x + lr * df(x)
print(f"x={x:.4f}, y={10-x:.4f}, f={f(x):.4f}") # should be x=5, y=5, f=25
Fundamentals of Statistics
-
Statistics is the science of learning from data. You collect observations, summarise them, and draw conclusions, often about things you cannot measure directly.
-
Imagine you want to know the average height of every adult in a country. You cannot measure everyone, so you measure a sample and use statistics to make an informed guess about the whole population.
-
There are two main branches:
- Descriptive statistics: summarising data you already have (averages, charts, tables)
- Inferential statistics: using a sample to make claims about a larger group
-
The building block of statistics is the distribution, a description of how values are spread out. Everything else, averages, tests, predictions, flows from understanding distributions.
-
A frequency distribution counts how often each value (or range of values) appears in your data. Think of sorting exam scores into bins and counting how many students fall in each bin. The result is a histogram.
-
A probability distribution replaces raw counts with probabilities. Instead of "12 students scored between 70 and 80," it says "there is a 0.24 probability of scoring between 70 and 80." The histogram bars become a smooth curve when the data is continuous.
-
The histogram on the left is built from actual data you collected. The smooth curve on the right is a mathematical model that describes the pattern behind the data. One is empirical, the other is theoretical.
-
To work with distributions mathematically, we need a way to assign numbers to outcomes. That is exactly what a random variable does.
-
A random variable is a function that maps each outcome of an experiment to a real number. Flip a coin: the outcome is "heads" or "tails," but a random variable $X$ converts this to $X(\text{heads}) = 1$ and $X(\text{tails}) = 0$. Now we can do arithmetic.
-
A discrete random variable takes on a countable set of values: the number of heads in 10 flips, the roll of a die, the number of emails you receive in an hour.
-
A continuous random variable can take any value in an interval: your exact height, the time until the next bus arrives, the temperature at noon.
-
The distinction matters because it changes how we compute probabilities. For discrete variables, we sum. For continuous variables, we integrate (recall integrals from Chapter 3).
-
For a discrete random variable, the probability mass function (PMF) gives the probability of each specific value:
$$P(X = x) = p(x), \quad \text{where } \sum_{x} p(x) = 1$$
- For a continuous random variable, the probability density function (PDF) gives the probability of falling within a range. The probability of any single exact value is zero; only intervals have positive probability:
$$P(a \le X \le b) = \int_a^b f(x), dx, \quad \text{where } \int_{-\infty}^{\infty} f(x), dx = 1$$
-
Now that we can assign numbers to outcomes, the most natural question is: what value do we expect on average?
-
Expectation (or expected value) is the weighted average of all possible values, where the weights are the probabilities. Think of it as the "centre of gravity" of the distribution.
-
If you roll a fair die many times, your average roll converges to 3.5. That is the expected value, even though you can never actually roll a 3.5.
-
For a discrete random variable:
$$E[X] = \sum_{x} x \cdot p(x)$$
- For a continuous random variable (using the integral from Chapter 3):
$$E[X] = \int_{-\infty}^{\infty} x \cdot f(x), dx$$
- Example: a fair six-sided die has $p(x) = 1/6$ for $x = 1, 2, 3, 4, 5, 6$.
$$E[X] = 1 \cdot \tfrac{1}{6} + 2 \cdot \tfrac{1}{6} + 3 \cdot \tfrac{1}{6} + 4 \cdot \tfrac{1}{6} + 5 \cdot \tfrac{1}{6} + 6 \cdot \tfrac{1}{6} = \frac{21}{6} = 3.5$$
-
Expectation is linear, meaning $E[aX + b] = aE[X] + b$. This property is extremely useful and shows up constantly in ML loss functions.
-
Expectation tells us the centre, but it says nothing about how spread out the values are. To describe the full shape of a distribution, we need moments.
-
A moment is an expectation of a power of $X$. The $k$-th raw moment is:
$$\mu_k' = E[X^k]$$
-
The first raw moment ($k = 1$) is just the mean: $\mu_1' = E[X] = \mu$.
-
Raw moments are measured from zero. Often we care about deviation from the mean instead. The $k$-th central moment centres the measurement:
$$\mu_k = E[(X - \mu)^k]$$
-
The first central moment is always zero (deviations above and below the mean cancel). The second central moment is the variance.
-
To compare distributions on different scales, we standardise by dividing by the appropriate power of the standard deviation $\sigma$:
$$\tilde{\mu}_k = \frac{\mu_k}{\sigma^k}$$
- Each moment captures a different aspect of the distribution's shape:
-
1st moment (Mean): Where the distribution is centred. The balance point.
-
2nd moment (Variance): How spread out values are around the mean. Higher variance means wider.
-
3rd moment (Skewness): Whether the distribution leans left or right. Zero skewness means symmetric.
-
4th moment (Kurtosis): How heavy the tails are. Higher kurtosis means more extreme outliers.
-
Let us work through all four moments for a concrete dataset: $X = {2, 4, 4, 4, 5, 5, 7, 9}$.
-
Step 1: Mean (1st raw moment)
$$\mu = \frac{2 + 4 + 4 + 4 + 5 + 5 + 7 + 9}{8} = \frac{40}{8} = 5$$
- Step 2: Variance (2nd central moment). Subtract the mean from each value, square, then average:
$$\sigma^2 = \frac{(2{-}5)^2 + (4{-}5)^2 + (4{-}5)^2 + (4{-}5)^2 + (5{-}5)^2 + (5{-}5)^2 + (7{-}5)^2 + (9{-}5)^2}{8}$$
$$= \frac{9 + 1 + 1 + 1 + 0 + 0 + 4 + 16}{8} = \frac{32}{8} = 4$$
-
The standard deviation is $\sigma = \sqrt{4} = 2$.
-
Step 3: Skewness (standardised 3rd central moment). Cube the deviations, average, divide by $\sigma^3$:
$$\tilde{\mu}_3 = \frac{1}{8} \cdot \frac{(-3)^3 + (-1)^3 + (-1)^3 + (-1)^3 + 0^3 + 0^3 + 2^3 + 4^3}{2^3}$$
$$= \frac{1}{8} \cdot \frac{-27 -1 -1 -1 + 0 + 0 + 8 + 64}{8} = \frac{42}{64} = 0.656$$
-
Positive skewness means the right tail is longer, which makes sense since 9 is far above the mean.
-
Step 4: Kurtosis (standardised 4th central moment). Raise deviations to the 4th power:
$$\tilde{\mu}_4 = \frac{1}{8} \cdot \frac{(-3)^4 + (-1)^4 + (-1)^4 + (-1)^4 + 0^4 + 0^4 + 2^4 + 4^4}{2^4}$$
$$= \frac{1}{8} \cdot \frac{81 + 1 + 1 + 1 + 0 + 0 + 16 + 256}{16} = \frac{356}{128} = 2.781$$
- A normal distribution has kurtosis of 3 (called "mesokurtic"). Our value of 2.781 is close, suggesting the tails are roughly normal. Values above 3 ("leptokurtic") signal heavier tails; below 3 ("platykurtic") signal lighter tails. Some formulas report excess kurtosis by subtracting 3, so our excess kurtosis would be $-0.219$.
Coding Tasks (use CoLab or notebook)
- Compute the expected value of a loaded die where face 6 has probability 0.3 and all other faces share the remaining probability equally. Verify by simulating 100,000 rolls.
import jax
import jax.numpy as jnp
# Loaded die: face 6 has p=0.3, others share 0.7 equally
probs = jnp.array([0.14, 0.14, 0.14, 0.14, 0.14, 0.30])
faces = jnp.array([1, 2, 3, 4, 5, 6])
# Analytical expected value
ev = jnp.sum(faces * probs)
print(f"Expected value (formula): {ev:.4f}")
# Simulation
key = jax.random.PRNGKey(42)
rolls = jax.random.choice(key, faces, shape=(100_000,), p=probs)
print(f"Expected value (simulation): {rolls.mean():.4f}")
- Compute all four moments (mean, variance, skewness, kurtosis) for the dataset from the worked example, then modify the data and observe how each moment changes.
import jax.numpy as jnp
x = jnp.array([2, 4, 4, 4, 5, 5, 7, 9], dtype=jnp.float32)
mean = jnp.mean(x)
variance = jnp.mean((x - mean) ** 2)
std = jnp.sqrt(variance)
skewness = jnp.mean(((x - mean) / std) ** 3)
kurtosis = jnp.mean(((x - mean) / std) ** 4)
print(f"Mean: {mean:.3f}")
print(f"Variance: {variance:.3f}")
print(f"Std Dev: {std:.3f}")
print(f"Skewness: {skewness:.3f}")
print(f"Kurtosis: {kurtosis:.3f}")
print(f"Excess K: {kurtosis - 3:.3f}")
- Visualise a PMF and CDF side by side for a fair die roll. Try changing the probabilities to see how the shapes shift.
import jax.numpy as jnp
import matplotlib.pyplot as plt
faces = jnp.array([1, 2, 3, 4, 5, 6])
pmf = jnp.ones(6) / 6 # fair die; try changing these!
cdf = jnp.cumsum(pmf)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
ax1.bar(faces, pmf, color="#3498db", alpha=0.8)
ax1.set_title("PMF")
ax1.set_xlabel("Face")
ax1.set_ylabel("P(X = x)")
ax1.set_ylim(0, 0.5)
ax2.step(faces, cdf, where="mid", color="#e74c3c", linewidth=2)
ax2.set_title("CDF")
ax2.set_xlabel("Face")
ax2.set_ylabel("P(X ≤ x)")
ax2.set_ylim(0, 1.1)
plt.tight_layout()
plt.show()
Statistical Measures
-
In the previous file we introduced moments as a family of summary statistics. Here we unpack the practical tools that flow from them: measures of dispersion, position, shape, and association.
-
Dispersion answers the question: how spread out is the data? Two classrooms can have the same average test score, but very different spreads.
-
The narrow (blue) distribution has low variance: most values cluster tightly around the mean. The wide (red) distribution has high variance: values are scattered further out.
-
Variance is the average squared distance from the mean. We square to avoid positive and negative deviations cancelling each other out.
$$\sigma^2 = \frac{1}{N} \sum_{i=1}^{N} (x_i - \mu)^2$$
- When working with a sample (not the full population), we divide by $N - 1$ instead of $N$. This correction (called Bessel's correction) accounts for the fact that a sample tends to underestimate the true variability:
$$s^2 = \frac{1}{N-1} \sum_{i=1}^{N} (x_i - \bar{x})^2$$
-
Standard deviation is the square root of variance: $\sigma = \sqrt{\sigma^2}$. It brings the measure back to the original units. If your data is in centimetres, variance is in cm$^2$, but standard deviation is back in cm.
-
Mean Absolute Deviation (MAD) is a simpler alternative. Instead of squaring, take the absolute value of each deviation:
$$\text{MAD} = \frac{1}{N} \sum_{i=1}^{N} |x_i - \mu|$$
-
MAD is more robust to outliers than variance because it does not amplify large deviations by squaring them. However, variance is more mathematically convenient (it decomposes nicely in proofs and ML optimisation).
-
Position answers a different question: where does a specific value sit relative to the rest of the data?
-
Quartiles split sorted data into four equal parts. Q1 (25th percentile) is the value below which 25% of data falls. Q2 is the median (50th percentile). Q3 is the 75th percentile.
-
The Interquartile Range (IQR) is $Q3 - Q1$. It captures the spread of the middle 50% of data, ignoring extremes.
-
The box plot is one of the most useful visualisations in statistics. The box spans Q1 to Q3, the line inside is the median, whiskers extend to the most extreme non-outlier values, and dots beyond the whiskers are outliers.
-
Percentiles generalise quartiles. The $p$-th percentile is the value below which $p%$ of observations fall. Q1 is the 25th percentile, the median is the 50th, and Q3 is the 75th.
-
The z-score tells you how many standard deviations a value is from the mean:
$$z = \frac{x - \mu}{\sigma}$$
-
A z-score of 2 means the value is 2 standard deviations above the mean. A z-score of $-1.5$ means it is 1.5 standard deviations below. This is also called standardisation and is used heavily in ML for feature scaling, as it transforms any distribution to have mean 0 and standard deviation 1.
-
Shape describes the geometry of a distribution beyond its centre and spread.
-
Skewness (the standardised 3rd moment from the previous file) measures asymmetry. A perfectly symmetric distribution like the normal curve has skewness of zero. Positive skewness means a longer right tail (e.g. income distributions). Negative skewness means a longer left tail (e.g. age at retirement).
$$\text{Skewness} = \frac{1}{N} \sum_{i=1}^{N} \left(\frac{x_i - \mu}{\sigma}\right)^3$$
- Kurtosis (the standardised 4th moment) measures tail heaviness. The normal distribution has kurtosis of 3. Distributions with heavier tails (more prone to outliers) have kurtosis greater than 3.
$$\text{Kurtosis} = \frac{1}{N} \sum_{i=1}^{N} \left(\frac{x_i - \mu}{\sigma}\right)^4$$
- Correlation measures the strength and direction of a relationship between two variables. It answers: when one variable goes up, does the other tend to go up, go down, or do nothing?
- Pearson correlation ($r$) measures linear association. It ranges from $-1$ (perfect negative) through $0$ (none) to $+1$ (perfect positive).
$$r = \frac{\sum_{i=1}^{N} (x_i - \bar{x})(y_i - \bar{y})}{\sqrt{\sum (x_i - \bar{x})^2} \cdot \sqrt{\sum (y_i - \bar{y})^2}}$$
-
If you recall dot products from Chapter 1, Pearson correlation is essentially the cosine similarity between the mean-centred versions of $\mathbf{x}$ and $\mathbf{y}$.
-
Spearman correlation ($\rho$) measures monotonic association. Instead of using raw values, it ranks them first and then computes Pearson correlation on the ranks. This makes it robust to outliers and works even when the relationship is nonlinear, as long as it is consistently increasing or decreasing.
-
Geometric mean is the appropriate average when values multiply together, like growth rates. If your investment grows by 10%, then 20%, then 30%, the average growth factor is not the arithmetic mean of those rates. Instead:
$$\bar{x}{\text{geo}} = \left(\prod{i=1}^{N} x_i\right)^{1/N}$$
-
For growth rates specifically, convert percentages to factors first (1.10, 1.20, 1.30), compute the geometric mean, then subtract 1.
-
Exponential Moving Average (EMA) gives more weight to recent observations. Unlike a simple moving average where all points in the window are equally weighted, EMA decays exponentially:
$$\text{EMA}t = \alpha \cdot x_t + (1 - \alpha) \cdot \text{EMA}{t-1}$$
-
The smoothing factor $\alpha$ (between 0 and 1) controls how quickly old observations lose influence. Higher $\alpha$ means more responsive to recent changes, lower $\alpha$ means smoother. In ML, EMA is used in optimisers like Adam and in batch normalisation's running statistics.
-
Outlier detection identifies data points that are unusually far from the rest. Two common methods:
- IQR method: a point is an outlier if it falls below $Q1 - 1.5 \times \text{IQR}$ or above $Q3 + 1.5 \times \text{IQR}$
- Z-score method: a point is an outlier if $|z| > 3$ (more than 3 standard deviations from the mean)
-
The IQR method is more robust because it does not assume a normal distribution. The z-score method works well when data is approximately normal but can fail when the distribution is heavily skewed.
Coding Tasks (use CoLab or notebook)
- Compute variance, standard deviation, and MAD for a dataset and compare them. Observe what happens when you add an extreme outlier.
import jax.numpy as jnp
data = jnp.array([4, 8, 6, 5, 3, 7, 9, 5, 6, 7], dtype=jnp.float32)
mean = jnp.mean(data)
variance = jnp.var(data)
std = jnp.std(data)
mad = jnp.mean(jnp.abs(data - mean))
print("Original data:")
print(f" Variance: {variance:.3f}, Std: {std:.3f}, MAD: {mad:.3f}")
# Add an outlier and recompute
data_outlier = jnp.append(data, 100.0)
mean2 = jnp.mean(data_outlier)
print(f"\nWith outlier (100):")
print(f" Variance: {jnp.var(data_outlier):.3f}, Std: {jnp.std(data_outlier):.3f}, MAD: {jnp.mean(jnp.abs(data_outlier - mean2)):.3f}")
- Compute Pearson and Spearman correlation between two variables. Experiment with different relationships.
import jax
import jax.numpy as jnp
# Perfect linear relationship
x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=jnp.float32)
y = 2 * x + 1 # try changing this!
def pearson(a, b):
a_c = a - jnp.mean(a)
b_c = b - jnp.mean(b)
return jnp.sum(a_c * b_c) / (jnp.sqrt(jnp.sum(a_c**2)) * jnp.sqrt(jnp.sum(b_c**2)))
def spearman(a, b):
rank_a = jnp.argsort(jnp.argsort(a)).astype(jnp.float32)
rank_b = jnp.argsort(jnp.argsort(b)).astype(jnp.float32)
return pearson(rank_a, rank_b)
print(f"Pearson r: {pearson(x, y):.4f}")
print(f"Spearman ρ: {spearman(x, y):.4f}")
- Implement outlier detection using both the IQR and z-score methods, then compare their results on skewed data.
import jax.numpy as jnp
data = jnp.array([2, 3, 3, 4, 5, 5, 5, 6, 6, 7, 50], dtype=jnp.float32)
# IQR method
q1, q3 = jnp.percentile(data, 25), jnp.percentile(data, 75)
iqr = q3 - q1
lower, upper = q1 - 1.5 * iqr, q3 + 1.5 * iqr
iqr_outliers = data[(data < lower) | (data > upper)]
print(f"IQR bounds: [{lower:.1f}, {upper:.1f}]")
print(f"IQR outliers: {iqr_outliers}")
# Z-score method
z_scores = (data - jnp.mean(data)) / jnp.std(data)
z_outliers = data[jnp.abs(z_scores) > 3]
print(f"\nZ-scores: {z_scores}")
print(f"Z-score outliers (|z| > 3): {z_outliers}")
- Compute and plot an Exponential Moving Average with different smoothing factors on noisy data.
import jax.numpy as jnp
import matplotlib.pyplot as plt
# Generate noisy data
key = __import__("jax").random.PRNGKey(0)
noise = __import__("jax").random.normal(key, shape=(50,))
signal = jnp.linspace(0, 5, 50) + noise
def ema(data, alpha):
result = jnp.zeros_like(data)
result = result.at[0].set(data[0])
for t in range(1, len(data)):
result = result.at[t].set(alpha * data[t] + (1 - alpha) * result[t - 1])
return result
plt.figure(figsize=(10, 4))
plt.plot(signal, "o", alpha=0.3, label="raw data", color="#999")
for alpha, color in [(0.1, "#e74c3c"), (0.3, "#3498db"), (0.7, "#27ae60")]:
plt.plot(ema(signal, alpha), label=f"α={alpha}", color=color, linewidth=2)
plt.legend()
plt.title("EMA with different smoothing factors")
plt.show()
Sampling
-
In an ideal world, you would measure every single member of the group you care about. In practice, that is almost never possible. You cannot survey every voter, test every light bulb, or scan every patient. So you take a sample and use it to learn about the whole.
-
The population is the complete set of individuals or items you want to study. The sample is the subset you actually observe.
-
A parameter is a number that describes the population (e.g. the true average height of all adults in a country).
-
A statistic is a number computed from your sample (e.g. the average height of the 500 people you measured). Statistics are used to estimate parameters.
-
The quality of your conclusions depends entirely on how you select your sample. A biased sample leads to biased conclusions, no matter how sophisticated your analysis.
-
The sampling frame is the list of all individuals from which you actually draw your sample. Ideally this matches the population perfectly, but in practice there are gaps.
-
For instnce, if you survey people by phone, you miss everyone without a phone. The difference between the frame and the population is called coverage error.
-
Sampling error is the natural discrepancy between a sample statistic and the population parameter.
-
Even a perfectly random sample will not match the population exactly. Larger samples reduce sampling error.
-
There are two broad families of sampling: probability and non-probability.
-
Probability sampling means every member of the population has a known, nonzero chance of being selected. This lets you quantify uncertainty and generalise results.
-
Simple random sampling: every individual has an equal chance of being selected, and every possible sample of size $n$ is equally likely. Think of putting every name in a hat and drawing blindly.
-
Stratified sampling: divide the population into non-overlapping groups (strata) based on a shared characteristic (e.g. age group, region), then randomly sample from each stratum. This guarantees representation from every group and reduces variance when strata differ from each other.
-
Cluster sampling: divide the population into groups (clusters), randomly select some clusters, then include everyone in the chosen clusters. This is practical when the population is spread out geographically, like sampling entire schools rather than individual students across a district.
-
Systematic sampling: pick a random starting point, then select every $k$-th individual from the list. For example, start at person 7 and then take every 10th person (7, 17, 27, ...). Simple to implement but can introduce bias if the list has a hidden pattern.
-
Non-probability sampling does not give every member a known chance of selection. Results cannot be rigorously generalised, but these methods are often faster and cheaper.
-
Convenience sampling: select whoever is easiest to reach. Surveying people at a shopping mall is convenient but misses those who do not shop there.
-
Quota sampling: like stratified sampling, but without randomness. The researcher fills quotas (e.g. 50 men and 50 women) by picking accessible individuals from each group.
-
Snowball sampling: start with a few participants and ask them to recruit others. Useful for hard-to-reach populations (e.g. studying rare diseases), but heavily biased toward connected individuals.
-
Once you have a sampling method, a natural question arises: if I took a different sample, would I get a different statistic? Almost certainly yes. The sampling distribution is the distribution of a statistic (like the sample mean) across all possible samples of the same size.
-
Imagine drawing 1,000 different samples of 30 people and computing the mean height of each. Those 1,000 means form a distribution. Some will be a bit above the true population mean, some a bit below, and most will cluster around the true value.
-
The standard deviation of this sampling distribution is called the standard error:
$$SE = \frac{\sigma}{\sqrt{n}}$$
-
Notice that the standard error shrinks as $n$ grows. Larger samples give more precise estimates. Quadrupling the sample size halves the standard error.
-
The most important result in statistics is the Central Limit Theorem (CLT). It says: no matter what the shape of the original population, the distribution of sample means approaches a normal distribution as the sample size increases.
- More precisely, if $X_1, X_2, \ldots, X_n$ are independent observations from any distribution with mean $\mu$ and finite variance $\sigma^2$, then as $n$ grows:
$$\bar{X} \approx \text{Normal}!\left(\mu, \frac{\sigma^2}{n}\right)$$
-
The CLT is what makes most of inferential statistics work. It lets us use the normal distribution as an approximation even when the underlying data is not normal, as long as the sample is large enough.
-
How large is "large enough"? A common rule of thumb is $n \ge 30$, but this depends on how non-normal the population is. For highly skewed distributions, you may need more. For roughly symmetric populations, even $n = 10$ can be sufficient.
-
The CLT has three key conditions:
- Independence: each observation must not influence the others
- Finite variance: the population variance must exist (rules out some exotic distributions)
- Identical distribution: all observations come from the same distribution
Coding Tasks (use CoLab or notebook)
- Demonstrate the CLT visually: draw samples from a highly skewed distribution, compute sample means, and watch the histogram of means become bell-shaped.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
key = jax.random.PRNGKey(0)
# Exponential distribution (very skewed)
population = jax.random.exponential(key, shape=(100_000,))
fig, axes = plt.subplots(1, 4, figsize=(14, 3))
sample_sizes = [1, 5, 30, 100]
for ax, n in zip(axes, sample_sizes):
keys = jax.random.split(key, 2000)
means = jnp.array([jax.random.choice(k, population, shape=(n,)).mean() for k in keys])
ax.hist(means, bins=40, color="#3498db", alpha=0.7, density=True)
ax.set_title(f"n = {n}")
ax.set_xlim(0, 4)
fig.suptitle("CLT: sample means become normal as n increases", fontsize=13)
plt.tight_layout()
plt.show()
- Compare simple random sampling with stratified sampling. Create a population with distinct groups and show that stratified sampling gives lower variance in estimates.
import jax
import jax.numpy as jnp
key = jax.random.PRNGKey(42)
# Population: two distinct groups
group_a = jax.random.normal(key, shape=(500,)) + 10 # mean ~10
key, subkey = jax.random.split(key)
group_b = jax.random.normal(subkey, shape=(500,)) + 20 # mean ~20
population = jnp.concatenate([group_a, group_b])
# Simple random sampling: 1000 trials, sample size 20
srs_means = []
for i in range(1000):
key, subkey = jax.random.split(key)
sample = jax.random.choice(subkey, population, shape=(20,), replace=False)
srs_means.append(sample.mean())
srs_means = jnp.array(srs_means)
# Stratified sampling: 10 from each group
strat_means = []
for i in range(1000):
key, k1, k2 = jax.random.split(key, 3)
s_a = jax.random.choice(k1, group_a, shape=(10,), replace=False)
s_b = jax.random.choice(k2, group_b, shape=(10,), replace=False)
strat_means.append(jnp.concatenate([s_a, s_b]).mean())
strat_means = jnp.array(strat_means)
print(f"Simple Random - Mean: {srs_means.mean():.3f}, Std: {srs_means.std():.3f}")
print(f"Stratified - Mean: {strat_means.mean():.3f}, Std: {strat_means.std():.3f}")
print(f"Stratified sampling reduced variance by {(1 - strat_means.var()/srs_means.var())*100:.1f}%")
- Explore how sample size affects standard error. Plot the standard error against sample size and confirm the $1/\sqrt{n}$ relationship.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
key = jax.random.PRNGKey(7)
population = jax.random.normal(key, shape=(50_000,)) * 10 + 50
sample_sizes = [5, 10, 20, 50, 100, 200, 500, 1000]
std_errors = []
for n in sample_sizes:
means = []
for _ in range(500):
key, subkey = jax.random.split(key)
sample = jax.random.choice(subkey, population, shape=(n,))
means.append(sample.mean())
std_errors.append(jnp.array(means).std())
plt.figure(figsize=(8, 4))
plt.plot(sample_sizes, std_errors, "o-", color="#e74c3c", label="Observed SE")
theoretical = population.std() / jnp.sqrt(jnp.array(sample_sizes, dtype=jnp.float32))
plt.plot(sample_sizes, theoretical, "--", color="#3498db", label="σ/√n (theoretical)")
plt.xlabel("Sample size (n)")
plt.ylabel("Standard error")
plt.legend()
plt.title("Standard error shrinks with larger samples")
plt.show()
Hypothesis Testing
-
Statistics is not just about describing data. Often you need to make a decision: does a new drug work? Is one algorithm faster than another? Has the average changed? Hypothesis testing gives you a structured framework for answering these questions using data.
-
The idea is simple: assume nothing has changed (the "null hypothesis"), then check whether the data is so extreme that this assumption becomes hard to believe.
-
The null hypothesis ($H_0$) is the default claim, usually a statement of "no effect" or "no difference." For example: "the average delivery time is still 30 minutes" or "the new model is no better than the old one."
-
The alternative hypothesis ($H_1$ or $H_a$) is what you suspect might be true instead: "the average delivery time has changed" or "the new model is better."
-
You never prove $H_1$ directly. Instead, you ask: if $H_0$ were true, how likely is it that I would see data this extreme? If it is very unlikely, you reject $H_0$ in favour of $H_1$.
-
The test statistic is a single number that summarises how far your sample result is from what $H_0$ predicts. Different tests use different formulas, but the logic is always the same: measure the distance between observed and expected.
-
The p-value is the probability of observing a test statistic at least as extreme as yours, assuming $H_0$ is true. A small p-value means the data is surprising under $H_0$.
-
The significance level ($\alpha$) is the threshold you set before looking at the data. If $p \le \alpha$, you reject $H_0$. Common choices are $\alpha = 0.05$ (5%) and $\alpha = 0.01$ (1%).
-
The shaded tails are the rejection regions. If your test statistic lands there, the data is surprising enough under $H_0$ that you reject it. The green area shows the p-value for a particular test statistic.
-
Here is the step-by-step procedure:
- Step 1: State $H_0$ and $H_1$
- Step 2: Choose a significance level $\alpha$
- Step 3: Collect data and compute the test statistic
- Step 4: Find the p-value (or compare the test statistic to a critical value)
- Step 5: If $p \le \alpha$, reject $H_0$. Otherwise, fail to reject $H_0$
-
Worked example: A factory claims their bolts have a mean length of 10 cm. You measure 36 bolts and find a sample mean of 10.3 cm. The known population standard deviation is 0.9 cm. Is there evidence that the mean has changed?
-
$H_0$: $\mu = 10$, $H_1$: $\mu \neq 10$, $\alpha = 0.05$
-
Test statistic (z-test, since $\sigma$ is known and $n$ is large):
$$z = \frac{\bar{x} - \mu_0}{\sigma / \sqrt{n}} = \frac{10.3 - 10}{0.9 / \sqrt{36}} = \frac{0.3}{0.15} = 2.0$$
-
For a two-tailed test at $\alpha = 0.05$, the critical values are $\pm 1.96$. Our $z = 2.0 > 1.96$, so we reject $H_0$. The p-value is approximately 0.046, which is less than 0.05.
-
Conclusion: there is statistically significant evidence that the mean bolt length differs from 10 cm.
-
A one-tailed test checks for an effect in one specific direction ($H_1$: $\mu > 10$ or $\mu < 10$). The entire $\alpha$ goes into one tail, making it easier to reject $H_0$ in that direction but impossible to detect an effect in the opposite direction.
-
A two-tailed test checks for any difference ($H_1$: $\mu \neq 10$). The $\alpha$ is split between both tails ($\alpha/2$ each). This is more conservative but catches effects in either direction.
-
Even with a good procedure, mistakes happen. There are exactly two types of errors:
-
Type I Error (false positive): you reject $H_0$ when it is actually true. The probability of this is $\alpha$, which you control by choosing your significance level. Like a fire alarm going off when there is no fire.
-
Type II Error (false negative): you fail to reject $H_0$ when it is actually false. The probability of this is $\beta$. Like a fire alarm staying silent during a real fire.
-
Power is $1 - \beta$, the probability of correctly rejecting a false $H_0$. Higher power means you are better at detecting real effects. Power increases when:
- The true effect size is larger (bigger differences are easier to detect)
- The sample size is larger (more data = more precision)
- The significance level $\alpha$ is larger (but this raises Type I error risk)
- The variability is lower (less noise)
-
There is a tension between Type I and Type II errors. Lowering $\alpha$ (being more cautious about false positives) increases $\beta$ (more false negatives). You cannot minimise both simultaneously with a fixed sample size.
-
Parametric tests assume the data follows a specific distribution (usually normal). They are more powerful when the assumptions hold.
-
Z-test: compares a sample mean to a known value when $\sigma$ is known and $n$ is large ($n \ge 30$). Test statistic:
$$z = \frac{\bar{x} - \mu_0}{\sigma / \sqrt{n}}$$
- T-test: like the z-test, but for when $\sigma$ is unknown (estimated from the sample) or $n$ is small. Uses the t-distribution, which has heavier tails than the normal. The heavier tails account for the extra uncertainty from estimating $\sigma$.
$$t = \frac{\bar{x} - \mu_0}{s / \sqrt{n}}$$
-
The t-distribution has a parameter called degrees of freedom ($df = n - 1$). As $df$ increases, the t-distribution approaches the normal distribution.
-
There are several flavours of t-test:
- One-sample t-test: is the sample mean different from a specific value?
- Independent two-sample t-test: are the means of two separate groups different?
- Paired t-test: are the means of two related measurements different (e.g. before and after treatment on the same subjects)?
-
ANOVA (Analysis of Variance): tests whether three or more group means are equal. Instead of running multiple t-tests (which inflates the Type I error rate), ANOVA does a single test by comparing the variance between groups to the variance within groups.
$$F = \frac{\text{variance between groups}}{\text{variance within groups}}$$
-
A large $F$ ratio means the groups differ more than you would expect from random variation alone.
-
Non-parametric tests make fewer assumptions about the data distribution. They work on ranks rather than raw values, making them robust to outliers and non-normality.
-
Chi-square test ($\chi^2$): tests whether observed frequencies match expected frequencies. Used for categorical data. For example: do the proportions of red, blue, and green cars match the manufacturer's claimed proportions?
$$\chi^2 = \sum \frac{(O_i - E_i)^2}{E_i}$$
-
Mann-Whitney U test: the non-parametric alternative to the independent two-sample t-test. It tests whether one group tends to have larger values than the other by comparing ranks.
-
Wilcoxon signed-rank test: the non-parametric alternative to the paired t-test. Compares paired observations by looking at the magnitude and direction of differences.
-
Kruskal-Wallis test: the non-parametric alternative to one-way ANOVA. Tests whether multiple groups come from the same distribution by comparing ranks across all groups.
-
Goodness-of-fit tests check whether your data follows a specific theoretical distribution. The chi-square goodness-of-fit test compares observed bin counts to expected counts under the hypothesised distribution.
-
Normality tests specifically check whether data is normally distributed. Common ones include the Shapiro-Wilk test (powerful for small samples) and the Kolmogorov-Smirnov test (compares the sample CDF to the theoretical CDF).
-
In ML, hypothesis testing appears when you compare model performance. If model A achieves 92% accuracy and model B achieves 91%, is the difference real or just noise? A paired t-test on cross-validation scores can answer this.
Coding Tasks (use CoLab or notebook)
- Perform a z-test for the bolt factory example from the text. Compute the test statistic, p-value, and make a decision.
import jax.numpy as jnp
x_bar = 10.3 # sample mean
mu_0 = 10.0 # null hypothesis value
sigma = 0.9 # known population std
n = 36 # sample size
alpha = 0.05
# Test statistic
z = (x_bar - mu_0) / (sigma / jnp.sqrt(n))
print(f"z = {z:.4f}")
# p-value (two-tailed) using the normal CDF approximation
# For |z| = 2.0, p ≈ 0.0456
from jax.scipy.stats import norm
p_value = 2 * (1 - norm.cdf(jnp.abs(z)))
print(f"p-value = {p_value:.4f}")
print(f"Reject H₀? {p_value <= alpha}")
- Simulate Type I error: when $H_0$ is true, how often do we mistakenly reject it? Run 10,000 experiments and check that the rejection rate matches $\alpha$.
import jax
import jax.numpy as jnp
key = jax.random.PRNGKey(0)
mu_0 = 50.0
sigma = 10.0
n = 30
alpha = 0.05
n_experiments = 10_000
rejections = 0
for i in range(n_experiments):
key, subkey = jax.random.split(key)
sample = mu_0 + sigma * jax.random.normal(subkey, shape=(n,))
z = (sample.mean() - mu_0) / (sigma / jnp.sqrt(n))
p_value = 2 * (1 - __import__("jax").scipy.stats.norm.cdf(jnp.abs(z)))
if p_value <= alpha:
rejections += 1
print(f"Rejection rate: {rejections/n_experiments:.4f}")
print(f"Expected (α): {alpha}")
- Compare a t-test and a Mann-Whitney U test on two groups. Generate data where one group has a slightly higher mean and see which test detects the difference.
import jax
import jax.numpy as jnp
key = jax.random.PRNGKey(99)
k1, k2 = jax.random.split(key)
group_a = jax.random.normal(k1, shape=(25,)) * 5 + 100
group_b = jax.random.normal(k2, shape=(25,)) * 5 + 103 # slightly higher mean
# Two-sample t-test (equal variance assumed)
n_a, n_b = len(group_a), len(group_b)
mean_a, mean_b = group_a.mean(), group_b.mean()
pooled_var = ((n_a - 1) * group_a.var() + (n_b - 1) * group_b.var()) / (n_a + n_b - 2)
se = jnp.sqrt(pooled_var * (1/n_a + 1/n_b))
t_stat = (mean_a - mean_b) / se
print(f"T-test statistic: {t_stat:.4f}")
# Mann-Whitney: count how often group_a values beat group_b values
u_stat = jnp.sum(group_a[:, None] < group_b[None, :])
print(f"Mann-Whitney U: {u_stat}")
print(f"\nGroup A mean: {mean_a:.2f}, Group B mean: {mean_b:.2f}")
Statistical Inference
-
Hypothesis testing gives you a yes/no decision: reject or fail to reject. But often you want something more informative, a range of plausible values for the parameter you are estimating. That is what confidence intervals provide.
-
A point estimate is a single number computed from your sample, like the sample mean $\bar{x}$. It is your best guess for the population parameter, but on its own it gives no sense of how precise the estimate is.
-
A confidence interval wraps that point estimate with a range that reflects uncertainty. It takes the form:
$$\text{CI} = \bar{x} \pm \text{ME}$$
- The margin of error (ME) depends on three things: how confident you want to be, how much variability is in the data, and how large your sample is:
$$\text{ME} = z^\ast \cdot \frac{\sigma}{\sqrt{n}}$$
- Here $z^\ast$ is the critical value from the normal distribution that matches your desired confidence level. For 95% confidence, $z^\ast = 1.96$. For 99% confidence, $z^\ast = 2.576$.
-
A 95% confidence interval means: if you repeated the experiment many times and built an interval each time, about 95% of those intervals would contain the true population parameter. It does not mean there is a 95% probability the parameter is in this specific interval. The parameter is fixed; the intervals are what vary.
-
Worked example: You measure the heights of 50 people and find $\bar{x} = 170$ cm with $\sigma = 8$ cm. Construct a 95% confidence interval.
$$\text{ME} = 1.96 \cdot \frac{8}{\sqrt{50}} = 1.96 \cdot 1.131 = 2.22 \text{ cm}$$
$$\text{CI} = [170 - 2.22, ; 170 + 2.22] = [167.78, ; 172.22]$$
-
You can say with 95% confidence that the true mean height lies between 167.78 and 172.22 cm.
-
When $\sigma$ is unknown (the usual case), use the sample standard deviation $s$ and the t-distribution instead:
$$\text{CI} = \bar{x} \pm t^\ast_{n-1} \cdot \frac{s}{\sqrt{n}}$$
-
Wider intervals are more confident but less precise. Narrower intervals are more precise but less confident. You can narrow an interval without losing confidence by increasing the sample size.
-
Power analysis helps you plan an experiment before you run it. The question is: how large a sample do I need to detect an effect of a given size with a specified power?
-
Recall from the previous file that power = $1 - \beta$, the probability of correctly rejecting a false $H_0$. A common target is 80% power.
-
The required sample size for a z-test detecting a difference $\delta$ with significance $\alpha$ and power $1-\beta$ is:
$$n = \left(\frac{(z_{\alpha/2} + z_{\beta}) \cdot \sigma}{\delta}\right)^2$$
- For example, to detect a 2 cm difference in mean height ($\sigma = 8$) with $\alpha = 0.05$ and 80% power ($z_{0.025} = 1.96$, $z_{0.20} = 0.84$):
$$n = \left(\frac{(1.96 + 0.84) \cdot 8}{2}\right)^2 = \left(\frac{22.4}{2}\right)^2 = 11.2^2 \approx 126$$
-
You would need about 126 people per group.
-
Power analysis prevents two common mistakes: running an experiment too small to detect a real effect (underpowered), or wasting resources on an experiment far larger than necessary (overpowered).
-
Monte Carlo methods use random sampling to solve problems that are difficult or impossible to solve analytically. The core idea: if you cannot compute something exactly, simulate it many times and use the results as an approximation.
-
The name comes from the Monte Carlo casino, a nod to the role of randomness. These methods are workhorses in ML for tasks like estimating integrals, evaluating model uncertainty, and approximating complex distributions.
-
The general Monte Carlo recipe:
- Define a domain of possible inputs
- Generate random inputs from that domain
- Evaluate a function on each input
- Aggregate the results (average, count, etc.)
-
A classic example is estimating $\pi$. Imagine a square with side length 2, centred at the origin, with a circle of radius 1 inscribed inside it. The area of the square is 4, and the area of the circle is $\pi$.
- Drop random points uniformly in the square. The fraction that land inside the circle approximates $\pi/4$:
$$\pi \approx 4 \times \frac{\text{points inside circle}}{\text{total points}}$$
-
A point $(x, y)$ is inside the circle if $x^2 + y^2 \le 1$. The more points you throw, the closer your estimate gets to the true value of $\pi$.
-
In ML, Monte Carlo methods appear in:
- Monte Carlo dropout: run inference multiple times with dropout enabled to estimate prediction uncertainty
- MCMC (Markov Chain Monte Carlo): sample from complex posterior distributions in Bayesian models
- Policy gradient methods: estimate gradients in reinforcement learning by sampling trajectories
-
Factor analysis is a technique for discovering hidden (latent) variables that explain the correlations among observed variables. If 10 personality survey questions can be explained by 3 underlying traits (extraversion, agreeableness, conscientiousness), factor analysis finds those traits.
-
The model assumes each observed variable $x_i$ is a linear combination of a few latent factors $f_j$ plus noise:
$$x_i = \lambda_{i1} f_1 + \lambda_{i2} f_2 + \ldots + \lambda_{ik} f_k + \epsilon_i$$
-
The $\lambda$ values are called factor loadings and tell you how strongly each observed variable relates to each factor. This connects directly to the matrix decompositions from Chapter 2; factor analysis is closely related to eigenvalue decomposition and SVD.
-
Experimental design is the art of structuring an experiment so that you can draw valid conclusions. Poor design can make even a large dataset useless.
-
Key components of a well-designed experiment:
- Independent variable (IV): what you manipulate (e.g. drug dose, model architecture)
- Dependent variable (DV): what you measure (e.g. recovery time, accuracy)
- Control group: receives no treatment (or a placebo), providing a baseline for comparison
- Random assignment: participants are assigned to groups randomly, which balances out confounding variables you did not measure
-
Common experimental designs:
- Completely randomised design: subjects are randomly assigned to treatment groups. Simple and effective when groups are comparable.
- Randomised block design: subjects are first grouped into blocks (e.g. by age), then randomly assigned to treatments within each block. This reduces variability from the blocking factor, similar in spirit to stratified sampling.
- Factorial design: tests multiple IVs simultaneously. A $2 \times 3$ factorial design has 2 levels of one variable and 3 of another, giving 6 treatment combinations. This lets you detect interactions, where the effect of one variable depends on the level of another.
- Crossover design: each subject receives all treatments in sequence (with washout periods in between). Every subject serves as their own control, reducing the effect of individual differences.
-
In ML experiments, these principles are critical. When comparing models, you should control for random seed, dataset split, and hardware. Cross-validation is a form of crossover design. Ablation studies, where you remove one component at a time, follow the logic of factorial designs.
Coding Tasks (use CoLab or notebook)
- Construct a 95% confidence interval for the height example, then experiment with different confidence levels and sample sizes.
import jax.numpy as jnp
x_bar = 170.0 # sample mean
sigma = 8.0 # population std (known)
n = 50 # sample size
# Critical values for common confidence levels
z_stars = {0.90: 1.645, 0.95: 1.960, 0.99: 2.576}
for conf, z_star in z_stars.items():
me = z_star * (sigma / jnp.sqrt(n))
lower, upper = x_bar - me, x_bar + me
print(f"{conf*100:.0f}% CI: [{lower:.2f}, {upper:.2f}] (ME = {me:.2f})")
- Estimate $\pi$ using Monte Carlo simulation. Plot how the estimate converges as you increase the number of points.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
key = jax.random.PRNGKey(42)
# Generate random points in [-1, 1] x [-1, 1]
n_points = 100_000
k1, k2 = jax.random.split(key)
x = jax.random.uniform(k1, shape=(n_points,), minval=-1, maxval=1)
y = jax.random.uniform(k2, shape=(n_points,), minval=-1, maxval=1)
# Check which points are inside the unit circle
inside = (x**2 + y**2) <= 1.0
cumulative_inside = jnp.cumsum(inside)
counts = jnp.arange(1, n_points + 1)
pi_estimates = 4.0 * cumulative_inside / counts
plt.figure(figsize=(10, 4))
plt.plot(pi_estimates, color="#3498db", alpha=0.7, linewidth=0.5)
plt.axhline(y=jnp.pi, color="#e74c3c", linestyle="--", label=f"π = {jnp.pi:.6f}")
plt.xlabel("Number of points")
plt.ylabel("Estimate of π")
plt.title("Monte Carlo estimation of π")
plt.legend()
plt.ylim(2.8, 3.5)
plt.show()
print(f"Final estimate: {pi_estimates[-1]:.6f}")
print(f"True value: {jnp.pi:.6f}")
print(f"Error: {abs(pi_estimates[-1] - jnp.pi):.6f}")
- Perform a simple power analysis: for a given effect size and standard deviation, compute the required sample size and verify it by simulation.
import jax
import jax.numpy as jnp
# Parameters
delta = 2.0 # effect size (difference in means)
sigma = 8.0 # population std
alpha = 0.05
power_target = 0.80
# Analytical sample size
z_alpha = 1.96 # two-tailed, alpha=0.05
z_beta = 0.84 # power=0.80
n_required = ((z_alpha + z_beta) * sigma / delta) ** 2
print(f"Required n per group: {n_required:.0f}")
# Verify by simulation
key = jax.random.PRNGKey(7)
n = int(jnp.ceil(n_required))
n_sims = 5000
rejections = 0
for _ in range(n_sims):
key, k1, k2 = jax.random.split(key, 3)
group_a = jax.random.normal(k1, shape=(n,)) * sigma + 50
group_b = jax.random.normal(k2, shape=(n,)) * sigma + 50 + delta
pooled_se = jnp.sqrt(2 * sigma**2 / n)
z = (group_b.mean() - group_a.mean()) / pooled_se
p = 2 * (1 - __import__("jax").scipy.stats.norm.cdf(jnp.abs(z)))
if p <= alpha:
rejections += 1
print(f"Simulated power: {rejections/n_sims:.3f}")
print(f"Target power: {power_target:.3f}")
- Visualise how confidence interval width changes with sample size. This shows why collecting more data gives more precise estimates.
import jax.numpy as jnp
import matplotlib.pyplot as plt
sigma = 8.0
z_star = 1.96 # 95% confidence
sample_sizes = jnp.array([10, 20, 30, 50, 100, 200, 500, 1000], dtype=jnp.float32)
margins = z_star * sigma / jnp.sqrt(sample_sizes)
plt.figure(figsize=(8, 4))
plt.bar([str(int(n)) for n in sample_sizes], margins, color="#3498db", alpha=0.7)
plt.xlabel("Sample size")
plt.ylabel("Margin of error (cm)")
plt.title("95% CI margin of error shrinks with larger samples")
plt.show()
Counting
-
Before we can compute probabilities, we need to count outcomes. If you want to know the chance of drawing a winning hand in poker, you first need to know how many possible hands exist and how many of those are winners. Counting is the machinery that makes probability precise.
-
The simplest counting principle is the multiplication rule. If one decision has $m$ options and a second independent decision has $n$ options, the total number of combined outcomes is $m \times n$.
-
Picture getting dressed in the morning. You have 3 shirts and 4 pants. Each shirt can pair with every pant, giving you $3 \times 4 = 12$ possible outfits.
-
The multiplication rule extends to any number of choices. If you also have 2 pairs of shoes, the total outfits become $3 \times 4 \times 2 = 24$. Each new independent choice multiplies the count.
-
The addition rule handles "or" scenarios. If event A can happen in $m$ ways and event B can happen in $n$ ways, and they cannot both happen at the same time (mutually exclusive), the total number of ways is $m + n$.
-
Suppose you can travel from city X to city Y by car (3 routes) or by train (2 routes). You cannot take both simultaneously, so the total options are $3 + 2 = 5$.
-
When events overlap, you need to subtract the double-counted outcomes. If $A$ and $B$ can co-occur, the count is $|A \cup B| = |A| + |B| - |A \cap B|$. This is the inclusion-exclusion principle, and it will reappear when we discuss probability addition rules.
-
The factorial of a non-negative integer $n$ is the product of all positive integers up to $n$:
$$n! = n \times (n-1) \times (n-2) \times \cdots \times 2 \times 1$$
-
Think of the factorial as answering: in how many ways can you arrange $n$ distinct objects in a line? Three books on a shelf can be arranged in $3! = 3 \times 2 \times 1 = 6$ ways. By convention, $0! = 1$.
-
Factorials grow extremely fast. $10! = 3{,}628{,}800$ and $20!$ is already over $2.4 \times 10^{18}$. This explosive growth is why brute-force search becomes impractical in combinatorial problems.
-
A permutation is an ordered arrangement of objects. When you pick $r$ items from $n$ distinct objects and the order matters, the number of permutations is:
$$P(n, r) = \frac{n!}{(n - r)!}$$
-
Imagine picking a president, vice president, and treasurer from a club of 10 people. The first role has 10 candidates, the second has 9 remaining, the third has 8. That gives $P(10, 3) = 10 \times 9 \times 8 = 720$. The formula confirms this: $\frac{10!}{7!} = 720$.
-
A combination is an unordered selection. When you pick $r$ items from $n$ and the order does not matter, we divide out the redundant orderings:
$$C(n, r) = \binom{n}{r} = \frac{n!}{r!(n - r)!}$$
- The notation $\binom{n}{r}$ is read "n choose r." The key insight is that every combination corresponds to $r!$ permutations (the $r$ chosen items can be rearranged in $r!$ ways), so we divide the permutation count by $r!$.
- Example: from a group of 10 people, how many ways can you form a committee of 3? Order does not matter (there is no president or vice president, just members), so we use combinations:
$$\binom{10}{3} = \frac{10!}{3! \cdot 7!} = \frac{10 \times 9 \times 8}{3 \times 2 \times 1} = 120$$
-
The same 10 people produce 720 permutations but only 120 combinations, because each group of 3 has $3! = 6$ internal orderings.
-
Combinations are central to probability. The binomial coefficient $\binom{n}{r}$ counts the number of ways to get exactly $r$ successes in $n$ trials, which is the heart of the binomial distribution (covered in file 03).
-
Let us work through a classic committee problem that combines multiple counting ideas.
-
Problem: A club has 8 men and 6 women. How many ways can you form a committee of 5 that includes exactly 3 men and 2 women?
-
Step 1: Choose 3 men from 8.
$$\binom{8}{3} = \frac{8!}{3! \cdot 5!} = \frac{8 \times 7 \times 6}{3 \times 2 \times 1} = 56$$
- Step 2: Choose 2 women from 6.
$$\binom{6}{2} = \frac{6!}{2! \cdot 4!} = \frac{6 \times 5}{2 \times 1} = 15$$
- Step 3: Apply the multiplication rule. Each selection of men can pair with each selection of women:
$$56 \times 15 = 840 \text{ committees}$$
-
This pattern, breaking a complex counting problem into independent sub-choices and multiplying, is the standard approach in combinatorics.
-
There are also permutations with repetition. When items can repeat, choosing $r$ items from $n$ types gives $n^r$ outcomes. A 4-digit PIN using digits 0-9 has $10^4 = 10{,}000$ possibilities. Each position has 10 options, and the multiplication rule handles the rest.
-
Combinations with repetition (also called "stars and bars") count how many ways to choose $r$ items from $n$ types when repeats are allowed and order does not matter:
$$\binom{n + r - 1}{r} = \frac{(n + r - 1)!}{r!(n - 1)!}$$
-
Example: choosing 3 scoops from 4 ice cream flavours (repeats allowed) gives $\binom{4 + 3 - 1}{3} = \binom{6}{3} = 20$ options.
-
To summarise the counting toolkit:
| Scenario | Formula |
|---|---|
| Ordered, no repetition (permutation) | $P(n,r) = \frac{n!}{(n-r)!}$ |
| Unordered, no repetition (combination) | $\binom{n}{r} = \frac{n!}{r!(n-r)!}$ |
| Ordered, with repetition | $n^r$ |
| Unordered, with repetition | $\binom{n+r-1}{r}$ |
- Every probability calculation involving equally likely outcomes uses the formula $P(\text{event}) = \frac{\text{favourable outcomes}}{\text{total outcomes}}$. Counting gives us both numbers. With this foundation, we are ready to formalise probability itself in the next file.
Coding Tasks (use CoLab or notebook)
- Compute $P(10, 3)$ and $\binom{10}{3}$ using both the factorial formula and direct computation. Verify that the permutation count is always $r!$ times the combination count.
import jax.numpy as jnp
from math import factorial
n, r = 10, 3
perm = factorial(n) // factorial(n - r)
comb = factorial(n) // (factorial(r) * factorial(n - r))
print(f"P({n},{r}) = {perm}")
print(f"C({n},{r}) = {comb}")
print(f"P / C = {perm // comb} (should equal {r}! = {factorial(r)})")
- Solve the committee problem (3 men from 8, 2 women from 6) programmatically and verify by enumerating all valid committees.
from itertools import combinations
from math import factorial
def comb_count(n, r):
return factorial(n) // (factorial(r) * factorial(n - r))
# Formula approach
men_ways = comb_count(8, 3)
women_ways = comb_count(6, 2)
print(f"Formula: {men_ways} × {women_ways} = {men_ways * women_ways}")
# Enumeration approach
men = [f"M{i}" for i in range(1, 9)]
women = [f"W{i}" for i in range(1, 7)]
count = sum(1 for _ in combinations(men, 3) for _ in combinations(women, 2))
print(f"Enumeration: {count}")
- Count how many 4-character passwords can be made from 26 lowercase letters (with repetition allowed). Then count how many contain no repeated letters.
from math import factorial
n = 26
r = 4
with_rep = n ** r
without_rep = factorial(n) // factorial(n - r)
print(f"With repetition: {with_rep:>10,}")
print(f"Without repetition: {without_rep:>10,}")
print(f"Fraction with repeats: {1 - without_rep/with_rep:.2%}")
- Simulate the birthday problem: in a group of $k$ people, what is the probability that at least two share a birthday? Plot the probability for $k = 1$ to $60$ and find where it crosses 50%.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
def birthday_prob_exact(k):
"""Probability of at least one shared birthday in group of k."""
p_no_match = 1.0
for i in range(k):
p_no_match *= (365 - i) / 365
return 1 - p_no_match
ks = list(range(1, 61))
probs = [birthday_prob_exact(k) for k in ks]
plt.figure(figsize=(8, 4))
plt.plot(ks, probs, color="#3498db", linewidth=2)
plt.axhline(y=0.5, color="#e74c3c", linestyle="--", alpha=0.7, label="50%")
cross = next(k for k, p in zip(ks, probs) if p >= 0.5)
plt.axvline(x=cross, color="#e74c3c", linestyle="--", alpha=0.7)
plt.xlabel("Group size (k)")
plt.ylabel("P(at least one shared birthday)")
plt.title(f"Birthday Problem (crosses 50% at k={cross})")
plt.legend()
plt.grid(alpha=0.3)
plt.show()
Probability Concepts
-
Probability assigns a number between 0 and 1 to an event, measuring how likely it is to happen.
-
A probability of 0 means impossible, 1 means certain, and 0.5 means a coin toss.
-
There are two main interpretations. The frequentist view says probability is the long-run relative frequency: flip a fair coin 10,000 times and heads will appear roughly 50% of the time.
-
The Bayesian view says probability is a degree of belief: you might say there is a 70% chance it rains tomorrow, even though tomorrow only happens once.
-
Both interpretations use the same mathematical rules. The difference is philosophical, but it matters in ML. Frequentist methods give you point estimates. Bayesian methods give you full distributions over parameters.
-
The sample space $S$ is the set of all possible outcomes of an experiment. Flip a coin: $S = {H, T}$. Roll a die: $S = {1, 2, 3, 4, 5, 6}$.
-
An event is any subset of the sample space. "Rolling an even number" is the event $A = {2, 4, 6}$, which is a subset of $S$.
-
The probability of an event when all outcomes are equally likely is simply counting (from file 01):
$$P(A) = \frac{|A|}{|S|} = \frac{\text{favourable outcomes}}{\text{total outcomes}}$$
- For the even-number example: $P(\text{even}) = \frac{3}{6} = 0.5$.
- The complement of event $A$, written $A'$ or $A^c$, is everything in $S$ that is not in $A$. Since every outcome is either in $A$ or not:
$$P(A') = 1 - P(A)$$
-
Complements are often the easier route. Instead of counting all the ways to get at least one head in 5 coin flips, count the one way to get no heads and subtract: $P(\text{at least one head}) = 1 - P(\text{all tails}) = 1 - (0.5)^5 = 0.969$.
-
Two events are mutually exclusive (disjoint) if they cannot both happen: $A \cap B = \emptyset$. Rolling a 2 and rolling a 5 on a single die are mutually exclusive.
-
The addition rule for mutually exclusive events is straightforward:
$$P(A \cup B) = P(A) + P(B) \quad \text{(if } A \cap B = \emptyset\text{)}$$
- When events can overlap, you need the general addition rule to avoid double-counting the intersection:
$$P(A \cup B) = P(A) + P(B) - P(A \cap B)$$
-
This mirrors the inclusion-exclusion principle from counting. The Venn diagram above shows why: the purple region (intersection) gets counted once in $P(A)$ and again in $P(B)$, so we subtract it once.
-
Joint probability $P(A \cap B)$ is the probability that both $A$ and $B$ occur. In a deck of cards, $P(\text{red} \cap \text{king}) = \frac{2}{52}$ because there are 2 red kings.
-
Marginal probability is the probability of a single event regardless of others. $P(\text{red}) = \frac{26}{52} = 0.5$ is a marginal probability. If you have a joint distribution over two variables, the marginal is obtained by summing (or integrating) over the other variable.
-
Conditional probability answers: given that $B$ has already happened, what is the probability of $A$? We shrink the sample space from $S$ down to $B$, and ask what fraction of $B$ also belongs to $A$:
$$P(A | B) = \frac{P(A \cap B)}{P(B)}, \quad P(B) > 0$$
-
Example: you draw a card and someone tells you it is red. What is the probability it is a king? There are 26 red cards and 2 of them are kings, so $P(\text{king} | \text{red}) = \frac{2}{26} = \frac{1}{13}$. Using the formula: $P(\text{king} \cap \text{red}) / P(\text{red}) = \frac{2/52}{26/52} = \frac{1}{13}$.
-
Two events are independent if knowing one happened tells you nothing about the other. Formally:
$$P(A \cap B) = P(A) \cdot P(B)$$
-
Equivalently, $P(A | B) = P(A)$. Flipping two separate coins are independent events. Drawing two cards without replacement are not independent (the first draw changes what remains).
-
Independence is a massive simplifier. For independent events, joint probabilities factor into products, which makes computation tractable. Many ML models assume independence between features (e.g. Naive Bayes) precisely because of this simplification.
-
The multiplication rule for any two events rearranges the conditional probability formula:
$$P(A \cap B) = P(A | B) \cdot P(B) = P(B | A) \cdot P(A)$$
-
For independent events, this simplifies to $P(A \cap B) = P(A) \cdot P(B)$ since the conditional equals the marginal.
-
Bayes' theorem is one of the most important results in probability and the foundation of Bayesian ML. It lets you reverse the direction of a conditional probability:
$$P(A | B) = \frac{P(B | A) \cdot P(A)}{P(B)}$$
- The theorem follows directly from writing $P(A \cap B)$ two ways: $P(B|A) \cdot P(A) = P(A|B) \cdot P(B)$, then solving for $P(A|B)$.
-
Each component has a name:
- Prior $P(A)$: your initial belief before seeing evidence
- Likelihood $P(B|A)$: how probable the evidence is, assuming $A$ is true
- Evidence $P(B)$: the total probability of seeing the evidence, acts as a normaliser
- Posterior $P(A|B)$: your updated belief after seeing the evidence
-
Let us work through the classic medical diagnosis example. Suppose a disease affects 1% of the population. A test for the disease is 95% accurate: it correctly identifies 95% of sick people (sensitivity) and correctly identifies 90% of healthy people (specificity).
-
You test positive. What is the probability you actually have the disease?
-
Let $D$ = having the disease, $+$ = testing positive.
- Prior: $P(D) = 0.01$
- Likelihood: $P(+ | D) = 0.95$
- False positive rate: $P(+ | D') = 0.10$
-
We need $P(+)$. By the law of total probability:
$$P(+) = P(+ | D) \cdot P(D) + P(+ | D') \cdot P(D')$$ $$= 0.95 \times 0.01 + 0.10 \times 0.99 = 0.0095 + 0.099 = 0.1085$$
- Now apply Bayes' theorem:
$$P(D | +) = \frac{P(+ | D) \cdot P(D)}{P(+)} = \frac{0.95 \times 0.01}{0.1085} \approx 0.088$$
-
Despite the test being "95% accurate," a positive result only gives you about an 8.8% chance of having the disease. The prior matters enormously. Because the disease is rare, most positive results are false positives. This is a crucial insight for any classification problem in ML: when classes are imbalanced, accuracy alone is misleading.
-
The law of total probability partitions the sample space into mutually exclusive, exhaustive events $B_1, B_2, \ldots, B_n$ and expresses any event $A$ as:
$$P(A) = \sum_{i=1}^{n} P(A | B_i) \cdot P(B_i)$$
-
This is exactly what we used to compute $P(+)$ in the medical example: we split the population into "has disease" and "does not have disease."
-
The chain rule of probability generalises the multiplication rule to any number of events:
$$P(A_1 \cap A_2 \cap \cdots \cap A_n) = P(A_1) \cdot P(A_2 | A_1) \cdot P(A_3 | A_1 \cap A_2) \cdots P(A_n | A_1 \cap \cdots \cap A_{n-1})$$
-
Each factor conditions on everything that came before. This is the backbone of autoregressive language models: the probability of a sentence is the product of each word's probability given all previous words.
-
Conditional independence means two events are independent given a third. $A$ and $B$ are conditionally independent given $C$ if:
$$P(A \cap B | C) = P(A | C) \cdot P(B | C)$$
-
Events can be marginally dependent but conditionally independent, or vice versa. For example, two students' exam scores may be correlated (both depend on the difficulty of the exam), but given the exam difficulty, their scores are independent.
-
Conditional independence is the key assumption behind graphical models like Bayesian networks. It lets you factorise complex joint distributions into manageable pieces, making inference computationally feasible.
Coding Tasks (use CoLab or notebook)
- Simulate the medical diagnosis problem. Generate a population of 100,000 people, apply the disease prevalence and test accuracy, and verify that Bayes' theorem gives the correct posterior.
import jax
import jax.numpy as jnp
key = jax.random.PRNGKey(42)
n = 100_000
# Generate population
k1, k2 = jax.random.split(key)
has_disease = jax.random.bernoulli(k1, p=0.01, shape=(n,))
# Generate test results
k3, k4 = jax.random.split(k2)
# Sensitivity: P(+|D) = 0.95, Specificity: P(-|D') = 0.90
test_positive = jnp.where(
has_disease,
jax.random.bernoulli(k3, p=0.95, shape=(n,)),
jax.random.bernoulli(k4, p=0.10, shape=(n,))
)
# Among those who tested positive, what fraction actually has the disease?
positives = test_positive.astype(bool)
true_positives = (has_disease & positives).sum()
total_positives = positives.sum()
print(f"Total positive tests: {total_positives}")
print(f"True positives: {true_positives}")
print(f"P(Disease | Positive) = {true_positives / total_positives:.4f}")
print(f"Bayes' formula: {0.95 * 0.01 / 0.1085:.4f}")
- Verify the addition rule by simulation. Generate random events A and B with known probabilities and overlap, then check that $P(A \cup B) = P(A) + P(B) - P(A \cap B)$.
import jax
import jax.numpy as jnp
key = jax.random.PRNGKey(0)
n = 200_000
k1, k2 = jax.random.split(key)
# Events: A = value < 0.4, B = value < 0.6 (overlap at < 0.4)
vals_a = jax.random.uniform(k1, shape=(n,))
vals_b = jax.random.uniform(k2, shape=(n,))
A = vals_a < 0.4
B = vals_b < 0.6
p_a = A.mean()
p_b = B.mean()
p_a_and_b = (A & B).mean()
p_a_or_b = (A | B).mean()
print(f"P(A) = {p_a:.4f}")
print(f"P(B) = {p_b:.4f}")
print(f"P(A ∩ B) = {p_a_and_b:.4f}")
print(f"P(A ∪ B) simulated = {p_a_or_b:.4f}")
print(f"P(A) + P(B) - P(A∩B) = {p_a + p_b - p_a_and_b:.4f}")
- Demonstrate that conditional probability changes with evidence. Simulate rolling two dice and compute $P(\text{sum} = 7)$, then $P(\text{sum} = 7 | \text{first die} = 3)$.
import jax
import jax.numpy as jnp
key = jax.random.PRNGKey(1)
n = 500_000
k1, k2 = jax.random.split(key)
d1 = jax.random.randint(k1, shape=(n,), minval=1, maxval=7)
d2 = jax.random.randint(k2, shape=(n,), minval=1, maxval=7)
total = d1 + d2
# Unconditional
p_sum7 = (total == 7).mean()
print(f"P(sum=7) = {p_sum7:.4f} (exact: {6/36:.4f})")
# Conditional on first die = 3
mask = d1 == 3
p_sum7_given_d1_3 = (total[mask] == 7).mean()
print(f"P(sum=7 | d1=3) = {p_sum7_given_d1_3:.4f} (exact: {1/6:.4f})")
- Implement Bayes' theorem as a function and use it to update beliefs iteratively. Start with a uniform prior over a coin's bias and update after observing each flip.
import jax.numpy as jnp
import matplotlib.pyplot as plt
def bayes_update(prior, likelihood):
"""Multiply prior by likelihood and normalise."""
posterior = prior * likelihood
return posterior / posterior.sum()
# Discretise possible bias values
theta = jnp.linspace(0, 1, 200)
prior = jnp.ones_like(theta) # uniform prior
prior = prior / prior.sum()
# Observed flips: 1=heads, 0=tails
flips = [1, 1, 0, 1, 1, 1, 0, 1, 0, 1]
plt.figure(figsize=(10, 5))
plt.plot(theta, prior, "--", color="#999", label="prior")
for i, flip in enumerate(flips):
likelihood = theta if flip == 1 else (1 - theta)
prior = bayes_update(prior, likelihood)
if i in [0, 2, 4, 9]:
plt.plot(theta, prior, label=f"after {i+1} flips", linewidth=2)
plt.xlabel("Coin bias θ")
plt.ylabel("Belief (normalised)")
plt.title("Bayesian updating: belief about coin bias")
plt.legend()
plt.grid(alpha=0.3)
plt.show()
Probability Distributions
-
In Chapter 4 we introduced random variables, PMFs, PDFs, and CDFs. Here we catalogue the most important probability distributions you will encounter in ML and statistics, giving the intuition, formula, mean, and variance for each.
-
Quick recap of the three core functions (see Chapter 4 for full definitions):
- PMF $P(X = x)$: gives the probability of each discrete outcome. The bars in a bar chart.
- PDF $f(x)$: gives the density at each point for continuous variables. The area under the curve between two points is the probability.
- CDF $F(x) = P(X \le x)$: the cumulative probability up to $x$. Always goes from 0 to 1 and never decreases.
-
The support of a distribution is the set of values where the PMF or PDF is positive. For a die roll, the support is ${1,2,3,4,5,6}$. For the normal distribution, the support is all real numbers $(-\infty, \infty)$.
-
Distributions divide cleanly into two families: discrete (countable outcomes, use PMFs) and continuous (uncountable outcomes, use PDFs).
-
Bernoulli distribution: the simplest distribution. A single trial with two outcomes: success (1) with probability $p$ and failure (0) with probability $1-p$.
$$P(X = x) = p^x (1 - p)^{1-x}, \quad x \in {0, 1}$$
-
Mean: $E[X] = p$. Variance: $\text{Var}(X) = p(1-p)$.
-
Every coin flip, every yes/no classification, every binary outcome is a Bernoulli trial. In ML, the output of a sigmoid function is exactly the $p$ parameter of a Bernoulli distribution.
-
Binomial distribution: count the number of successes in $n$ independent Bernoulli trials, each with the same probability $p$.
$$P(X = k) = \binom{n}{k} p^k (1-p)^{n-k}, \quad k = 0, 1, \ldots, n$$
-
The binomial coefficient $\binom{n}{k}$ from file 01 counts how many ways to arrange $k$ successes among $n$ trials.
-
Mean: $E[X] = np$. Variance: $\text{Var}(X) = np(1-p)$.
-
Example: flip a biased coin ($p = 0.7$) eight times. The probability of getting exactly 6 heads is $\binom{8}{6}(0.7)^6(0.3)^2 = 28 \times 0.1176 \times 0.09 \approx 0.296$.
-
Poisson distribution: counts the number of events in a fixed interval of time or space, given a known average rate $\lambda$. Useful when events are rare and independent.
$$P(X = k) = \frac{\lambda^k e^{-\lambda}}{k!}, \quad k = 0, 1, 2, \ldots$$
-
Mean: $E[X] = \lambda$. Variance: $\text{Var}(X) = \lambda$. The mean equals the variance, which is a signature property.
-
Examples: emails per hour ($\lambda = 5$), typos per page, server requests per second. In ML, Poisson regression models count data where a linear model would predict negative counts.
-
As $n \to \infty$ and $p \to 0$ with $np = \lambda$ held constant, the Binomial$(n,p)$ converges to Poisson$(\lambda)$. This is why the Poisson works well for rare events in large populations.
-
Geometric distribution: counts the number of trials until the first success. "How many coins do I flip before I get my first heads?"
$$P(X = k) = (1-p)^{k-1} p, \quad k = 1, 2, 3, \ldots$$
-
Mean: $E[X] = 1/p$. Variance: $\text{Var}(X) = (1-p)/p^2$.
-
The geometric distribution is memoryless: the probability of waiting $k$ more trials for success does not depend on how many trials you have already waited. This makes it special among discrete distributions.
-
Negative Binomial distribution: generalises the geometric by counting trials until the $r$-th success (geometric is the special case $r=1$).
$$P(X = k) = \binom{k-1}{r-1} p^r (1-p)^{k-r}, \quad k = r, r+1, r+2, \ldots$$
-
Mean: $E[X] = r/p$. Variance: $\text{Var}(X) = r(1-p)/p^2$.
-
The Negative Binomial is also used in practice to model overdispersed count data (where the variance exceeds the mean), which the Poisson cannot handle.
-
Now we move to continuous distributions.
-
Uniform distribution: all values in an interval $[a, b]$ are equally likely. The PDF is a flat rectangle.
$$f(x) = \frac{1}{b - a}, \quad a \le x \le b$$
-
Mean: $E[X] = \frac{a+b}{2}$. Variance: $\text{Var}(X) = \frac{(b-a)^2}{12}$.
-
Random number generators produce Uniform(0,1) samples as their starting point. Other distributions are generated by transforming these uniform samples.
-
Normal (Gaussian) distribution: the most important distribution in statistics. It arises naturally from the Central Limit Theorem (see Chapter 4): averages of many independent random variables tend toward a normal distribution regardless of the original distribution.
$$f(x) = \frac{1}{\sigma\sqrt{2\pi}} \exp!\left(-\frac{(x - \mu)^2}{2\sigma^2}\right)$$
-
Mean: $E[X] = \mu$. Variance: $\text{Var}(X) = \sigma^2$.
-
The standard normal has $\mu = 0$ and $\sigma = 1$. Any normal variable $X$ can be standardised to a standard normal $Z$ using $Z = (X - \mu)/\sigma$.
-
The empirical rule (68-95-99.7 rule) says:
- About 68% of data falls within $\pm 1\sigma$ of the mean
- About 95% falls within $\pm 2\sigma$
- About 99.7% falls within $\pm 3\sigma$
-
In ML, normal distributions appear everywhere: weight initialisation, noise in data augmentation, the assumption behind MSE loss (which implicitly assumes Gaussian errors), and the reparameterisation trick in variational autoencoders.
-
Exponential distribution: models the time between events in a Poisson process. If events arrive at rate $\lambda$, the waiting time between them follows Exponential$(\lambda)$.
$$f(x) = \lambda e^{-\lambda x}, \quad x \ge 0$$
-
Mean: $E[X] = 1/\lambda$. Variance: $\text{Var}(X) = 1/\lambda^2$.
-
Like the geometric distribution for discrete variables, the exponential is memoryless: $P(X > s + t | X > s) = P(X > t)$. The probability of waiting another $t$ units does not depend on how long you have already waited.
-
Gamma distribution: generalises the exponential. It models the time until the $\alpha$-th event in a Poisson process (exponential is $\alpha = 1$).
$$f(x) = \frac{\beta^\alpha}{\Gamma(\alpha)} x^{\alpha - 1} e^{-\beta x}, \quad x > 0$$
-
Here $\alpha$ (shape) controls the shape and $\beta$ (rate) controls the scale. $\Gamma(\alpha)$ is the gamma function, which extends factorials to real numbers: $\Gamma(n) = (n-1)!$ for positive integers.
-
Mean: $E[X] = \alpha/\beta$. Variance: $\text{Var}(X) = \alpha/\beta^2$.
-
Beta distribution: defined on the interval $[0, 1]$, making it perfect for modelling probabilities, proportions, and rates.
$$f(x) = \frac{x^{\alpha - 1}(1 - x)^{\beta - 1}}{B(\alpha, \beta)}, \quad 0 \le x \le 1$$
-
The denominator $B(\alpha, \beta) = \frac{\Gamma(\alpha)\Gamma(\beta)}{\Gamma(\alpha + \beta)}$ is the beta function, a normalising constant.
-
Mean: $E[X] = \frac{\alpha}{\alpha + \beta}$. Variance: $\text{Var}(X) = \frac{\alpha\beta}{(\alpha+\beta)^2(\alpha+\beta+1)}$.
-
The Beta distribution is the conjugate prior for the Bernoulli and Binomial likelihoods. This means if your prior is Beta and your data is Bernoulli, the posterior is also Beta, which makes Bayesian updating analytically tractable. We will use this in file 04.
- Chi-squared distribution ($\chi^2$): if you take $k$ independent standard normal random variables and sum their squares, the result follows a $\chi^2$ distribution with $k$ degrees of freedom.
$$f(x) = \frac{1}{2^{k/2}\Gamma(k/2)} x^{k/2 - 1} e^{-x/2}, \quad x > 0$$
-
Mean: $E[X] = k$. Variance: $\text{Var}(X) = 2k$.
-
The $\chi^2$ distribution is actually a special case of the Gamma distribution with $\alpha = k/2$ and $\beta = 1/2$. It appears in hypothesis testing (the chi-squared test from Chapter 4), goodness-of-fit tests, and in computing confidence intervals for variance.
-
Student's t-distribution: looks like a normal distribution but with heavier tails. It arises when you estimate the mean of a normally distributed population using a small sample and the population variance is unknown.
$$f(x) = \frac{\Gamma!\left(\frac{\nu+1}{2}\right)}{\sqrt{\nu\pi},\Gamma!\left(\frac{\nu}{2}\right)} \left(1 + \frac{x^2}{\nu}\right)^{-(\nu+1)/2}$$
-
The parameter $\nu$ (nu) is the degrees of freedom. As $\nu \to \infty$, the t-distribution converges to the standard normal. With small $\nu$, the heavier tails give more probability to extreme values, reflecting the extra uncertainty from a small sample.
-
Mean: $E[X] = 0$ (for $\nu > 1$). Variance: $\text{Var}(X) = \frac{\nu}{\nu - 2}$ (for $\nu > 2$).
-
The t-distribution is used in t-tests (Chapter 4) and shows up in Bayesian inference as a marginal distribution when integrating out unknown variance.
-
To summarise the key distributions:
| Distribution | Type | Support | Mean | Variance |
|---|---|---|---|---|
| Bernoulli$(p)$ | Discrete | ${0,1}$ | $p$ | $p(1-p)$ |
| Binomial$(n,p)$ | Discrete | ${0,\ldots,n}$ | $np$ | $np(1-p)$ |
| Poisson$(\lambda)$ | Discrete | ${0,1,2,\ldots}$ | $\lambda$ | $\lambda$ |
| Geometric$(p)$ | Discrete | ${1,2,3,\ldots}$ | $1/p$ | $(1-p)/p^2$ |
| Uniform$(a,b)$ | Continuous | $[a,b]$ | $(a+b)/2$ | $(b-a)^2/12$ |
| Normal$(\mu,\sigma^2)$ | Continuous | $(-\infty,\infty)$ | $\mu$ | $\sigma^2$ |
| Exponential$(\lambda)$ | Continuous | $[0,\infty)$ | $1/\lambda$ | $1/\lambda^2$ |
| Gamma$(\alpha,\beta)$ | Continuous | $(0,\infty)$ | $\alpha/\beta$ | $\alpha/\beta^2$ |
| Beta$(\alpha,\beta)$ | Continuous | $[0,1]$ | $\alpha/(\alpha+\beta)$ | see above |
| $\chi^2(k)$ | Continuous | $(0,\infty)$ | $k$ | $2k$ |
| Student's $t(\nu)$ | Continuous | $(-\infty,\infty)$ | $0$ | $\nu/(\nu-2)$ |
Coding Tasks (use CoLab or notebook)
- Plot the Binomial PMF for $n=20$ with several values of $p$. Observe how the shape shifts from left-skewed to symmetric to right-skewed.
import jax.numpy as jnp
import matplotlib.pyplot as plt
from math import comb
n = 20
ks = jnp.arange(0, n + 1)
fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True)
for ax, p, color in zip(axes, [0.2, 0.5, 0.8], ["#e74c3c", "#3498db", "#27ae60"]):
pmf = jnp.array([comb(n, int(k)) * p**k * (1-p)**(n-k) for k in ks])
ax.bar(ks, pmf, color=color, alpha=0.7)
ax.set_title(f"Binomial(n={n}, p={p})")
ax.set_xlabel("k")
axes[0].set_ylabel("P(X = k)")
plt.tight_layout()
plt.show()
- Verify the Poisson approximation to the Binomial. Set $n = 1000$, $p = 0.003$, and compare Binomial$(n, p)$ with Poisson$(\lambda = np)$.
import jax.numpy as jnp
import matplotlib.pyplot as plt
from math import comb, factorial, exp
n, p = 1000, 0.003
lam = n * p
ks = jnp.arange(0, 15)
binom_pmf = jnp.array([comb(n, int(k)) * p**k * (1-p)**(n-k) for k in ks])
poisson_pmf = jnp.array([lam**k * exp(-lam) / factorial(int(k)) for k in ks])
plt.figure(figsize=(8, 4))
plt.bar(ks - 0.15, binom_pmf, width=0.3, color="#3498db", alpha=0.7, label=f"Binomial({n},{p})")
plt.bar(ks + 0.15, poisson_pmf, width=0.3, color="#e74c3c", alpha=0.7, label=f"Poisson({lam})")
plt.xlabel("k")
plt.ylabel("P(X = k)")
plt.title("Poisson approximation to Binomial")
plt.legend()
plt.show()
- Sample from a Normal distribution and verify the empirical rule. Count what fraction of samples fall within 1, 2, and 3 standard deviations.
import jax
import jax.numpy as jnp
key = jax.random.PRNGKey(42)
mu, sigma = 5.0, 2.0
samples = mu + sigma * jax.random.normal(key, shape=(100_000,))
for k in [1, 2, 3]:
within = jnp.abs(samples - mu) <= k * sigma
print(f"Within {k}σ: {within.mean():.4f} (expected: {[0.6827, 0.9545, 0.9973][k-1]:.4f})")
- Explore the Beta distribution by varying $\alpha$ and $\beta$. Plot several shapes and see how the distribution changes from uniform to skewed to concentrated.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
x = jnp.linspace(0.01, 0.99, 200)
def beta_pdf(x, a, b):
# Unnormalised for shape comparison
return x**(a-1) * (1-x)**(b-1)
plt.figure(figsize=(10, 5))
params = [(1,1,"Uniform"), (2,5,"Left skew"), (5,2,"Right skew"),
(5,5,"Symmetric"), (0.5,0.5,"U-shape")]
colors = ["#999", "#e74c3c", "#3498db", "#27ae60", "#9b59b6"]
for (a, b, label), color in zip(params, colors):
y = beta_pdf(x, a, b)
y = y / jnp.trapezoid(y, x) # normalise
plt.plot(x, y, label=f"α={a}, β={b} ({label})", color=color, linewidth=2)
plt.xlabel("x")
plt.ylabel("Density")
plt.title("Beta distribution shapes")
plt.legend()
plt.grid(alpha=0.3)
plt.show()
Bayesian Methods and Sequential Models
-
So far we have described distributions and how to compute probabilities. Now we tackle the question at the heart of ML: given observed data, how do we find the best parameters for our model?
-
Maximum Likelihood Estimation (MLE) answers this directly. Pick the parameter values that make the observed data most probable.
-
Formally, given data $D = {x_1, x_2, \ldots, x_n}$ and a model with parameter $\theta$, the likelihood function is:
$$L(\theta | D) = P(D | \theta) = \prod_{i=1}^{n} P(x_i | \theta)$$
- The product assumes the data points are independent and identically distributed (i.i.d.). The MLE estimate is:
$$\hat{\theta}{\text{MLE}} = \arg\max\theta L(\theta | D)$$
- In practice we maximise the log-likelihood instead, because the log turns products into sums and prevents numerical underflow:
$$\ell(\theta) = \log L(\theta | D) = \sum_{i=1}^{n} \log P(x_i | \theta)$$
-
Since $\log$ is monotonically increasing, the $\theta$ that maximises $\ell(\theta)$ also maximises $L(\theta)$.
-
Coin toss example: you flip a coin 10 times and get 7 heads. What is the MLE estimate of the coin's bias $p$ (probability of heads)?
-
Each flip is Bernoulli($p$), so the likelihood of 7 heads in 10 flips is:
$$L(p) = \binom{10}{7} p^7 (1-p)^3$$
-
Taking the log and differentiating: $\frac{d\ell}{dp} = \frac{7}{p} - \frac{3}{1-p} = 0$, which gives $\hat{p}_{\text{MLE}} = 7/10 = 0.7$.
-
MLE is intuitive and simple. If you got 7 heads in 10 flips, the most likely bias is 0.7. But notice the problem: if you got 10 heads in 10 flips, MLE says $\hat{p} = 1$, meaning the coin will always land heads. That seems overconfident given only 10 observations.
-
Maximum A Posteriori (MAP) estimation fixes this by adding prior beliefs. Instead of maximising just the likelihood, MAP maximises the posterior:
$$\hat{\theta}{\text{MAP}} = \arg\max\theta P(\theta | D) = \arg\max_\theta P(D | \theta) \cdot P(\theta)$$
-
We dropped $P(D)$ from the denominator because it does not depend on $\theta$ and does not affect the argmax.
-
The prior $P(\theta)$ encodes what we believed about $\theta$ before seeing data. If we use a Beta(2, 2) prior for our coin bias (expressing a mild belief that the coin is roughly fair), the MAP estimate is no longer simply the proportion of heads. It gets pulled toward 0.5.
- With a Beta($\alpha$, $\beta$) prior and observing $h$ heads and $t$ tails, the posterior is Beta($\alpha + h$, $\beta + t$), and the MAP estimate is:
$$\hat{p}_{\text{MAP}} = \frac{\alpha + h - 1}{\alpha + \beta + h + t - 2}$$
-
For our example with Beta(2,2) prior, 7 heads, 3 tails: $\hat{p}_{\text{MAP}} = \frac{2 + 7 - 1}{2 + 2 + 10 - 2} = \frac{8}{12} = 0.667$.
-
Notice how the MAP estimate (0.667) is pulled toward 0.5 compared to the MLE (0.7). The prior acts as regularisation. In ML, L2 regularisation (weight decay) is exactly equivalent to MAP estimation with a Gaussian prior on the weights.
-
Full Bayesian inference goes further than MAP. Instead of finding a single best $\theta$, it maintains the entire posterior distribution $P(\theta | D)$. This gives you not just a point estimate but a measure of uncertainty.
-
For the biased coin with Beta(2,2) prior and 7 heads, 3 tails, the full posterior is Beta(9, 5). The mean of this distribution is $9/14 \approx 0.643$, and its spread tells us how confident we are. With more data, the posterior narrows.
-
The three approaches form a spectrum:
- MLE: no prior, just data. Fast, but can overfit with little data.
- MAP: point estimate with prior regularisation. Adds robustness.
- Full Bayesian: entire posterior distribution. Most informative, but often computationally expensive.
-
Markov chains model sequences where the next state depends only on the current state, not on the history. This "memorylessness" is called the Markov property:
$$P(X_{t+1} | X_t, X_{t-1}, \ldots, X_1) = P(X_{t+1} | X_t)$$
-
Think of weather. Tomorrow's weather depends on today's weather, but not on last week's (a simplification, but surprisingly useful).
-
A Markov chain has a finite set of states and a transition matrix $T$ where entry $T_{ij}$ gives the probability of moving from state $i$ to state $j$. Each row sums to 1.
- For the weather example above, the transition matrix is:
-
If today is rainy (state vector $\mathbf{s}_0 = [1, 0, 0]$), the probability distribution over tomorrow's weather is $\mathbf{s}_1 = \mathbf{s}_0 T = [0.3, 0.4, 0.3]$. Two days from now: $\mathbf{s}_2 = \mathbf{s}_0 T^2$. This uses matrix multiplication from Chapter 1.
-
Many Markov chains converge to a stationary distribution $\pi$ such that $\pi T = \pi$. No matter where you start, after enough steps the chain settles into $\pi$. This property is the foundation of MCMC (Markov Chain Monte Carlo), a widely used sampling technique in Bayesian ML.
-
Hidden Markov Models (HMMs) extend Markov chains by adding a layer of indirection. The true states are hidden (unobserved), and at each time step the hidden state emits an observable signal.
-
An HMM has three components:
- Transition probabilities $P(z_t | z_{t-1})$: how hidden states evolve (the Markov chain)
- Emission probabilities $P(x_t | z_t)$: what each hidden state produces as observable output
- Initial distribution $P(z_1)$: the starting hidden state probabilities
-
Umbrella example: suppose you cannot see the weather directly but you can observe whether your friend carries an umbrella. The hidden states are {Rainy, Sunny} and the observation is {Umbrella, No umbrella}.
-
Transition probabilities: $P(\text{Rainy}|\text{Rainy}) = 0.7$, $P(\text{Sunny}|\text{Rainy}) = 0.3$, $P(\text{Rainy}|\text{Sunny}) = 0.4$, $P(\text{Sunny}|\text{Sunny}) = 0.6$.
-
Emission probabilities: $P(\text{Umbrella}|\text{Rainy}) = 0.9$, $P(\text{No umbrella}|\text{Rainy}) = 0.1$, $P(\text{Umbrella}|\text{Sunny}) = 0.2$, $P(\text{No umbrella}|\text{Sunny}) = 0.8$.
-
The key questions for HMMs are:
- Decoding: given observations, what is the most likely sequence of hidden states? Solved by the Viterbi algorithm.
- Evaluation: what is the probability of an observation sequence? Solved by the Forward algorithm.
- Learning: given observations, what are the best model parameters? Solved by the Baum-Welch algorithm (an instance of Expectation-Maximisation).
-
Viterbi walkthrough: suppose you observe [Umbrella, Umbrella, No umbrella] and want to find the most likely weather sequence.
-
Start with initial probabilities. Assume $P(R) = 0.5$, $P(S) = 0.5$.
-
Day 1 (observe Umbrella):
- $V_1(R) = P(R) \cdot P(U|R) = 0.5 \times 0.9 = 0.45$
- $V_1(S) = P(S) \cdot P(U|S) = 0.5 \times 0.2 = 0.10$
-
Day 2 (observe Umbrella):
- $V_2(R) = \max(V_1(R) \cdot P(R|R), V_1(S) \cdot P(R|S)) \cdot P(U|R)$
- $= \max(0.45 \times 0.7, 0.10 \times 0.4) \times 0.9 = \max(0.315, 0.04) \times 0.9 = 0.2835$
- $V_2(S) = \max(V_1(R) \cdot P(S|R), V_1(S) \cdot P(S|S)) \cdot P(U|S)$
- $= \max(0.45 \times 0.3, 0.10 \times 0.6) \times 0.2 = \max(0.135, 0.06) \times 0.2 = 0.027$
-
Day 3 (observe No umbrella):
- $V_3(R) = \max(0.2835 \times 0.7, 0.027 \times 0.4) \times 0.1 = 0.1985 \times 0.1 = 0.01985$
- $V_3(S) = \max(0.2835 \times 0.3, 0.027 \times 0.6) \times 0.8 = 0.08505 \times 0.8 = 0.06804$
-
Day 3's maximum is at Sunny. Backtracking: Day 3 = Sunny (from R), Day 2 = Rainy (from R), Day 1 = Rainy. Most likely sequence: Rainy, Rainy, Sunny.
-
The Forward-Backward algorithm computes the probability of being in each hidden state at each time step, given the entire observation sequence. The forward pass computes $P(z_t, x_{1:t})$ and the backward pass computes $P(x_{t+1:T} | z_t)$. Multiplying these gives the smoothed state probabilities.
-
The Baum-Welch algorithm learns HMM parameters from data when the hidden states are unobserved. It is an Expectation-Maximisation (EM) algorithm: the E-step uses forward-backward to estimate which hidden states generated the observations, and the M-step updates the transition and emission probabilities.
-
HMMs were historically dominant in speech recognition (hidden phoneme states emit acoustic signals) and bioinformatics (hidden gene states emit DNA base pairs). While deep learning has largely superseded HMMs in these fields, the ideas of hidden states, emissions, and sequential inference remain central to sequence models.
-
Conditional Random Fields (CRFs) improve on HMMs by removing the independence assumption on emissions. In an HMM, the observation at time $t$ depends only on the hidden state at time $t$. CRFs allow the label at position $t$ to depend on the entire input sequence.
-
A linear-chain CRF models the conditional probability of a label sequence $\mathbf{y}$ given an input sequence $\mathbf{x}$:
$$P(\mathbf{y} | \mathbf{x}) = \frac{1}{Z(\mathbf{x})} \exp!\left(\sum_t \left[\sum_k \lambda_k f_k(y_t, y_{t-1}, \mathbf{x}, t)\right]\right)$$
-
Here $f_k$ are feature functions (which can look at any part of the input), $\lambda_k$ are learned weights, and $Z(\mathbf{x})$ is a normalising constant.
-
CRFs are discriminative models (they model $P(\mathbf{y}|\mathbf{x})$ directly) while HMMs are generative (they model $P(\mathbf{x}, \mathbf{y})$). This distinction is the same as logistic regression (discriminative) vs Naive Bayes (generative).
-
In modern NLP, CRF layers are often added on top of neural networks (BiLSTM-CRF, BERT-CRF) for tasks like named entity recognition and part-of-speech tagging, where capturing label dependencies is important.
Coding Tasks (use CoLab or notebook)
- Implement MLE and MAP for a coin toss experiment. Observe how the MAP estimate changes with different priors and different amounts of data.
import jax.numpy as jnp
import matplotlib.pyplot as plt
# Data: observed coin flips
heads, tails = 7, 3
# MLE
p_mle = heads / (heads + tails)
print(f"MLE: {p_mle:.4f}")
# MAP with Beta prior
for alpha, beta in [(1,1), (2,2), (5,5), (10,10)]:
p_map = (alpha + heads - 1) / (alpha + beta + heads + tails - 2)
print(f"MAP (Beta({alpha},{beta})): {p_map:.4f}")
# Visualise posterior for Beta(2,2) prior
theta = jnp.linspace(0.01, 0.99, 200)
# Posterior is Beta(alpha+heads, beta+tails)
a_post, b_post = 2 + heads, 2 + tails
posterior = theta**(a_post-1) * (1-theta)**(b_post-1)
posterior = posterior / jnp.trapezoid(posterior, theta)
plt.figure(figsize=(8, 4))
plt.plot(theta, posterior, color="#e74c3c", linewidth=2, label=f"Posterior Beta({a_post},{b_post})")
plt.axvline(p_mle, color="#3498db", linestyle="--", label=f"MLE = {p_mle:.2f}")
plt.axvline((a_post-1)/(a_post+b_post-2), color="#e74c3c", linestyle="--", label=f"MAP = {(a_post-1)/(a_post+b_post-2):.3f}")
plt.xlabel("θ (coin bias)")
plt.ylabel("Density")
plt.title("Posterior distribution after 7H, 3T with Beta(2,2) prior")
plt.legend()
plt.grid(alpha=0.3)
plt.show()
- Build a Markov chain for the weather model and simulate it. Compute the stationary distribution both by simulation and by solving $\pi T = \pi$.
import jax
import jax.numpy as jnp
# Transition matrix: R, S, C
T = jnp.array([
[0.3, 0.4, 0.3],
[0.2, 0.5, 0.3],
[0.4, 0.3, 0.3]
])
states = ["Rainy", "Sunny", "Cloudy"]
# Simulate 100,000 steps
key = jax.random.PRNGKey(42)
n_steps = 100_000
state = 0 # start rainy
counts = jnp.zeros(3)
for i in range(n_steps):
key, subkey = jax.random.split(key)
state = jax.random.choice(subkey, 3, p=T[state])
counts = counts.at[state].add(1)
sim_stationary = counts / n_steps
print("Simulated stationary distribution:")
for s, p in zip(states, sim_stationary):
print(f" {s}: {p:.4f}")
# Analytical: find left eigenvector with eigenvalue 1
eigenvalues, eigenvectors = jnp.linalg.eig(T.T)
idx = jnp.argmin(jnp.abs(eigenvalues - 1.0))
pi = jnp.real(eigenvectors[:, idx])
pi = pi / pi.sum()
print("\nAnalytical stationary distribution:")
for s, p in zip(states, pi):
print(f" {s}: {p:.4f}")
- Implement the Viterbi algorithm for the umbrella HMM and decode a sequence of observations.
import jax.numpy as jnp
# HMM parameters
states = ["Rainy", "Sunny"]
obs_names = ["Umbrella", "No umbrella"]
trans = jnp.array([[0.7, 0.3], # R->R, R->S
[0.4, 0.6]]) # S->R, S->S
emit = jnp.array([[0.9, 0.1], # R->U, R->noU
[0.2, 0.8]]) # S->U, S->noU
init = jnp.array([0.5, 0.5])
# Observations: U=0, noU=1
observations = [0, 0, 1] # Umbrella, Umbrella, No umbrella
def viterbi(obs, init, trans, emit):
n_states = len(init)
T = len(obs)
V = jnp.zeros((T, n_states))
path = jnp.zeros((T, n_states), dtype=int)
# Initialisation
V = V.at[0].set(init * emit[:, obs[0]])
# Recursion
for t in range(1, T):
for j in range(n_states):
probs = V[t-1] * trans[:, j]
V = V.at[t, j].set(jnp.max(probs) * emit[j, obs[t]])
path = path.at[t, j].set(jnp.argmax(probs))
# Backtrack
best = [int(jnp.argmax(V[-1]))]
for t in range(T-1, 0, -1):
best.insert(0, int(path[t, best[0]]))
return best, V
decoded, scores = viterbi(observations, init, trans, emit)
print("Observations:", [obs_names[o] for o in observations])
print("Decoded: ", [states[s] for s in decoded])
- Visualise how the posterior evolves as you observe more coin flips. Start with a Beta(1,1) prior (uniform) and update after each flip.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
theta = jnp.linspace(0.01, 0.99, 300)
key = jax.random.PRNGKey(7)
# True bias = 0.65
flips = jax.random.bernoulli(key, p=0.65, shape=(50,))
plt.figure(figsize=(10, 5))
a, b = 1, 1 # Beta(1,1) = uniform
for n_obs in [0, 1, 5, 10, 25, 50]:
h = int(flips[:n_obs].sum())
t = n_obs - h
a_post = a + h
b_post = b + t
y = theta**(a_post-1) * (1-theta)**(b_post-1)
y = y / jnp.trapezoid(y, theta)
plt.plot(theta, y, linewidth=2, label=f"n={n_obs} (h={h})")
plt.axvline(0.65, color="black", linestyle=":", alpha=0.5, label="true p=0.65")
plt.xlabel("θ")
plt.ylabel("Density")
plt.title("Bayesian updating: posterior narrows with more data")
plt.legend()
plt.grid(alpha=0.3)
plt.show()
Information Theory
-
Information theory, founded by Claude Shannon in 1948, gives us a mathematical framework for quantifying information. It answers questions like: how surprised should you be by an event? How much information does a message carry? How different are two probability distributions?
-
These questions sound abstract, but they are the foundation of ML loss functions, data compression, and communication systems. Cross-entropy loss, the most common loss function in classification, comes directly from information theory.
-
Start with the simplest question: how much information does a single event carry?
-
Surprisal (also called self-information) measures how surprising an event is. If something very likely happens, you learn almost nothing. If something rare happens, you learn a lot.
-
If you live in a desert and someone tells you it is sunny, that is not very informative. If they tell you it is snowing, that is extremely informative. Surprisal formalises this intuition:
$$I(x) = \log_2 \frac{1}{p(x)} = -\log_2 p(x)$$
-
The unit is bits when we use $\log_2$. A fair coin flip has surprisal $-\log_2(0.5) = 1$ bit. An event with probability $1/8$ has surprisal $\log_2(8) = 3$ bits.
-
Why logarithm and not just $1/p$? Three reasons:
- A certain event ($p = 1$) should give zero information: $\log(1) = 0$ but $1/1 = 1$.
- Independent events should have additive information: $\log(1/p_1 p_2) = \log(1/p_1) + \log(1/p_2)$.
- We want a smooth, well-behaved function. $1/p$ explodes; $\log(1/p)$ grows gently.
-
Entropy is the expected surprisal, the average amount of information you get per event sampled from a distribution. It measures the uncertainty or "unpredictability" of the distribution:
$$H(X) = E[I(X)] = -\sum_{x} p(x) \log_2 p(x)$$
-
A fair coin has entropy $H = -0.5\log_2(0.5) - 0.5\log_2(0.5) = 1$ bit. Maximum uncertainty.
-
A biased coin with $p = 0.9$ has entropy $H = -0.9\log_2(0.9) - 0.1\log_2(0.1) \approx 0.469$ bits. Less uncertain, so less entropy.
-
A deterministic event ($p = 1$) has entropy $H = 0$. No uncertainty at all.
-
Entropy is maximised when all outcomes are equally likely. For $n$ equally likely outcomes, $H = \log_2 n$. A fair die has entropy $\log_2 6 \approx 2.585$ bits.
-
The practical meaning of entropy is compression. Shannon's source coding theorem says you cannot compress data below its entropy rate without losing information. An image where every pixel is equally likely (maximum entropy) cannot be compressed. An image that is mostly white (low entropy) compresses well.
-
For a quick sense of scale: a grayscale pixel (256 values) has a maximum entropy of 8 bits. A 1080p grayscale image has at most $1920 \times 1080 \times 8 \approx 16.6$ million bits. Real images have much lower entropy because neighbouring pixels are correlated, which is why JPEG compression works.
-
For continuous random variables, the discrete sum becomes an integral. Differential entropy is:
$$h(X) = -\int_{-\infty}^{\infty} f(x) \log f(x), dx$$
-
A Gaussian with variance $\sigma^2$ has differential entropy $h = \frac{1}{2}\log_2(2\pi e \sigma^2)$. Among all distributions with the same variance, the Gaussian has the maximum entropy. This is one reason the Gaussian is so common in modelling: it makes the fewest assumptions beyond the specified mean and variance.
-
Mutual information measures how much knowing one variable tells you about another. It is the reduction in uncertainty about $X$ when you observe $Y$:
$$I(X; Y) = H(X) - H(X|Y) = H(Y) - H(Y|X)$$
- Equivalently:
$$I(X; Y) = \sum_{x,y} p(x,y) \log_2 \frac{p(x,y)}{p(x) p(y)}$$
-
If $X$ and $Y$ are independent, $p(x,y) = p(x)p(y)$ and mutual information is zero. The more dependent they are, the higher the mutual information.
-
In ML, mutual information is used in feature selection (pick features with high MI with the target), in information bottleneck methods, and in evaluating clustering quality.
-
Cross-entropy measures the average number of bits needed to encode events from distribution $p$ using a code optimised for distribution $q$:
$$H(p, q) = -\sum_{x} p(x) \log_2 q(x)$$
-
If $q$ matches $p$ perfectly, cross-entropy equals entropy: $H(p, p) = H(p)$. If $q$ is a bad approximation, cross-entropy is higher. The "extra" bits come from the mismatch.
-
This is exactly why cross-entropy is the standard loss function for classification in ML. The true labels define $p$ (a one-hot distribution), and the model's predicted probabilities define $q$. Minimising cross-entropy pushes $q$ toward $p$:
$$\mathcal{L} = -\sum_{c} y_c \log \hat{y}_c$$
-
For a single sample with true class $c$, this simplifies to $\mathcal{L} = -\log \hat{y}_c$. The loss is the surprisal of the true class under the model's predictions. If the model assigns high probability to the correct class, the loss is low.
-
KL divergence (Kullback-Leibler divergence, also called relative entropy) measures how much one distribution differs from another:
$$D_{\text{KL}}(p | q) = \sum_{x} p(x) \log \frac{p(x)}{q(x)} = H(p, q) - H(p)$$
- KL divergence is the "extra cost" of using distribution $q$ instead of the true distribution $p$. It is always non-negative ($D_{\text{KL}} \ge 0$) and equals zero only when $p = q$.
-
KL divergence is not symmetric: $D_{\text{KL}}(p | q) \ne D_{\text{KL}}(q | p)$. This asymmetry matters. $D_{\text{KL}}(p | q)$ penalises $q$ for placing low probability where $p$ has high probability (because $\log(p/q)$ blows up). $D_{\text{KL}}(q | p)$ penalises the reverse.
-
This asymmetry leads to two styles of approximation:
- Minimising $D_{\text{KL}}(p | q)$ produces moment-matching behaviour: $q$ covers all modes of $p$ but may be too spread out.
- Minimising $D_{\text{KL}}(q | p)$ produces mode-seeking behaviour: $q$ concentrates on one mode of $p$ but may miss others. This is what variational inference uses.
-
Since $H(p)$ is constant with respect to the model, minimising cross-entropy $H(p, q)$ is equivalent to minimising $D_{\text{KL}}(p | q)$. This is why we can use cross-entropy loss and know that we are also minimising the KL divergence between the true and predicted distributions.
-
KL divergence plays a central role in Bayesian updating. The posterior $P(\theta | D)$ is the distribution closest to the prior $P(\theta)$ (in KL divergence terms) that is consistent with the observed data. Each new observation updates the posterior, reducing uncertainty about $\theta$.
-
In variational autoencoders (VAEs), the loss function has two terms: a reconstruction loss (cross-entropy) and a KL divergence term that regularises the latent space to stay close to a standard normal distribution.
-
To tie everything together: entropy tells you the intrinsic uncertainty in a distribution, cross-entropy tells you how well your model approximates reality, and KL divergence tells you the gap between the two. These three quantities form the backbone of modern ML optimisation.
Coding Tasks (use CoLab or notebook)
- Compute the entropy of various distributions and verify that the uniform distribution has maximum entropy for a given number of outcomes.
import jax.numpy as jnp
def entropy(p):
"""Compute entropy in bits. Filter out zero-probability events."""
p = p[p > 0]
return -jnp.sum(p * jnp.log2(p))
# Fair die
fair = jnp.ones(6) / 6
print(f"Fair die entropy: {entropy(fair):.4f} bits (max = log2(6) = {jnp.log2(6.):.4f})")
# Loaded die
loaded = jnp.array([0.1, 0.1, 0.1, 0.1, 0.1, 0.5])
print(f"Loaded die entropy: {entropy(loaded):.4f} bits")
# Deterministic
det = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 1.0])
print(f"Deterministic: {entropy(det):.4f} bits")
# Fair coin
coin = jnp.array([0.5, 0.5])
print(f"Fair coin entropy: {entropy(coin):.4f} bits")
- Compute cross-entropy and KL divergence between a true distribution and several approximations. Verify that $D_{\text{KL}}(p | q) = H(p, q) - H(p)$.
import jax.numpy as jnp
def cross_entropy(p, q):
return -jnp.sum(p * jnp.log2(jnp.clip(q, 1e-10, 1.0)))
def kl_divergence(p, q):
mask = p > 0
return jnp.sum(jnp.where(mask, p * jnp.log2(p / jnp.clip(q, 1e-10, 1.0)), 0.0))
def entropy(p):
p = p[p > 0]
return -jnp.sum(p * jnp.log2(p))
p = jnp.array([0.4, 0.3, 0.2, 0.1]) # true distribution
for name, q in [("perfect match", p),
("slight mismatch", jnp.array([0.35, 0.30, 0.25, 0.10])),
("big mismatch", jnp.array([0.1, 0.1, 0.1, 0.7]))]:
h_p = entropy(p)
h_pq = cross_entropy(p, q)
kl = kl_divergence(p, q)
print(f"{name:20s}: H(p)={h_p:.4f}, H(p,q)={h_pq:.4f}, "
f"KL={kl:.4f}, H(p,q)-H(p)={h_pq-h_p:.4f}")
- Show that KL divergence is not symmetric by computing $D_{\text{KL}}(p | q)$ and $D_{\text{KL}}(q | p)$ for two different distributions.
import jax.numpy as jnp
def kl_div(p, q):
mask = p > 0
return float(jnp.sum(jnp.where(mask, p * jnp.log2(p / jnp.clip(q, 1e-10, 1.0)), 0.0)))
p = jnp.array([0.9, 0.1])
q = jnp.array([0.5, 0.5])
print(f"D_KL(p || q) = {kl_div(p, q):.4f}")
print(f"D_KL(q || p) = {kl_div(q, p):.4f}")
print(f"Not the same! KL divergence is asymmetric.")
- Simulate cross-entropy loss during training. Create a "true" one-hot label and show how the loss decreases as the model's predicted probabilities improve.
import jax.numpy as jnp
import matplotlib.pyplot as plt
# True label: class 2 out of 4
true_label = jnp.array([0, 0, 1, 0])
# Simulate improving predictions
steps = []
losses = []
for confidence in jnp.linspace(0.25, 0.99, 50):
# Model becomes more confident in class 2
remaining = (1 - confidence) / 3
pred = jnp.array([remaining, remaining, confidence, remaining])
loss = -jnp.sum(true_label * jnp.log(jnp.clip(pred, 1e-10, 1.0)))
steps.append(float(confidence))
losses.append(float(loss))
plt.figure(figsize=(8, 4))
plt.plot(steps, losses, color="#e74c3c", linewidth=2)
plt.xlabel("Model confidence in true class")
plt.ylabel("Cross-entropy loss")
plt.title("Cross-entropy loss decreases as predictions improve")
plt.grid(alpha=0.3)
plt.show()
Classical Machine Learning
-
Machine learning is the study of algorithms that improve their performance on some task by learning from data, rather than being explicitly programmed with rules. Instead of writing "if income > 50k and age < 30 then approve loan," you hand the algorithm thousands of past loan decisions and let it figure out the pattern.
-
There are three broad paradigms. Supervised learning uses labelled data, meaning each input comes with a known correct output. The algorithm learns a mapping from inputs to outputs. Unsupervised learning works with unlabelled data and tries to discover hidden structure, like clusters or compressed representations. Reinforcement learning learns through trial and error, receiving rewards or penalties for actions taken in an environment (covered in file 04).
-
Within supervised learning, classification predicts discrete categories (spam or not spam, cat or dog) while regression predicts continuous values (house price, temperature tomorrow). The boundary is not always sharp: logistic regression is named "regression" but performs classification.
-
A key distinction in probabilistic models is generative vs discriminative. A generative model learns the joint distribution $P(x, y)$, which means it understands how the data itself is generated. It can produce new samples. A discriminative model learns $P(y \mid x)$ directly, focusing only on the boundary between classes. Naive Bayes is generative; logistic regression (file 02) is discriminative. Generative models are more flexible but harder to train well; discriminative models often give better classification accuracy when you have enough data.
-
Naive Bayes is one of the simplest and most effective classifiers. It applies Bayes' theorem (from chapter 05) directly:
$$P(C_k \mid x) = \frac{P(x \mid C_k) , P(C_k)}{P(x)}$$
-
The "naive" part is a strong independence assumption: it treats every feature as independent given the class. If you are classifying emails as spam, Naive Bayes assumes the presence of the word "free" tells you nothing about the presence of the word "winner," once you know the email is spam. This is almost never true in reality, but the classifier works surprisingly well anyway.
-
Since $P(x)$ is the same for all classes, classification simplifies to picking the class that maximises the numerator:
$$\hat{y} = \arg\max_{k} ; P(C_k) \prod_{i=1}^{n} P(x_i \mid C_k)$$
-
The prior $P(C_k)$ is just the fraction of training examples in each class. The likelihoods $P(x_i \mid C_k)$ depend on what kind of features you have, which gives rise to three common variants.
-
Multinomial Naive Bayes is designed for count data, like word frequencies in documents. Each feature $x_i$ represents how many times word $i$ appears, and the likelihood follows a multinomial distribution. This is the standard choice for text classification, sentiment analysis, and spam filtering.
-
Gaussian Naive Bayes assumes each feature follows a normal distribution within each class. You estimate the mean $\mu_{ik}$ and variance $\sigma_{ik}^2$ of feature $i$ for class $k$ from the training data, then compute:
$$P(x_i \mid C_k) = \frac{1}{\sqrt{2\pi\sigma_{ik}^2}} \exp!\left(-\frac{(x_i - \mu_{ik})^2}{2\sigma_{ik}^2}\right)$$
- This is the natural choice when your features are continuous measurements, like height, weight, or sensor readings.
-
Bernoulli Naive Bayes models binary features: each feature is either present (1) or absent (0). Instead of counting how many times a word appears, you only track whether it appears at all. This works well for short texts or binary feature vectors.
-
A practical problem arises when a feature value never appears with a certain class in training data. The likelihood becomes zero, and because everything is multiplied together, the entire posterior collapses to zero. Laplace smoothing fixes this by adding a small count (usually 1) to every feature-class combination:
$$P(x_i \mid C_k) = \frac{\text{count}(x_i, C_k) + \alpha}{\text{count}(C_k) + \alpha \cdot V}$$
-
Here $\alpha$ is the smoothing parameter (typically 1) and $V$ is the number of possible values for that feature. This ensures no probability is ever exactly zero.
-
Decision trees take a completely different approach. Instead of computing probabilities, they partition the feature space through a sequence of yes/no questions. Think of the game Twenty Questions: at each step, you ask the question that narrows down the possibilities the most.
-
A tree starts at the root with all training examples. At each internal node, it picks a feature and a threshold to split on (e.g., "is age < 30?"). Examples flow left or right based on the answer. This continues recursively until the leaves, which hold predictions: the majority class for classification, or the mean value for regression.
-
The critical question is: which feature should you split on? You want splits that produce the "purest" child nodes, where most examples belong to the same class. Two common measures of impurity are Gini impurity and entropy.
-
Gini impurity measures the probability that a randomly chosen sample would be misclassified if labelled according to the distribution in that node:
$$\text{Gini}(S) = 1 - \sum_{k=1}^{K} p_k^2$$
-
If a node is perfectly pure (all one class), Gini is 0. If classes are equally balanced (say 50/50 for two classes), Gini reaches its maximum of 0.5.
-
Entropy (from chapter 05's information theory section) measures the average surprise:
$$H(S) = -\sum_{k=1}^{K} p_k \log_2 p_k$$
-
A pure node has entropy 0. A perfectly balanced binary node has entropy 1 bit. In practice, Gini and entropy give very similar trees; Gini is slightly faster to compute since it avoids the logarithm.
-
Information gain is the reduction in impurity achieved by a split. For a split that divides set $S$ into subsets $S_L$ and $S_R$:
$$\text{IG}(S, \text{split}) = H(S) - \frac{|S_L|}{|S|} H(S_L) - \frac{|S_R|}{|S|} H(S_R)$$
-
The algorithm greedily picks the split with the highest information gain at each node. This is a locally optimal strategy, not globally optimal, but it works well in practice.
-
Regression trees work the same way, but leaves predict a continuous value (the mean of the examples that reach that leaf) and the splitting criterion uses variance reduction instead of Gini or entropy.
-
Left unchecked, a decision tree will keep splitting until every leaf is pure, essentially memorising the training data. This is severe overfitting. Pruning combats this. Pre-pruning sets limits before growing the tree: maximum depth, minimum samples per leaf, or minimum information gain to make a split. Post-pruning grows the full tree first, then removes branches that do not improve performance on a validation set.
-
A single decision tree is easy to interpret but tends to be unstable: small changes in the data can produce a very different tree. Ensemble methods combine many models to get better predictions than any single model could achieve.
-
The core idea is the "wisdom of crowds." If you ask 100 mediocre classifiers and take a majority vote, the ensemble can be excellent, as long as the individual classifiers make somewhat independent errors.
-
Bagging (bootstrap aggregating) trains multiple models on different random subsets of the data, sampled with replacement (bootstrap samples). Each model sees roughly 63% of the original data. At prediction time, you average the outputs (regression) or take a majority vote (classification). Because each model sees different data, they make different mistakes, and averaging cancels out much of the variance.
-
Random Forests are bagging applied to decision trees with one extra twist: at each split, the tree only considers a random subset of features (typically $\sqrt{d}$ features out of $d$ total). This further decorrelates the trees, making the ensemble even more powerful. Random forests are one of the most reliable off-the-shelf classifiers in all of machine learning.
-
Boosting takes the opposite philosophy. Instead of training models independently, it trains them sequentially, with each new model focusing on the examples that previous models got wrong.
-
AdaBoost (Adaptive Boosting) maintains a weight for each training example. Initially all weights are equal. After training a weak learner (often a very shallow decision tree, called a "stump"), examples that were misclassified get higher weights, so the next learner pays more attention to them. The final prediction is a weighted vote of all learners, where better-performing learners get more say:
$$H(x) = \text{sign}!\left(\sum_{t=1}^{T} \alpha_t , h_t(x)\right)$$
- The weight $\alpha_t$ for learner $t$ depends on its error rate $\epsilon_t$:
$$\alpha_t = \frac{1}{2} \ln!\left(\frac{1 - \epsilon_t}{\epsilon_t}\right)$$
-
A learner with low error gets a large positive weight; one performing at chance ($\epsilon = 0.5$) gets zero weight.
-
Gradient Boosting generalises this idea. Instead of reweighting examples, each new model is trained to predict the residual errors (negative gradient of the loss function) of the combined ensemble so far. For squared error loss, the residuals are literally the differences between predictions and targets. Gradient boosting with decision trees (GBDT) is behind many winning solutions in structured data competitions (XGBoost, LightGBM, CatBoost are popular implementations).
-
The key contrast: bagging reduces variance (averaging out noise) while boosting reduces bias (correcting systematic errors). Bagging works best when individual models overfit; boosting works best when they underfit.
-
Shifting to unsupervised learning, K-Means clustering is the simplest and most widely used clustering algorithm. Given $n$ data points and a target number of clusters $K$, it assigns each point to one of $K$ groups by minimising the total distance from each point to its cluster centre.
-
The algorithm alternates two steps. First, assign each point to the nearest centroid. Second, update each centroid to the mean of all points assigned to it. Repeat until assignments stop changing. This is guaranteed to converge because the total within-cluster distance decreases (or stays the same) at every step.
- Formally, K-Means minimises the within-cluster sum of squares, called inertia:
$$J = \sum_{k=1}^{K} \sum_{x \in C_k} |x - \mu_k|^2$$
-
where $\mu_k$ is the centroid of cluster $C_k$.
-
K-Means is sensitive to initialisation. Bad starting centroids can lead to poor local minima. The K-Means++ initialisation strategy picks the first centroid randomly, then chooses each subsequent centroid with probability proportional to its squared distance from the nearest existing centroid. This spreads out the initial centres and almost always gives better results.
-
How do you choose $K$? Two common tools. The elbow method plots inertia vs $K$ and looks for the "elbow" where adding more clusters stops helping much. The silhouette score measures how similar a point is to its own cluster compared to the nearest other cluster, ranging from -1 (wrong cluster) to +1 (well clustered). Average silhouette score across all points gives an overall measure of cluster quality.
-
K-Means has limitations: it assumes spherical clusters of roughly equal size, and it makes "hard" assignments (each point belongs to exactly one cluster). Gaussian Mixture Models (GMMs) relax both restrictions.
-
A GMM models the data as a mixture of $K$ Gaussian distributions, each with its own mean $\mu_k$, covariance $\Sigma_k$, and mixing weight $\pi_k$ (where the weights sum to 1):
$$P(x) = \sum_{k=1}^{K} \pi_k , \mathcal{N}(x \mid \mu_k, \Sigma_k)$$
-
Instead of hard assignments, each point gets a soft assignment: the probability (called the "responsibility") that it belongs to each cluster. A point near the boundary between two Gaussians might be 60% cluster A and 40% cluster B.
-
GMMs are fitted using the Expectation-Maximisation (EM) algorithm, which alternates two steps, much like K-Means. The E-step computes responsibilities: for each point, what is the probability it came from each Gaussian? The M-step updates parameters: given the responsibilities, what are the best means, covariances, and mixing weights? EM is guaranteed to increase the data likelihood at each iteration and converges to a local maximum.
-
K-Means is actually a special case of EM with GMMs: it corresponds to spherical Gaussians with equal covariance and hard (0/1) responsibilities.
-
Support Vector Machines (SVMs) approach classification from a geometric perspective. Given two linearly separable classes, there are infinitely many hyperplanes that separate them. SVM finds the one with the maximum margin, the largest possible gap between the hyperplane and the nearest data points from each class.
-
The nearest points, the ones sitting right on the edge of the margin, are called support vectors. They are the only points that matter for defining the boundary; you could remove all other training points and get the same hyperplane.
- For a linear classifier $f(x) = w \cdot x + b$, finding the maximum margin amounts to solving:
$$\min_{w, b} ; \frac{1}{2}|w|^2 \quad \text{subject to} \quad y_i(w \cdot x_i + b) \geq 1 ; \text{for all } i$$
-
This is a convex quadratic program, so it has a unique global solution (no local minima to worry about).
-
Real data is rarely perfectly separable. Soft-margin SVM allows some points to violate the margin by introducing slack variables $\xi_i \geq 0$:
$$\min_{w, b, \xi} ; \frac{1}{2}|w|^2 + C \sum_{i=1}^{n} \xi_i \quad \text{subject to} \quad y_i(w \cdot x_i + b) \geq 1 - \xi_i$$
-
The hyperparameter $C$ controls the tradeoff: large $C$ penalises misclassifications heavily (tighter fit, risk of overfitting), small $C$ allows more violations (wider margin, more regularised).
-
The most powerful feature of SVMs is the kernel trick. Many datasets that are not linearly separable in the original feature space become separable when mapped to a higher-dimensional space. The kernel trick lets you compute dot products in that high-dimensional space without ever explicitly computing the transformation.
-
A kernel function $K(x_i, x_j) = \phi(x_i) \cdot \phi(x_j)$ replaces every dot product in the SVM optimisation. The most popular kernel is the Radial Basis Function (RBF) kernel:
$$K(x_i, x_j) = \exp!\left(-\gamma |x_i - x_j|^2\right)$$
-
The RBF kernel implicitly maps data to an infinite-dimensional space. The parameter $\gamma$ controls how far the influence of a single training point reaches: large $\gamma$ means each point only influences its immediate neighbourhood (risk of overfitting), small $\gamma$ gives smoother boundaries.
-
Other common kernels include the polynomial kernel $K(x_i, x_j) = (x_i \cdot x_j + c)^d$ and the linear kernel $K(x_i, x_j) = x_i \cdot x_j$ (which is just the standard SVM without any transformation).
-
In practice, SVMs with RBF kernels were the dominant classifier before deep learning took over. They still work well on small-to-medium datasets, especially when the number of features is large relative to the number of samples.
-
The SVM's connection to chapter 02 (matrices) runs deep. The optimisation is typically solved in its dual form, where the solution depends only on dot products between training examples, which is exactly what makes the kernel trick possible. The entire algorithm operates in the language of inner products and linear algebra.
-
To summarise the classical ML toolkit:
| Algorithm | Type | Key Strength | Key Weakness |
|---|---|---|---|
| Naive Bayes | Supervised (generative) | Fast, works with little data | Independence assumption |
| Decision Tree | Supervised | Interpretable | Overfits easily |
| Random Forest | Supervised (ensemble) | Robust, few hyperparameters | Less interpretable |
| Gradient Boosting | Supervised (ensemble) | State-of-the-art on tabular data | Slower, more tuning |
| K-Means | Unsupervised (clustering) | Simple, scalable | Assumes spherical clusters |
| GMM | Unsupervised (clustering) | Soft assignments, flexible shapes | Sensitive to initialisation |
| SVM | Supervised | Effective in high dimensions | Slow on large datasets |
Coding Tasks (use CoLab or notebook)
- Implement Gaussian Naive Bayes from scratch. Train on synthetic 2D data with two classes and visualise the decision boundary. Compare with scikit-learn's implementation.
import jax.numpy as jnp
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
# Generate synthetic data
X, y = make_classification(n_samples=300, n_features=2, n_redundant=0,
n_informative=2, n_clusters_per_class=1, random_state=42)
X, y = jnp.array(X), jnp.array(y)
# Fit Gaussian Naive Bayes from scratch
classes = jnp.unique(y)
params = {}
for c in classes:
c = int(c)
mask = y == c
X_c = X[mask]
params[c] = {
'mean': jnp.mean(X_c, axis=0),
'var': jnp.var(X_c, axis=0),
'prior': jnp.sum(mask) / len(y)
}
def gaussian_log_likelihood(x, mean, var):
return -0.5 * jnp.sum(jnp.log(2 * jnp.pi * var) + (x - mean)**2 / var)
def predict(X):
preds = []
for x in X:
log_posts = []
for c in [0, 1]:
log_post = jnp.log(params[c]['prior']) + gaussian_log_likelihood(
x, params[c]['mean'], params[c]['var'])
log_posts.append(log_post)
preds.append(jnp.argmax(jnp.array(log_posts)))
return jnp.array(preds)
# Decision boundary visualisation
xx, yy = jnp.meshgrid(jnp.linspace(X[:,0].min()-1, X[:,0].max()+1, 200),
jnp.linspace(X[:,1].min()-1, X[:,1].max()+1, 200))
grid = jnp.column_stack([xx.ravel(), yy.ravel()])
zz = predict(grid).reshape(xx.shape)
plt.figure(figsize=(8, 6))
plt.contourf(xx, yy, zz, alpha=0.3, cmap='coolwarm')
plt.scatter(X[y==0, 0], X[y==0, 1], c='#3498db', label='Class 0', edgecolors='k', s=20)
plt.scatter(X[y==1, 0], X[y==1, 1], c='#e74c3c', label='Class 1', edgecolors='k', s=20)
plt.title("Gaussian Naive Bayes Decision Boundary")
plt.legend()
plt.grid(alpha=0.3)
plt.show()
accuracy = jnp.mean(predict(X) == y)
print(f"Training accuracy: {accuracy:.2%}")
- Build a decision tree that splits using Gini impurity. Implement the splitting logic for a single node and show how information gain selects the best feature and threshold.
import jax.numpy as jnp
def gini_impurity(y):
"""Gini impurity of a label array."""
classes, counts = jnp.unique(y, return_counts=True)
probs = counts / len(y)
return 1.0 - jnp.sum(probs ** 2)
def information_gain(y, left_mask):
"""IG from splitting y into left/right by boolean mask."""
parent_gini = gini_impurity(y)
left_y, right_y = y[left_mask], y[~left_mask]
n = len(y)
if len(left_y) == 0 or len(right_y) == 0:
return 0.0
child_gini = (len(left_y)/n) * gini_impurity(left_y) + \
(len(right_y)/n) * gini_impurity(right_y)
return float(parent_gini - child_gini)
def best_split(X, y):
"""Find the feature and threshold that maximise information gain."""
best_ig, best_feat, best_thresh = -1, None, None
for feat in range(X.shape[1]):
thresholds = jnp.unique(X[:, feat])
for thresh in thresholds:
mask = X[:, feat] <= float(thresh)
ig = information_gain(y, mask)
if ig > best_ig:
best_ig, best_feat, best_thresh = ig, feat, float(thresh)
return best_feat, best_thresh, best_ig
# Example: synthetic data
from sklearn.datasets import make_classification
X, y = make_classification(n_samples=100, n_features=4, n_redundant=0, random_state=0)
X, y = jnp.array(X), jnp.array(y)
feat, thresh, ig = best_split(X, y)
print(f"Best split: feature {feat}, threshold {thresh:.3f}, info gain {ig:.4f}")
print(f"Parent Gini: {gini_impurity(y):.4f}")
mask = X[:, feat] <= thresh
print(f"Left Gini: {gini_impurity(y[mask]):.4f} ({int(jnp.sum(mask))} samples)")
print(f"Right Gini: {gini_impurity(y[~mask]):.4f} ({int(jnp.sum(~mask))} samples)")
- Implement K-Means from scratch with K-Means++ initialisation. Cluster a synthetic dataset and visualise the clusters at each iteration.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
# Generate synthetic clusters
X, y_true = make_blobs(n_samples=300, centers=4, cluster_std=0.8, random_state=42)
X = jnp.array(X)
def kmeans_plus_plus_init(X, K, key):
"""K-Means++ initialisation."""
n = X.shape[0]
idx = jax.random.randint(key, (), 0, n)
centroids = [X[idx]]
for _ in range(1, K):
dists = jnp.min(jnp.stack([jnp.sum((X - c)**2, axis=1) for c in centroids]), axis=0)
probs = dists / jnp.sum(dists)
key, subkey = jax.random.split(key)
idx = jax.random.choice(subkey, n, p=probs)
centroids.append(X[idx])
return jnp.stack(centroids)
def kmeans(X, K, max_iters=20, key=jax.random.PRNGKey(0)):
centroids = kmeans_plus_plus_init(X, K, key)
history = [centroids]
for _ in range(max_iters):
# Assign step
dists = jnp.stack([jnp.sum((X - c)**2, axis=1) for c in centroids])
labels = jnp.argmin(dists, axis=0)
# Update step
new_centroids = jnp.stack([
jnp.mean(X[labels == k], axis=0) for k in range(K)
])
history.append(new_centroids)
if jnp.allclose(centroids, new_centroids):
break
centroids = new_centroids
return labels, centroids, history
K = 4
labels, centroids, history = kmeans(X, K)
# Plot final result
colors = ['#3498db', '#e74c3c', '#27ae60', '#9b59b6']
plt.figure(figsize=(8, 6))
for k in range(K):
mask = labels == k
plt.scatter(X[mask, 0], X[mask, 1], c=colors[k], s=20, alpha=0.6)
plt.scatter(centroids[k, 0], centroids[k, 1], c=colors[k], marker='X',
s=200, edgecolors='k', linewidths=1.5)
plt.title(f"K-Means Clustering (K={K}, {len(history)-1} iterations)")
plt.grid(alpha=0.3)
plt.show()
# Compute inertia
inertia = sum(jnp.sum((X[labels == k] - centroids[k])**2) for k in range(K))
print(f"Final inertia: {inertia:.2f}")
- Demonstrate the kernel trick. Show that an RBF kernel computes dot products in a high-dimensional space by comparing the kernel matrix with explicit feature mapping for a polynomial kernel.
import jax.numpy as jnp
# Simple 2D data
X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
# Polynomial kernel: K(x,y) = (x·y + 1)^2
def poly_kernel(X, degree=2, c=1.0):
return (X @ X.T + c) ** degree
# Explicit degree-2 feature map for 2D: (1, sqrt(2)*x1, sqrt(2)*x2, x1^2, x2^2, sqrt(2)*x1*x2)
def poly_features(X):
x1, x2 = X[:, 0], X[:, 1]
return jnp.column_stack([
jnp.ones(len(X)),
jnp.sqrt(2) * x1,
jnp.sqrt(2) * x2,
x1 ** 2,
x2 ** 2,
jnp.sqrt(2) * x1 * x2
])
K_trick = poly_kernel(X)
phi = poly_features(X)
K_explicit = phi @ phi.T
print("Kernel trick (polynomial degree 2):")
print(K_trick)
print("\nExplicit feature map dot products:")
print(K_explicit)
print(f"\nMatrices match: {jnp.allclose(K_trick, K_explicit)}")
# RBF kernel: no finite explicit map exists
def rbf_kernel(X, gamma=0.5):
sq_dists = jnp.sum(X**2, axis=1, keepdims=True) + \
jnp.sum(X**2, axis=1) - 2 * X @ X.T
return jnp.exp(-gamma * sq_dists)
K_rbf = rbf_kernel(X)
print("\nRBF kernel matrix:")
print(K_rbf)
print("Diagonal is always 1 (a point is identical to itself)")
print("Off-diagonal entries decay with distance")
Gradient Machine Learning
-
The classical methods in file 01 use clever heuristics or closed-form solutions. This file covers algorithms that learn by following gradients, taking small steps downhill on a loss surface until they find good parameters. Gradient-based learning is the engine behind everything from linear regression to the largest neural networks.
-
Linear regression is the simplest gradient-based model, and it also has a closed-form solution, which makes it a perfect starting point. The model is a line (or hyperplane in higher dimensions):
$$\hat{y} = w \cdot x + b = \sum_{i=1}^{d} w_i x_i + b$$
-
In matrix notation (from chapter 02), if we stack all training inputs as rows of a matrix $X$ and absorb the bias into $w$ by appending a column of ones, this becomes $\hat{y} = Xw$.
-
The goal is to minimise the mean squared error (MSE), the average squared difference between predictions and actual values:
$$\mathcal{L}(w) = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 = \frac{1}{n} |y - Xw|^2$$
- Why squared error? It has a probabilistic justification: if you assume the targets are generated as $y = Xw + \epsilon$ where $\epsilon \sim \mathcal{N}(0, \sigma^2)$, then maximising the Gaussian likelihood of the data (chapter 05) is equivalent to minimising MSE. Squared error also penalises large mistakes more than small ones, which is often desirable.
- Because MSE is a quadratic function of $w$, it has a unique global minimum that we can find analytically. Taking the derivative, setting it to zero, and solving gives the normal equation:
$$w^{*} = (X^T X)^{-1} X^T y$$
-
This directly uses the matrix inverse from chapter 02. The expression $X^T X$ is a $d \times d$ matrix (where $d$ is the number of features), and $X^T y$ is a $d$-dimensional vector. The normal equation gives the exact optimal weights in one shot.
-
When does the normal equation fail? When $X^T X$ is singular (not invertible), which happens if features are linearly dependent or if you have more features than samples ($d > n$). In these cases you need regularisation (covered later) or gradient descent.
-
Logistic regression adapts the linear model for binary classification. Instead of predicting a continuous value, we want a probability between 0 and 1. The sigmoid function squashes any real number into this range:
$$\sigma(z) = \frac{1}{1 + e^{-z}}$$
- The model computes $z = w \cdot x + b$ (a linear score, just like linear regression) and then passes it through the sigmoid: $\hat{y} = \sigma(w \cdot x + b)$. The output $\hat{y}$ is interpreted as $P(y = 1 \mid x)$.
-
The sigmoid has nice properties: $\sigma(0) = 0.5$, $\sigma(z) \to 1$ as $z \to \infty$, $\sigma(z) \to 0$ as $z \to -\infty$, and its derivative has the elegant form $\sigma'(z) = \sigma(z)(1 - \sigma(z))$.
-
The loss function for logistic regression is binary cross-entropy (BCE), which comes directly from the Bernoulli likelihood (chapter 05):
$$\mathcal{L} = -\frac{1}{n} \sum_{i=1}^{n} \left[ y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i) \right]$$
-
When the true label is 1, only the first term is active and it penalises low predictions. When the true label is 0, only the second term is active and it penalises high predictions. The logarithm makes the penalty extremely steep for confident wrong predictions: predicting 0.01 when the true label is 1 costs much more than predicting 0.4.
-
Unlike MSE for linear regression, there is no closed-form solution for the BCE-minimising weights. We need an iterative approach: gradient descent.
-
The intuition behind gradient descent is simple: imagine you are standing on a hilly landscape (the loss surface) in fog. You cannot see the global minimum, but you can feel the slope under your feet. You take a step downhill, feel the slope again, and repeat. Eventually you reach a valley.
$$w \leftarrow w - \eta \frac{\partial \mathcal{L}}{\partial w}$$
- The learning rate $\eta$ controls your step size. Too large and you overshoot valleys, bouncing around without converging. Too small and you inch along painfully slowly, possibly getting stuck in a local minimum.
-
The gradient $\frac{\partial \mathcal{L}}{\partial w}$ is a vector pointing in the direction of steepest ascent. We subtract it because we want to go downhill. This is the chain rule from chapter 03 applied to the loss function.
-
Batch gradient descent computes the gradient using the entire training set at every step. This gives an exact gradient but is expensive when $n$ is large.
-
Stochastic gradient descent (SGD) uses a single random example per step. The gradient is noisy (it estimates the true gradient from one sample) but each step is extremely fast. The noise can actually help escape shallow local minima.
-
Mini-batch gradient descent splits the difference: use a batch of $B$ examples (typically 32, 64, or 256) per step. This balances computational efficiency (vectorised operations on the batch) with gradient quality. Almost all deep learning uses mini-batch SGD.
-
Backpropagation is how we actually compute gradients in models with many parameters, like neural networks. It is the chain rule from chapter 03 applied systematically through a computational graph.
-
Any model can be represented as a directed acyclic graph of operations: inputs flow in, get multiplied by weights, added together, passed through nonlinear functions, and eventually produce a loss value. The forward pass computes the output (and loss) by flowing data through this graph from input to output.
-
The backward pass (backpropagation) flows gradients in reverse. Starting from the loss, you compute how the loss changes with respect to each intermediate value, using the chain rule at every node. If $L$ depends on $z$ which depends on $w$, then:
$$\frac{\partial L}{\partial w} = \frac{\partial L}{\partial z} \cdot \frac{\partial z}{\partial w}$$
-
Each node only needs to know its own local derivative and the gradient flowing in from above. This makes backpropagation modular and efficient: the cost is roughly twice the forward pass (one pass forward, one backward).
-
Vanilla SGD has a problem: it oscillates in directions with steep curvature while making slow progress in flat directions. Optimisers improve on this by adapting the step based on gradient history.
-
SGD with momentum keeps a running average of past gradients (an exponential moving average, from chapter 04). This smooths out oscillations and accelerates progress along consistent directions:
$$v_t = \beta v_{t-1} + (1 - \beta) \nabla \mathcal{L}$$ $$w \leftarrow w - \eta , v_t$$
-
Think of a ball rolling downhill: momentum lets it build up speed in a consistent direction and dampens the side-to-side jitter. The typical value is $\beta = 0.9$.
-
Nesterov Accelerated Gradient (NAG) is a small but clever tweak: instead of computing the gradient at the current position, compute it at the "look-ahead" position $w - \eta \beta v_{t-1}$. This corrective step reduces overshooting:
$$v_t = \beta , v_{t-1} + \nabla \mathcal{L}(w - \eta \beta , v_{t-1})$$ $$w \leftarrow w - \eta , v_t$$
- Adagrad adapts the learning rate per parameter. Parameters that receive large gradients get smaller learning rates, and vice versa. It accumulates the squared gradients:
$$G_t = G_{t-1} + g_t^2, \quad w \leftarrow w - \frac{\eta}{\sqrt{G_t + \epsilon}} g_t$$
-
The problem: $G_t$ only grows, so the effective learning rate monotonically decreases and eventually becomes too small to learn anything.
-
RMSprop fixes this by using an exponential moving average of squared gradients instead of a sum, so recent gradients matter more than ancient ones:
$$s_t = \beta , s_{t-1} + (1 - \beta) g_t^2, \quad w \leftarrow w - \frac{\eta}{\sqrt{s_t + \epsilon}} g_t$$
- Adam (Adaptive Moment Estimation) combines momentum and RMSprop. It maintains both a first-moment estimate (mean of gradients, like momentum) and a second-moment estimate (mean of squared gradients, like RMSprop):
$$m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t$$ $$v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2$$
- Since $m_t$ and $v_t$ are initialised at zero, they are biased toward zero in early steps. Bias correction fixes this:
$$\hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t}$$
$$w \leftarrow w - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t$$
-
Default hyperparameters ($\beta_1 = 0.9$, $\beta_2 = 0.999$, $\epsilon = 10^{-8}$) work well across a wide range of problems, which is why Adam is the default optimiser in most deep learning work.
-
AdamW decouples weight decay from the gradient update. Standard L2 regularisation and weight decay are equivalent for SGD but not for Adam. AdamW applies weight decay directly to the parameters rather than adding $\lambda w$ to the gradient. This gives better generalisation and is now the standard in transformer training:
$$w \leftarrow w - \eta \left( \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda , w \right)$$
- LION (EvoLved Sign Momentum) is a newer optimiser discovered through program search. It uses only the sign of the momentum update (not the magnitude), which makes each update uniform in scale. LION uses less memory than Adam (no second-moment buffer) and can match or beat Adam on many tasks:
$$w \leftarrow w - \eta \cdot \text{sign}(\beta_1 , m_{t-1} + (1 - \beta_1) , g_t)$$ $$m_t = \beta_2 , m_{t-1} + (1 - \beta_2) , g_t$$
- Muon (Momentum + Orthogonalisation) applies Nesterov momentum and then orthogonalises the update matrix using Newton-Schulz iterations, which approximate the polar decomposition. The resulting update direction lies on the Stiefel manifold, every update has roughly equal magnitude across all singular directions, preventing any single direction from dominating. This removes the need for adaptive second-moment estimates (no $v_t$ buffer like Adam), reducing memory. Muon has shown strong results on transformer training, often matching AdamW quality at faster convergence, particularly for the attention and MLP weight matrices. Embedding and output layers are typically still handled by AdamW.
$$G_t = \text{NesterovMomentum}(\nabla \mathcal{L})$$ $$U_t = \text{NewtonSchulz}(G_t) \approx G_t (G_t^T G_t)^{-1/2}$$ $$W \leftarrow W - \eta , U_t$$
- The Newton-Schulz iteration computes the orthogonal factor by repeating $X_{k+1} = \frac{1}{2} X_k (3I - X_k^T X_k)$ for a few steps (typically 5-10). This avoids the cost of a full SVD while giving a good approximation.
-
Beyond MSE and BCE, several other loss functions are commonly used.
-
Mean Absolute Error (MAE), or L1 loss, takes the average of absolute differences: $\frac{1}{n}\sum|y_i - \hat{y}_i|$. It is more robust to outliers than MSE because it does not square large errors.
-
Huber loss combines the best of both: it behaves like MSE for small errors (smooth, easy to optimise) and like MAE for large errors (robust to outliers). It has a threshold $\delta$ that controls the transition.
-
Categorical cross-entropy (CCE) generalises BCE to multiple classes. If $\hat{y}_k$ is the predicted probability for class $k$ and the true class is $c$:
$$\mathcal{L} = -\log(\hat{y}_c)$$
-
This is just the negative log-probability of the correct class. Minimising cross-entropy is equivalent to maximising the likelihood, which connects back to the information theory in chapter 05: cross-entropy measures how many extra bits you need when using your predicted distribution instead of the true distribution.
-
Hinge loss is used by SVMs: $\mathcal{L} = \max(0, 1 - y \cdot f(x))$. It only penalises predictions that are on the wrong side of the margin or within the margin. Once a point is correctly classified with sufficient confidence, the loss is zero.
-
Regularisation prevents overfitting by adding a penalty for complex models. The regularised loss is:
$$\mathcal{L}{\text{reg}} = \mathcal{L}{\text{data}} + \lambda , R(w)$$
-
L2 regularisation (Ridge, weight decay) penalises the sum of squared weights: $R(w) = |w|^2 = \sum w_i^2$. It discourages any single weight from becoming too large, effectively shrinking all weights toward zero but rarely making them exactly zero.
-
L1 regularisation (Lasso) penalises the sum of absolute weights: $R(w) = |w|_1 = \sum |w_i|$. It encourages sparsity, driving many weights to exactly zero, which performs automatic feature selection.
-
Elastic Net combines both: $R(w) = \alpha |w|_1 + (1 - \alpha) |w|^2$, blending sparsity and shrinkage.
-
There is a beautiful Bayesian interpretation (from chapter 05). L2 regularisation is equivalent to placing a Gaussian prior on the weights and finding the MAP estimate. L1 regularisation corresponds to a Laplace prior. The regularisation strength $\lambda$ controls how much you trust the prior relative to the data.
-
Evaluation metrics tell you whether your model is actually working. For regression, MSE and MAE are standard. For classification, things are more nuanced.
-
A confusion matrix is a table of four counts for binary classification:
- True Positive (TP): predicted positive, actually positive
- False Positive (FP): predicted positive, actually negative
- True Negative (TN): predicted negative, actually negative
- False Negative (FN): predicted negative, actually positive
-
Accuracy = $\frac{TP + TN}{TP + TN + FP + FN}$ can be misleading when classes are imbalanced. If 99% of emails are not spam, a model that always predicts "not spam" has 99% accuracy but is useless.
-
Precision = $\frac{TP}{TP + FP}$ answers: of all predicted positives, how many are actually positive? High precision means few false alarms.
-
Recall (sensitivity) = $\frac{TP}{TP + FN}$ answers: of all actual positives, how many did you catch? High recall means few missed cases.
-
F1 score = $\frac{2 \cdot \text{precision} \cdot \text{recall}}{\text{precision} + \text{recall}}$ is the harmonic mean of precision and recall, balancing both.
-
The ROC curve plots the true positive rate (recall) against the false positive rate ($\frac{FP}{FP + TN}$) as you vary the classification threshold from 0 to 1. A perfect classifier hugs the top-left corner. The AUC (area under the ROC curve) summarises performance in a single number: 1.0 is perfect, 0.5 is random guessing.
-
Cross-validation provides a more reliable estimate of generalisation performance. In $k$-fold cross-validation, you split the data into $k$ folds, train on $k-1$ of them, test on the remaining fold, and rotate. The average test performance across all $k$ folds is your estimate. This uses all data for both training and testing (just never at the same time), which is especially valuable when data is scarce.
-
The bias-variance tradeoff (from chapter 04) is the fundamental tension in ML. A model's expected error decomposes into:
$$\text{Error} = \text{Bias}^2 + \text{Variance} + \text{Irreducible Noise}$$
-
Bias is systematic error from wrong assumptions (e.g., fitting a line to curved data). Variance is sensitivity to training data fluctuations (e.g., a degree-20 polynomial fitting noise). Simple models have high bias and low variance; complex models have low bias and high variance. The sweet spot minimises total error.
-
Learning rate scheduling adjusts $\eta$ during training. Common strategies:
- Step decay: multiply $\eta$ by a factor (e.g., 0.1) every $N$ epochs
- Cosine annealing: smoothly decrease $\eta$ following a cosine curve from the initial value to near zero
- Warmup: start with a very small $\eta$ and linearly increase it for the first few thousand steps, then decay. This prevents large initial gradients from destabilising training
- 1cycle: one cosine cycle up then down, which can give faster convergence
-
Hyperparameter tuning is the process of finding good values for learning rate, batch size, regularisation strength, and other settings that are not learned by gradient descent. Common approaches:
- Grid search: try every combination on a predefined grid (exhaustive but expensive)
- Random search: sample combinations randomly, which is often more efficient because not all hyperparameters matter equally
- Bayesian optimisation: build a model of the objective function and intelligently choose the next hyperparameters to try
- ASHA (Asynchronous Successive Halving Algorithm): runs many trials in parallel with small budgets, then promotes the most promising ones to larger budgets while killing the rest early. It combines the efficiency of early stopping with massive parallelism — instead of running 100 full training runs, start all 100 cheaply, keep the top quarter at each rung, and only a handful run to completion. This is the backbone of modern large-scale tuning frameworks like Ray Tune.
-
Schedule-free learning eliminates the need for a learning rate schedule altogether. Instead of decaying $\eta$ on a fixed curve, it maintains two sequences: a slow-moving average of iterates $z_t$ (which converges to the optimum) and a fast exploratory iterate $y_t$ (where gradients are evaluated). The final output is the averaged sequence, which provably matches the convergence rate of the best schedule in hindsight. This removes the schedule as a hyperparameter entirely — you only set the base learning rate and the optimizer handles the rest. Schedule-free variants of both SGD and Adam have been shown to match or exceed their tuned-schedule counterparts.
Coding Tasks (use CoLab or notebook)
- Implement linear regression with both the normal equation and gradient descent. Compare the solutions and plot the convergence of the GD loss over iterations.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
# Generate synthetic data: y = 3x + 2 + noise
key = jax.random.PRNGKey(42)
n = 100
X = jax.random.uniform(key, (n, 1), minval=0, maxval=10)
y = 3 * X[:, 0] + 2 + jax.random.normal(key, (n,)) * 1.5
# Add bias column
X_b = jnp.column_stack([X, jnp.ones(n)])
# Normal equation
w_exact = jnp.linalg.solve(X_b.T @ X_b, X_b.T @ y)
print(f"Normal equation: w={w_exact[0]:.4f}, b={w_exact[1]:.4f}")
# Gradient descent
w_gd = jnp.zeros(2)
lr = 0.005
losses = []
for step in range(500):
pred = X_b @ w_gd
error = pred - y
loss = jnp.mean(error ** 2)
losses.append(float(loss))
grad = (2 / n) * X_b.T @ error
w_gd = w_gd - lr * grad
print(f"Gradient descent: w={w_gd[0]:.4f}, b={w_gd[1]:.4f}")
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].scatter(X[:, 0], y, s=15, alpha=0.5, color='#3498db')
axes[0].plot([0, 10], [w_exact[1], w_exact[0]*10 + w_exact[1]], color='#e74c3c', linewidth=2)
axes[0].set_title("Linear Regression Fit")
axes[0].set_xlabel("x"); axes[0].set_ylabel("y")
axes[1].plot(losses, color='#27ae60', linewidth=1.5)
axes[1].set_title("GD Loss Convergence")
axes[1].set_xlabel("Step"); axes[1].set_ylabel("MSE")
axes[1].set_yscale('log')
plt.tight_layout()
plt.show()
- Implement logistic regression from scratch with gradient descent. Train on a 2D dataset and visualise the learned decision boundary.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
# Generate data
X, y = make_moons(n_samples=300, noise=0.2, random_state=42)
X, y = jnp.array(X), jnp.array(y, dtype=jnp.float32)
def sigmoid(z):
return 1 / (1 + jnp.exp(-z))
# Add bias column
X_b = jnp.column_stack([X, jnp.ones(len(X))])
w = jnp.zeros(3)
lr = 0.5
losses = []
for step in range(2000):
z = X_b @ w
pred = sigmoid(z)
# BCE loss
loss = -jnp.mean(y * jnp.log(pred + 1e-8) + (1 - y) * jnp.log(1 - pred + 1e-8))
losses.append(float(loss))
# Gradient
grad = X_b.T @ (pred - y) / len(y)
w = w - lr * grad
# Decision boundary
xx, yy = jnp.meshgrid(jnp.linspace(-2, 3, 200), jnp.linspace(-1.5, 2, 200))
grid = jnp.column_stack([xx.ravel(), yy.ravel(), jnp.ones(xx.size)])
zz = sigmoid(grid @ w).reshape(xx.shape)
plt.figure(figsize=(8, 6))
plt.contourf(xx, yy, zz, levels=[0, 0.5, 1], alpha=0.3, colors=['#e74c3c', '#3498db'])
plt.contour(xx, yy, zz, levels=[0.5], colors='#9b59b6', linewidths=2)
plt.scatter(X[y==0, 0], X[y==0, 1], c='#e74c3c', s=15, label='Class 0')
plt.scatter(X[y==1, 0], X[y==1, 1], c='#3498db', s=15, label='Class 1')
plt.title("Logistic Regression Decision Boundary")
plt.legend()
plt.grid(alpha=0.3)
plt.show()
- Compare optimiser trajectories on a 2D quadratic surface. Run SGD, SGD+Momentum, and Adam from the same starting point and plot their paths.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
# Elongated quadratic: L(w1, w2) = 0.5*w1^2 + 10*w2^2
def loss_fn(w):
return 0.5 * w[0]**2 + 10 * w[1]**2
grad_fn = jax.grad(loss_fn)
def run_sgd(w0, lr=0.05, steps=80):
w = w0.copy()
path = [w.copy()]
for _ in range(steps):
g = grad_fn(w)
w = w - lr * g
path.append(w.copy())
return jnp.stack(path)
def run_momentum(w0, lr=0.05, beta=0.9, steps=80):
w, v = w0.copy(), jnp.zeros(2)
path = [w.copy()]
for _ in range(steps):
g = grad_fn(w)
v = beta * v + (1 - beta) * g
w = w - lr * v
path.append(w.copy())
return jnp.stack(path)
def run_adam(w0, lr=0.05, b1=0.9, b2=0.999, eps=1e-8, steps=80):
w, m, v = w0.copy(), jnp.zeros(2), jnp.zeros(2)
path = [w.copy()]
for t in range(1, steps + 1):
g = grad_fn(w)
m = b1 * m + (1 - b1) * g
v = b2 * v + (1 - b2) * g**2
m_hat = m / (1 - b1**t)
v_hat = v / (1 - b2**t)
w = w - lr * m_hat / (jnp.sqrt(v_hat) + eps)
path.append(w.copy())
return jnp.stack(path)
w0 = jnp.array([8.0, 3.0])
sgd_path = run_sgd(w0)
mom_path = run_momentum(w0)
adam_path = run_adam(w0)
# Plot
fig, ax = plt.subplots(figsize=(8, 6))
w1 = jnp.linspace(-10, 10, 100)
w2 = jnp.linspace(-4, 4, 100)
W1, W2 = jnp.meshgrid(w1, w2)
L = 0.5 * W1**2 + 10 * W2**2
ax.contour(W1, W2, L, levels=20, cmap='Greys', alpha=0.4)
ax.plot(sgd_path[:,0], sgd_path[:,1], 'o-', color='#3498db', markersize=2, linewidth=1, label='SGD')
ax.plot(mom_path[:,0], mom_path[:,1], 'o-', color='#27ae60', markersize=2, linewidth=1, label='Momentum')
ax.plot(adam_path[:,0], adam_path[:,1], 'o-', color='#e74c3c', markersize=2, linewidth=1, label='Adam')
ax.plot(0, 0, 'k*', markersize=15, label='Minimum')
ax.set_xlabel('w₁'); ax.set_ylabel('w₂')
ax.set_title("Optimizer Trajectories on Elongated Quadratic")
ax.legend()
plt.grid(alpha=0.3)
plt.show()
- Show the effect of L1 vs L2 regularisation on weight sparsity. Train linear regression with both penalties and compare the resulting weight vectors.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
# Synthetic data: only first 3 of 20 features are relevant
key = jax.random.PRNGKey(0)
n, d = 200, 20
w_true = jnp.zeros(d).at[:3].set(jnp.array([3.0, -2.0, 1.5]))
X = jax.random.normal(key, (n, d))
y = X @ w_true + 0.5 * jax.random.normal(key, (n,))
def train_ridge(X, y, lam=1.0, lr=0.01, steps=2000):
"""L2 regularised linear regression via GD."""
w = jnp.zeros(X.shape[1])
for _ in range(steps):
pred = X @ w
grad = (2/len(y)) * X.T @ (pred - y) + 2 * lam * w
w = w - lr * grad
return w
def train_lasso(X, y, lam=1.0, lr=0.01, steps=2000):
"""L1 regularised linear regression via proximal GD."""
w = jnp.zeros(X.shape[1])
for _ in range(steps):
pred = X @ w
grad = (2/len(y)) * X.T @ (pred - y)
w = w - lr * grad
# Soft thresholding (proximal operator for L1)
w = jnp.sign(w) * jnp.maximum(jnp.abs(w) - lr * lam, 0)
return w
w_l2 = train_ridge(X, y, lam=0.1)
w_l1 = train_lasso(X, y, lam=0.1)
fig, axes = plt.subplots(1, 3, figsize=(14, 4))
axes[0].bar(range(d), w_true, color='#333', alpha=0.7)
axes[0].set_title("True Weights"); axes[0].set_xlabel("Feature")
axes[1].bar(range(d), w_l2, color='#3498db', alpha=0.7)
axes[1].set_title("L2 (Ridge): shrinks all"); axes[1].set_xlabel("Feature")
axes[2].bar(range(d), w_l1, color='#e74c3c', alpha=0.7)
axes[2].set_title("L1 (Lasso): zeros out irrelevant"); axes[2].set_xlabel("Feature")
plt.tight_layout()
plt.show()
print(f"L2 non-zero weights: {int(jnp.sum(jnp.abs(w_l2) > 0.01))}/{d}")
print(f"L1 non-zero weights: {int(jnp.sum(jnp.abs(w_l1) > 0.01))}/{d}")
Deep Learning
-
What makes a network "deep"? A shallow network has one hidden layer; a deep network has many. Depth lets the network build hierarchical representations, with early layers learning simple features (edges, tones) and later layers composing them into complex concepts (faces, sentences). This compositionality is what gives deep learning its power.
-
The simplest deep network is the multi-layer perceptron (MLP), also called a fully connected or dense network. Each layer computes:
$$h = \sigma(Wx + b)$$
-
Here $W$ is a weight matrix (chapter 02), $b$ is a bias vector, and $\sigma$ is a nonlinear activation function. The output of one layer becomes the input to the next. Without the nonlinearity, stacking layers would be pointless: $W_2(W_1 x) = (W_2 W_1)x$, which is just another linear transformation. This is exactly the matrix multiplication collapse from chapter 02.
-
Activation functions introduce the nonlinearity that makes depth meaningful.
-
ReLU (Rectified Linear Unit): $\text{ReLU}(x) = \max(0, x)$. It is the most widely used activation. It is fast to compute, does not saturate for positive inputs, and produces sparse activations (many neurons output exactly zero). The downside: neurons with negative input always output zero, and if they get stuck there permanently, they "die" and stop learning.
-
Sigmoid: $\sigma(x) = \frac{1}{1+e^{-x}}$, squashing inputs to $(0, 1)$. Useful for output layers in binary classification, but problematic in hidden layers because gradients vanish when the input is far from zero (the curve is nearly flat).
-
Tanh: $\tanh(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}$, squashing to $(-1, 1)$. Zero-centred (unlike sigmoid), which helps gradient flow, but still suffers from vanishing gradients at extremes.
-
GELU (Gaussian Error Linear Unit): $\text{GELU}(x) = x \cdot \Phi(x)$, where $\Phi$ is the standard normal CDF. It is a smooth approximation to ReLU that allows small negative values through. GELU is the default in GPT and BERT.
-
Swish: $\text{Swish}(x) = x \cdot \sigma(x)$, another smooth gate. Similar to GELU in practice.
-
A dense layer with $d_{\text{in}}$ inputs and $d_{\text{out}}$ outputs has $d_{\text{in}} \times d_{\text{out}} + d_{\text{out}}$ parameters (weights plus biases). The matrix multiply $Wx$ is just matrix-vector multiplication from chapter 02. In a batch setting, the input is a matrix $X$ of shape $(B, d_{\text{in}})$ and the output is $XW^T + b$ of shape $(B, d_{\text{out}})$.
-
The universal approximation theorem states that a single hidden layer with enough neurons can approximate any continuous function on a compact domain to arbitrary accuracy. This sounds like depth should not matter, but the catch is "enough neurons." In practice, deep networks can represent the same functions with exponentially fewer parameters than shallow ones. Depth gives you efficiency, not just expressiveness.
-
As networks get deeper, two gradient pathologies emerge. Vanishing gradients: when gradients pass through many layers (via the chain rule, chapter 03), they get multiplied by many factors. If these factors are consistently less than 1 (as happens with sigmoid and tanh saturating), the gradient shrinks exponentially toward zero. Early layers barely learn. Exploding gradients: if factors are consistently greater than 1, gradients grow exponentially, causing numerical overflow and unstable training.
-
Solutions to vanishing/exploding gradients:
- Use ReLU or GELU activations (gradient is 1 for positive inputs, no saturation)
- Careful weight initialisation
- Normalisation layers
- Residual connections (skip connections)
- Gradient clipping (for exploding gradients): cap the gradient norm at a maximum value
-
Weight initialisation matters because it determines the scale of activations and gradients at the start of training. If weights are too large, activations explode; too small, they vanish.
-
Xavier (Glorot) initialisation sets weights from a distribution with variance $\frac{2}{d_{\text{in}} + d_{\text{out}}}$. This keeps the variance of activations roughly constant across layers, assuming linear or tanh activations.
-
He (Kaiming) initialisation uses variance $\frac{2}{d_{\text{in}}}$, which is calibrated for ReLU activations (since ReLU zeros out half the activations, you need double the variance to compensate).
-
Normalisation layers stabilise training by ensuring that the inputs to each layer have consistent statistics (roughly zero mean, unit variance).
-
Batch Normalisation (BatchNorm) normalises across the batch dimension: for each channel/feature, compute the mean and variance across all samples in the mini-batch, then normalise. It adds learnable scale ($\gamma$) and shift ($\beta$) parameters so the network can undo the normalisation if needed:
$$\hat{x} = \frac{x - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}, \quad y = \gamma \hat{x} + \beta$$
-
BatchNorm has a problem: it depends on the batch size. With very small batches, the statistics are noisy. At inference time, you use running averages instead of batch statistics, which creates a train/test discrepancy.
-
Layer Normalisation (LayerNorm) normalises across the feature dimension for each individual sample. It does not depend on other samples in the batch, making it the standard choice for transformers and recurrent networks.
-
Instance Normalisation normalises across spatial dimensions for each sample and each channel independently. It is popular in style transfer.
-
Group Normalisation splits channels into groups and normalises within each group. It is a compromise between LayerNorm and InstanceNorm.
-
Dropout is a regularisation technique that randomly zeroes out a fraction $p$ of neurons during training. This forces the network to not rely on any single neuron, encouraging redundant representations. At test time, all neurons are active. Inverted dropout scales activations by $\frac{1}{1-p}$ during training so that no scaling is needed at test time. This is the standard implementation.
-
Convolutional Neural Networks (CNNs) exploit spatial structure. Instead of connecting every input to every output (as in dense layers), a convolutional layer slides a small filter (kernel) across the input, computing a dot product at each position. The same filter weights are shared across all positions, which drastically reduces parameters and builds in translation invariance.
-
The convolution operation for a 2D input with filter $K$ of size $k \times k$:
$$(\text{input} * K)[i,j] = \sum_{m=0}^{k-1} \sum_{n=0}^{k-1} \text{input}[i+m, j+n] \cdot K[m, n]$$
-
The output size depends on three hyperparameters. Stride controls how many pixels the filter moves between positions (stride 2 halves the spatial dimensions). Padding adds zeros around the input border ("same" padding preserves spatial size, "valid" padding does not). The output size formula: $\text{out} = \lfloor (\text{in} - k + 2p) / s \rfloor + 1$.
-
Pooling layers downsample feature maps. Max pooling takes the maximum value in each window; average pooling takes the mean. Pooling reduces spatial dimensions while keeping the most important information.
-
Dilated convolutions insert gaps between filter elements, increasing the receptive field without increasing parameters. A dilation rate of 2 means the 3x3 filter covers a 5x5 area.
-
1x1 convolutions are convolutions with a 1x1 filter. They do not look at spatial neighbours; instead, they mix information across channels. Think of them as applying a dense layer at every spatial position. They are used to change the number of channels cheaply.
-
Skip connections (residual connections) let the input bypass one or more layers: $\text{output} = F(x) + x$. The layer only needs to learn the residual $F(x) = \text{output} - x$, which is easier when the optimal transformation is close to identity. ResNets (Residual Networks) stacked over 100 layers using this trick, solving the degradation problem where deeper networks performed worse than shallower ones.
-
CNNs build a feature hierarchy. Early layers detect edges and textures. Middle layers combine these into parts (eyes, wheels). Late layers recognise whole objects. Each layer's receptive field (the region of the input it can "see") grows with depth.
-
Embeddings map discrete tokens (words, characters, item IDs) to dense vectors. An embedding layer is just a lookup table: a matrix $E$ of shape (vocabulary size, embedding dimension). Looking up token $i$ means selecting row $i$ of $E$. This is equivalent to multiplying by a one-hot vector, which is just a special case of matrix-vector multiplication (chapter 02). Embeddings are learned during training, so similar tokens end up with similar vectors.
-
Tokenisation is the process of converting raw text into a sequence of tokens. Word-level tokenisation splits on spaces but cannot handle unseen words. Subword tokenisation (BPE, WordPiece, SentencePiece) breaks text into frequent subword units, balancing vocabulary size and coverage. The word "unhappiness" might become ["un", "happiness"] or ["un", "happ", "iness"].
-
Recurrent Neural Networks (RNNs) process sequences one element at a time, maintaining a hidden state that carries information forward:
$$h_t = \tanh(W_h h_{t-1} + W_x x_t + b)$$
-
The hidden state $h_t$ is a compressed summary of everything the network has seen up to time $t$. The same weights $W_h$ and $W_x$ are shared across all time steps (weight sharing, like CNNs share spatial weights).
-
Vanilla RNNs struggle with long sequences because of vanishing gradients: the gradient signal from step $t$ to step $t - k$ passes through $k$ multiplications by $W_h$, and it shrinks (or explodes) exponentially.
-
LSTM (Long Short-Term Memory) solves this by introducing a separate cell state $c_t$ that flows through time with minimal interference. Three gates control what information enters, leaves, and persists:
-
The forget gate decides what to erase from the cell state: $f_t = \sigma(W_f [h_{t-1}, x_t] + b_f)$
-
The input gate decides what new information to write: $i_t = \sigma(W_i [h_{t-1}, x_t] + b_i)$, with candidate values $\tilde{c}t = \tanh(W_c [h{t-1}, x_t] + b_c)$
-
The cell state updates: $c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t$
-
The output gate decides what to expose: $o_t = \sigma(W_o [h_{t-1}, x_t] + b_o)$, and $h_t = o_t \odot \tanh(c_t)$
-
The cell state acts like a conveyor belt: information can flow unchanged across many time steps (the forget gate stays close to 1), which solves the vanishing gradient problem for long-range dependencies.
-
GRU (Gated Recurrent Unit) simplifies the LSTM by merging the cell state and hidden state into one, and using two gates instead of three: an update gate (combines forget and input) and a reset gate. GRUs have fewer parameters and often perform comparably to LSTMs.
-
The fundamental limitation of RNNs (including LSTMs) is sequential processing: you must process token 1 before token 2 before token 3. This prevents parallelisation and creates an information bottleneck, as all context must squeeze through the fixed-size hidden state.
-
Attention solves both problems. Instead of compressing the entire input into a fixed vector, attention lets the model look back at all input positions and decide which ones are relevant for the current output.
-
The modern formulation uses queries, keys, and values (Q, K, V). Think of it like a library search: you have a query (what you are looking for), keys (labels on each book), and values (the actual book contents). You compare your query against all keys to figure out which values to retrieve.
-
Scaled dot-product attention:
$$\text{Attention}(Q, K, V) = \text{softmax}!\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$
-
$QK^T$ computes the similarity between every query and every key. This is a matrix multiply (chapter 02), and the entries are dot products, which measure cosine similarity (chapter 01). Dividing by $\sqrt{d_k}$ prevents the dot products from becoming too large (which would make the softmax saturate and produce near-one-hot distributions with vanishing gradients). The softmax converts similarities to a probability distribution. Multiplying by $V$ produces a weighted combination of values.
-
Multi-head attention runs $h$ parallel attention operations, each with different learned projections of Q, K, and V. This lets the model attend to information from different representation subspaces simultaneously. One head might attend to syntactic relationships while another attends to semantic ones. The outputs are concatenated and projected:
$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O$$
- The Transformer architecture (Vaswani et al., 2017) is built entirely from attention and feed-forward layers, with no recurrence. The encoder block repeats: multi-head self-attention, add and layer-norm, feed-forward network, add and layer-norm. The decoder block adds a masked self-attention (preventing the model from seeing future tokens) and a cross-attention layer that attends to the encoder output.
- Positional encoding is necessary because attention is permutation-equivariant, meaning it treats the input as a set, not a sequence. Without position information, "the cat sat on the mat" and "the mat sat on the cat" would be identical. The original Transformer uses sinusoidal positional encodings:
$$PE_{(pos, 2i)} = \sin!\left(\frac{pos}{10000^{2i/d}}\right), \quad PE_{(pos, 2i+1)} = \cos!\left(\frac{pos}{10000^{2i/d}}\right)$$
-
Each position gets a unique vector that the model can use to distinguish positions. Modern models often use learned positional embeddings or relative positional encodings (RoPE, ALiBi) instead.
-
Transformers process all tokens in parallel (the self-attention matrix $QK^T$ is computed in one matrix multiply), which makes them much faster to train than RNNs on modern hardware. The tradeoff is that self-attention is $O(n^2)$ in sequence length (every token attends to every other), while RNNs are $O(n)$. This is why long-context models require special attention variants (sparse attention, linear attention, flash attention).
-
Vision Transformers (ViT) apply the Transformer to images by splitting the image into fixed-size patches (e.g., 16x16), flattening each patch into a vector, and treating the patches as a sequence of tokens. A learnable [CLS] token is prepended, and its final representation is used for classification. Despite having no convolutional inductive biases, ViTs match or surpass CNNs when trained on enough data.
-
MLP-Mixer is an even simpler architecture that replaces both attention and convolution with MLPs. It alternates between "token-mixing" MLPs (applied across spatial positions) and "channel-mixing" MLPs (applied across features). It performs competitively, suggesting that the key insight of modern architectures is not attention itself, but rather efficient mixing of information across tokens and features.
-
Autoencoders learn compressed representations by training a network to reconstruct its own input. The encoder maps the input to a lower-dimensional bottleneck (the latent code), and the decoder maps it back:
$$z = f_{\text{enc}}(x), \quad \hat{x} = f_{\text{dec}}(z), \quad \mathcal{L} = |x - \hat{x}|^2$$
-
The bottleneck forces the network to learn the most important features. Autoencoders are used for dimensionality reduction, denoising (train on noisy input, reconstruct clean output), and anomaly detection (high reconstruction error signals an unusual input).
-
Variational Autoencoders (VAEs) add a probabilistic twist. Instead of encoding to a single point $z$, the encoder outputs the parameters of a distribution (mean $\mu$ and variance $\sigma^2$ of a Gaussian). The latent code is sampled from this distribution: $z = \mu + \sigma \odot \epsilon$, where $\epsilon \sim \mathcal{N}(0, I)$. This reparameterisation trick makes the sampling differentiable so gradients can flow through.
-
The VAE loss has two terms:
$$\mathcal{L} = \underbrace{|x - \hat{x}|^2}{\text{reconstruction}} + \underbrace{D{\text{KL}}(q(z|x) | p(z))}_{\text{regularisation}}$$
- The KL divergence term (from chapter 05) pushes the learned posterior $q(z|x)$ toward the prior $p(z) = \mathcal{N}(0, I)$, ensuring the latent space is smooth and well-structured. You can then sample from the prior and decode to generate new data. This is what makes VAEs generative models.
Coding Tasks (use CoLab or notebook)
- Build a simple MLP from scratch in JAX. Train it on a 2D classification problem (e.g., concentric circles) and visualise the decision boundary.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from sklearn.datasets import make_circles
# Data
X, y = make_circles(n_samples=500, noise=0.1, factor=0.5, random_state=42)
X, y = jnp.array(X), jnp.array(y, dtype=jnp.float32)
# Initialise a 2-layer MLP: 2 -> 16 -> 16 -> 1
def init_params(key):
k1, k2, k3 = jax.random.split(key, 3)
return {
'W1': jax.random.normal(k1, (2, 16)) * 0.5,
'b1': jnp.zeros(16),
'W2': jax.random.normal(k2, (16, 16)) * 0.5,
'b2': jnp.zeros(16),
'W3': jax.random.normal(k3, (16, 1)) * 0.5,
'b3': jnp.zeros(1),
}
def forward(params, x):
h = jnp.maximum(0, x @ params['W1'] + params['b1']) # ReLU
h = jnp.maximum(0, h @ params['W2'] + params['b2']) # ReLU
logit = (h @ params['W3'] + params['b3']).squeeze()
return jax.nn.sigmoid(logit)
def loss_fn(params, X, y):
pred = forward(params, X)
return -jnp.mean(y * jnp.log(pred + 1e-7) + (1 - y) * jnp.log(1 - pred + 1e-7))
grad_fn = jax.jit(jax.grad(loss_fn))
params = init_params(jax.random.PRNGKey(0))
lr = 0.1
for step in range(2000):
grads = grad_fn(params, X, y)
params = {k: params[k] - lr * grads[k] for k in params}
# Plot decision boundary
xx, yy = jnp.meshgrid(jnp.linspace(-2, 2, 200), jnp.linspace(-2, 2, 200))
grid = jnp.column_stack([xx.ravel(), yy.ravel()])
zz = forward(params, grid).reshape(xx.shape)
plt.figure(figsize=(7, 6))
plt.contourf(xx, yy, zz, levels=[0, 0.5, 1], alpha=0.3, colors=['#e74c3c', '#3498db'])
plt.scatter(X[y==0,0], X[y==0,1], c='#e74c3c', s=10, label='Class 0')
plt.scatter(X[y==1,0], X[y==1,1], c='#3498db', s=10, label='Class 1')
plt.title("MLP Decision Boundary on Concentric Circles")
plt.legend(); plt.grid(alpha=0.3); plt.show()
acc = jnp.mean((forward(params, X) > 0.5) == y)
print(f"Accuracy: {acc:.2%}")
- Implement 1D convolution from scratch. Apply a simple edge-detection filter to a signal and compare with the built-in
jnp.convolve.
import jax.numpy as jnp
import matplotlib.pyplot as plt
def conv1d(signal, kernel):
"""1D convolution (valid mode) from scratch."""
n, k = len(signal), len(kernel)
output = jnp.zeros(n - k + 1)
for i in range(n - k + 1):
output = output.at[i].set(jnp.sum(signal[i:i+k] * kernel))
return output
# Create a signal with a step function
t = jnp.linspace(0, 4, 200)
signal = jnp.where(t < 1, 0.0, jnp.where(t < 2, 1.0, jnp.where(t < 3, 0.5, 1.5)))
# Edge detection kernel
edge_kernel = jnp.array([-1.0, 0.0, 1.0])
# Our implementation vs built-in
our_output = conv1d(signal, edge_kernel)
jnp_output = jnp.convolve(signal, edge_kernel, mode='valid')
fig, axes = plt.subplots(3, 1, figsize=(10, 6), sharex=True)
axes[0].plot(t, signal, color='#3498db', linewidth=1.5)
axes[0].set_title("Original Signal"); axes[0].set_ylabel("Value")
axes[1].plot(t[:len(our_output)], our_output, color='#e74c3c', linewidth=1.5)
axes[1].set_title("After Edge Detection (our conv1d)"); axes[1].set_ylabel("Value")
axes[2].plot(t[:len(jnp_output)], jnp_output, color='#27ae60', linewidth=1.5, linestyle='--')
axes[2].set_title("After Edge Detection (jnp.convolve)"); axes[2].set_ylabel("Value")
axes[2].set_xlabel("t")
plt.tight_layout(); plt.show()
print(f"Outputs match: {jnp.allclose(our_output, jnp_output)}")
- Implement scaled dot-product attention from scratch. Compute attention weights for a small example and visualise the attention matrix as a heatmap.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
def scaled_dot_product_attention(Q, K, V):
"""Scaled dot-product attention."""
d_k = Q.shape[-1]
scores = Q @ K.T / jnp.sqrt(d_k)
weights = jax.nn.softmax(scores, axis=-1)
output = weights @ V
return output, weights
# Example: 4 tokens, embedding dim 8
key = jax.random.PRNGKey(42)
k1, k2, k3 = jax.random.split(key, 3)
seq_len, d_model = 4, 8
Q = jax.random.normal(k1, (seq_len, d_model))
K = jax.random.normal(k2, (seq_len, d_model))
V = jax.random.normal(k3, (seq_len, d_model))
output, weights = scaled_dot_product_attention(Q, K, V)
print(f"Q shape: {Q.shape}")
print(f"Attention weights shape: {weights.shape}")
print(f"Output shape: {output.shape}")
print(f"\nAttention weights (rows sum to 1):")
print(weights)
print(f"Row sums: {weights.sum(axis=-1)}")
# Visualise attention
fig, ax = plt.subplots(figsize=(5, 4))
im = ax.imshow(weights, cmap='Blues', vmin=0, vmax=1)
ax.set_xlabel("Key position"); ax.set_ylabel("Query position")
ax.set_title("Attention Weights")
tokens = ['tok 0', 'tok 1', 'tok 2', 'tok 3']
ax.set_xticks(range(4)); ax.set_xticklabels(tokens)
ax.set_yticks(range(4)); ax.set_yticklabels(tokens)
for i in range(4):
for j in range(4):
ax.text(j, i, f"{weights[i,j]:.2f}", ha='center', va='center', fontsize=10)
plt.colorbar(im); plt.tight_layout(); plt.show()
- Build a simple autoencoder that compresses 2D data through a 1D bottleneck and reconstructs it. Visualise the latent space and reconstructions.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
# Data
X, _ = make_moons(n_samples=500, noise=0.05, random_state=42)
X = jnp.array(X)
# Autoencoder: 2 -> 8 -> 1 -> 8 -> 2
def init_ae(key):
k1, k2, k3, k4 = jax.random.split(key, 4)
return {
'enc_W1': jax.random.normal(k1, (2, 8)) * 0.5, 'enc_b1': jnp.zeros(8),
'enc_W2': jax.random.normal(k2, (8, 1)) * 0.5, 'enc_b2': jnp.zeros(1),
'dec_W1': jax.random.normal(k3, (1, 8)) * 0.5, 'dec_b1': jnp.zeros(8),
'dec_W2': jax.random.normal(k4, (8, 2)) * 0.5, 'dec_b2': jnp.zeros(2),
}
def encode(p, x):
h = jnp.tanh(x @ p['enc_W1'] + p['enc_b1'])
return h @ p['enc_W2'] + p['enc_b2']
def decode(p, z):
h = jnp.tanh(z @ p['dec_W1'] + p['dec_b1'])
return h @ p['dec_W2'] + p['dec_b2']
def ae_loss(p, X):
z = encode(p, X)
X_hat = decode(p, z)
return jnp.mean((X - X_hat) ** 2)
grad_fn = jax.jit(jax.grad(ae_loss))
params = init_ae(jax.random.PRNGKey(0))
lr = 0.01
for step in range(3000):
grads = grad_fn(params, X)
params = {k: params[k] - lr * grads[k] for k in params}
z = encode(params, X)
X_hat = decode(params, z)
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].scatter(X[:,0], X[:,1], c=z.squeeze(), cmap='viridis', s=10)
axes[0].set_title("Original Data (coloured by latent code)")
axes[1].scatter(X_hat[:,0], X_hat[:,1], c=z.squeeze(), cmap='viridis', s=10)
axes[1].set_title("Reconstruction from 1D bottleneck")
for ax in axes:
ax.set_aspect('equal'); ax.grid(alpha=0.3)
plt.tight_layout(); plt.show()
print(f"Reconstruction MSE: {ae_loss(params, X):.4f}")
Reinforcement Learning
-
Supervised learning needs labelled data. Unsupervised learning finds patterns in unlabelled data. Reinforcement learning (RL) is different from both: an agent learns by interacting with an environment, taking actions, and receiving rewards. There are no correct labels; the agent must discover good behaviour through trial and error.
-
Think of teaching a dog a new trick. You do not show it a dataset of correct behaviours. Instead, it tries things, you give treats for good actions, and over time it figures out what you want. RL formalises this process.
-
The RL setup has five core components. The agent is the learner and decision-maker. The environment is everything outside the agent that it interacts with. At each time step, the agent observes a state $s_t$, chooses an action $a_t$, receives a reward $r_t$, and transitions to a new state $s_{t+1}$. The agent's goal is to maximise the total reward it collects over time.
-
A policy $\pi$ is the agent's strategy: a mapping from states to actions. A deterministic policy gives one action per state: $a = \pi(s)$. A stochastic policy gives a probability distribution over actions: $\pi(a \mid s)$. The goal of RL is to find the optimal policy, the one that maximises expected cumulative reward.
-
The mathematical framework for RL is the Markov Decision Process (MDP), defined by a tuple $(S, A, P, R, \gamma)$: a set of states $S$, a set of actions $A$, transition probabilities $P(s' \mid s, a)$, a reward function $R(s, a)$, and a discount factor $\gamma$.
-
The Markov property (from chapter 05) says the future depends only on the current state, not on the history of how you got there: $P(s_{t+1} \mid s_t, a_t, s_{t-1}, \ldots) = P(s_{t+1} \mid s_t, a_t)$. This means the state contains all the information needed to make a decision.
-
The discount factor $\gamma \in [0, 1)$ determines how much the agent cares about future rewards versus immediate ones. The discounted return from time $t$ is:
$$G_t = r_t + \gamma r_{t+1} + \gamma^2 r_{t+2} + \cdots = \sum_{k=0}^{\infty} \gamma^k r_{t+k}$$
-
With $\gamma = 0$, the agent is completely myopic, caring only about the next reward. With $\gamma$ close to 1, the agent is far-sighted. The discount factor also ensures the sum converges (if rewards are bounded), which is important for mathematical well-definedness.
-
Value functions estimate how good it is to be in a state (or to take an action in a state). The state-value function $V^\pi(s)$ is the expected return starting from state $s$ and following policy $\pi$:
$$V^\pi(s) = \mathbb{E}_\pi \left[ G_t \mid s_t = s \right]$$
- The action-value function $Q^\pi(s, a)$ is the expected return starting from state $s$, taking action $a$, and then following $\pi$:
$$Q^\pi(s, a) = \mathbb{E}_\pi \left[ G_t \mid s_t = s, a_t = a \right]$$
-
The relationship: $V^\pi(s) = \sum_a \pi(a \mid s) , Q^\pi(s, a)$. The state value is the average of action values, weighted by the policy.
-
The Bellman equation expresses a recursive relationship: the value of a state equals the immediate reward plus the discounted value of the next state. For the state-value function:
$$V^\pi(s) = \sum_a \pi(a \mid s) \sum_{s'} P(s' \mid s, a) \left[ R(s, a) + \gamma , V^\pi(s') \right]$$
- For the optimal value function $V^{*}(s)$, the agent always picks the best action:
$$V^{}(s) = \max_a \sum_{s'} P(s' \mid s, a) \left[ R(s, a) + \gamma , V^{}(s') \right]$$
- Similarly, the Bellman optimality equation for $Q^{*}$:
$$Q^{}(s, a) = \sum_{s'} P(s' \mid s, a) \left[ R(s, a) + \gamma \max_{a'} Q^{}(s', a') \right]$$
-
Once you have $Q^{}$, the optimal policy is trivial: always pick the action with the highest Q-value: $\pi^{}(s) = \arg\max_a Q^{*}(s, a)$.
-
Dynamic programming methods solve MDPs when you know the transition probabilities and rewards (the full model). Policy evaluation computes $V^\pi$ for a given policy by iteratively applying the Bellman equation until convergence. Policy improvement takes the value function and constructs a better policy by acting greedily: $\pi'(s) = \arg\max_a \sum_{s'} P(s' \mid s, a)[R(s,a) + \gamma V^\pi(s')]$.
-
Policy iteration alternates between evaluation and improvement until the policy stops changing. It is guaranteed to converge to the optimal policy.
-
Value iteration combines both steps into one: it repeatedly applies the Bellman optimality equation until $V^{*}$ converges, then extracts the policy.
$$V(s) \leftarrow \max_a \sum_{s'} P(s' \mid s, a) \left[ R(s, a) + \gamma , V(s') \right]$$
-
Dynamic programming requires knowing $P(s' \mid s, a)$, which is often impractical. In most real problems, the agent does not know the environment's dynamics; it can only interact with it. This is where model-free methods come in.
-
Temporal Difference (TD) learning learns from experience without knowing the model. The key idea is bootstrapping: instead of waiting until the end of an episode to compute the actual return $G_t$, you estimate it using the current value function:
$$V(s_t) \leftarrow V(s_t) + \alpha \left[ r_t + \gamma , V(s_{t+1}) - V(s_t) \right]$$
- The term in brackets is the TD error: the difference between the TD target ($r_t + \gamma V(s_{t+1})$) and the current estimate $V(s_t)$. If the TD error is positive, the state was better than expected, so we increase its value. If negative, we decrease it.
-
TD learning updates after every single step (not after complete episodes), which makes it much more efficient than Monte Carlo methods. It also works in continuing (non-episodic) environments.
-
SARSA (State-Action-Reward-State-Action) is TD learning applied to Q-values. The agent takes action $a$ in state $s$, observes reward $r$ and next state $s'$, then chooses next action $a'$ according to its policy:
$$Q(s, a) \leftarrow Q(s, a) + \alpha \left[ r + \gamma , Q(s', a') - Q(s, a) \right]$$
-
SARSA is on-policy: it updates using the action the agent actually takes, which includes exploration. This makes SARSA more conservative; it learns a policy that accounts for its own exploration noise.
-
Q-learning is the most famous RL algorithm. It is like SARSA, but instead of using the action the agent actually takes, it uses the best possible action:
$$Q(s, a) \leftarrow Q(s, a) + \alpha \left[ r + \gamma \max_{a'} Q(s', a') - Q(s, a) \right]$$
-
Q-learning is off-policy: it learns the optimal Q-values regardless of the policy being followed. The agent can explore randomly while still learning the optimal action values. This makes Q-learning more aggressive and often faster to converge, but it can overestimate values.
-
Exploration vs exploitation is the fundamental dilemma: should the agent exploit what it already knows (choose the action with the highest estimated value) or explore unknown actions (which might turn out to be better)?
-
The simplest strategy is epsilon-greedy: with probability $\epsilon$, take a random action (explore); with probability $1 - \epsilon$, take the greedy action (exploit). A common schedule starts with high $\epsilon$ (lots of exploration) and decays it over time.
-
Tabular methods (storing a value for each state-action pair in a table) work for small, discrete state spaces. For large or continuous state spaces, you need function approximation. Deep Q-Networks (DQN) use a neural network to approximate $Q(s, a; \theta)$, where $\theta$ are the network weights.
-
DQN introduced two critical stabilisation techniques. Experience replay: instead of learning from consecutive transitions (which are highly correlated), store transitions in a replay buffer and sample random mini-batches for training. This breaks correlations and reuses data efficiently.
-
Target network: use a separate, slowly-updated copy of the network to compute TD targets. Without this, the target moves every time you update the network, creating a "chasing your own tail" instability. The target network is updated periodically (hard update every $N$ steps) or continuously (soft update: $\theta^{-} \leftarrow \tau\theta + (1-\tau)\theta^{-}$).
-
The DQN loss is just MSE between predicted Q-values and TD targets:
$$\mathcal{L}(\theta) = \mathbb{E} \left[ \left( r + \gamma \max_{a'} Q(s', a'; \theta^{-}) - Q(s, a; \theta) \right)^2 \right]$$
-
All the methods so far learn value functions and derive policies from them. Policy gradient methods take a different approach: they directly parameterise the policy $\pi(a \mid s; \theta)$ and optimise it by gradient ascent on expected return.
-
The policy gradient theorem gives the gradient of expected return with respect to policy parameters:
$$\nabla_\theta J(\theta) = \mathbb{E}\pi \left[ \nabla\theta \log \pi(a \mid s; \theta) \cdot G_t \right]$$
-
This says: increase the probability of actions that led to high returns, decrease the probability of actions that led to low returns. The log-probability gradient gives the direction to change the policy, and $G_t$ scales how much to change it.
-
REINFORCE is the simplest policy gradient algorithm. Run an episode, compute returns $G_t$ for each step, and update:
$$\theta \leftarrow \theta + \alpha , \nabla_\theta \log \pi(a_t \mid s_t; \theta) \cdot G_t$$
- REINFORCE has high variance because $G_t$ is a noisy, single-sample estimate of the expected return. A common fix is to subtract a baseline (typically the average return or a learned value function) to reduce variance without introducing bias:
$$\theta \leftarrow \theta + \alpha , \nabla_\theta \log \pi(a_t \mid s_t; \theta) \cdot (G_t - b)$$
- Actor-Critic methods use two networks. The actor is the policy $\pi(a \mid s; \theta)$. The critic is a value function $V(s; \phi)$ that serves as the baseline. The advantage $A_t = r_t + \gamma V(s_{t+1}) - V(s_t)$ replaces $G_t - b$:
$$\theta \leftarrow \theta + \alpha , \nabla_\theta \log \pi(a_t \mid s_t; \theta) \cdot A_t$$
- The critic is updated by minimising TD error, just like value-based methods. The actor is updated using the policy gradient, with the critic's advantage estimate reducing variance. This is the best of both worlds.
-
PPO (Proximal Policy Optimization) is the most widely used policy gradient algorithm in practice. It addresses a key problem: if a policy update is too large, performance can collapse catastrophically.
-
PPO uses a clipped surrogate objective. Let $r_t(\theta) = \frac{\pi(a_t | s_t; \theta)}{\pi(a_t | s_t; \theta_{\text{old}})}$ be the probability ratio between new and old policies. The loss is:
$$\mathcal{L}^{\text{CLIP}}(\theta) = \mathbb{E} \left[ \min!\left( r_t(\theta) A_t, ; \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \right) \right]$$
-
The clipping (typically $\epsilon = 0.2$) prevents the ratio from moving too far from 1, which keeps updates small and stable. If the advantage is positive (action was good), the ratio is capped at $1 + \epsilon$. If negative (action was bad), the ratio is capped at $1 - \epsilon$. This is simpler and more stable than earlier trust-region methods (TRPO).
-
PPO is what was used to train ChatGPT-style models via RLHF (Reinforcement Learning from Human Feedback). In RLHF, a reward model is trained on human preference data (which of two outputs do humans prefer?), and then PPO optimises the language model's policy to maximise this learned reward.
-
DPO (Direct Preference Optimization) simplifies RLHF by eliminating the reward model entirely. Instead of training a reward model and then running RL, DPO derives a closed-form loss that directly optimises the policy from preference data:
$$\mathcal{L}{\text{DPO}}(\theta) = -\mathbb{E} \left[ \log \sigma!\left( \beta \log \frac{\pi\theta(y_w \mid x)}{\pi_{\text{ref}}(y_w \mid x)} - \beta \log \frac{\pi_\theta(y_l \mid x)}{\pi_{\text{ref}}(y_l \mid x)} \right) \right]$$
-
Here $y_w$ is the preferred (winning) response and $y_l$ is the dispreferred (losing) response. DPO increases the relative probability of preferred outputs and is much simpler to implement than PPO-based RLHF.
-
Two important distinctions in RL algorithms. On-policy vs off-policy: on-policy methods (SARSA, PPO) learn from data generated by the current policy; off-policy methods (Q-learning, DQN) can learn from data generated by any policy. Off-policy methods are more sample-efficient (they reuse old data) but can be less stable.
-
Model-based vs model-free: model-free methods (everything discussed so far) learn values or policies directly from experience. Model-based methods learn a model of the environment ($P(s' \mid s, a)$ and $R(s, a)$) and use it for planning (imagining future trajectories without actually taking actions). Model-based methods are more sample-efficient but add the complexity of learning an accurate model.
-
To summarise the RL landscape:
| Method | Type | Key Idea | Strength |
|---|---|---|---|
| Value Iteration | DP, model-based | Bellman optimality | Exact solution (small MDPs) |
| SARSA | TD, on-policy | Learn Q on-policy | Conservative, safe |
| Q-Learning | TD, off-policy | Learn Q*, greedy target | Simple, effective |
| DQN | Deep, off-policy | Neural Q + replay + target net | Scales to high-dim states |
| REINFORCE | Policy gradient | Gradient of log-prob * return | Simple policy optimisation |
| Actor-Critic | PG + value | Actor + critic for low variance | Practical and flexible |
| PPO | PG, clipped | Trust-region-like stability | Industry standard |
| DPO | Direct preference | Skip reward model | Simpler RLHF |
Coding Tasks (use CoLab or notebook)
- Implement value iteration for a simple gridworld. Compute the optimal value function and extract the optimal policy. Visualise both as a heatmap and arrow plot.
import jax.numpy as jnp
import matplotlib.pyplot as plt
# 4x4 gridworld: goal at (3,3), reward -1 per step, 0 at goal
grid_size = 4
gamma = 0.99
goal = (3, 3)
# Actions: up, down, left, right
actions = [(-1, 0), (1, 0), (0, -1), (0, 1)]
action_names = ['up', 'down', 'left', 'right']
action_arrows = ['\u2191', '\u2193', '\u2190', '\u2192']
def step(s, a):
"""Deterministic transition."""
ns = (max(0, min(grid_size-1, s[0]+a[0])),
max(0, min(grid_size-1, s[1]+a[1])))
return ns
# Value iteration
V = jnp.zeros((grid_size, grid_size))
for iteration in range(100):
V_new = jnp.array(V)
for i in range(grid_size):
for j in range(grid_size):
if (i, j) == goal:
continue
values = []
for a in actions:
ns = step((i, j), a)
values.append(-1 + gamma * float(V[ns[0], ns[1]]))
V_new = V_new.at[i, j].set(max(values))
if jnp.max(jnp.abs(V_new - V)) < 1e-6:
print(f"Converged in {iteration+1} iterations")
break
V = V_new
# Extract policy
policy = [['' for _ in range(grid_size)] for _ in range(grid_size)]
for i in range(grid_size):
for j in range(grid_size):
if (i, j) == goal:
policy[i][j] = 'G'
continue
best_a = max(range(4), key=lambda a: -1 + gamma * float(V[step((i,j), actions[a])[0], step((i,j), actions[a])[1]]))
policy[i][j] = action_arrows[best_a]
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
im = axes[0].imshow(V, cmap='YlOrRd_r')
axes[0].set_title("Optimal Value Function")
for i in range(grid_size):
for j in range(grid_size):
axes[0].text(j, i, f"{V[i,j]:.1f}", ha='center', va='center', fontsize=10)
plt.colorbar(im, ax=axes[0])
axes[1].imshow(jnp.ones((grid_size, grid_size)), cmap='Greys', vmin=0, vmax=2)
axes[1].set_title("Optimal Policy")
for i in range(grid_size):
for j in range(grid_size):
axes[1].text(j, i, policy[i][j], ha='center', va='center', fontsize=18)
plt.tight_layout(); plt.show()
- Implement tabular Q-learning on a simple gridworld. Train the agent, plot the learning curve, and show the learned Q-values.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
grid_size = 5
goal = (4, 4)
actions = [(-1,0), (1,0), (0,-1), (0,1)]
# Q-table
Q = {}
for i in range(grid_size):
for j in range(grid_size):
Q[(i,j)] = [0.0] * 4
alpha = 0.1
gamma = 0.95
epsilon = 1.0
epsilon_decay = 0.995
min_epsilon = 0.01
def step(s, a_idx):
a = actions[a_idx]
ns = (max(0, min(grid_size-1, s[0]+a[0])),
max(0, min(grid_size-1, s[1]+a[1])))
r = 0.0 if ns == goal else -1.0
done = ns == goal
return ns, r, done
key = jax.random.PRNGKey(42)
rewards_per_episode = []
for ep in range(500):
s = (0, 0)
total_reward = 0
for _ in range(100):
key, subkey = jax.random.split(key)
if float(jax.random.uniform(subkey)) < epsilon:
key, subkey = jax.random.split(key)
a = int(jax.random.randint(subkey, (), 0, 4))
else:
a = max(range(4), key=lambda i: Q[s][i])
ns, r, done = step(s, a)
total_reward += r
# Q-learning update
Q[s][a] += alpha * (r + gamma * max(Q[ns]) - Q[s][a])
s = ns
if done:
break
rewards_per_episode.append(total_reward)
epsilon = max(min_epsilon, epsilon * epsilon_decay)
plt.figure(figsize=(8, 4))
# Smooth the curve
window = 20
smoothed = [sum(rewards_per_episode[max(0,i-window):i+1])/min(i+1, window)
for i in range(len(rewards_per_episode))]
plt.plot(smoothed, color='#3498db', linewidth=1.5)
plt.xlabel("Episode"); plt.ylabel("Total Reward (smoothed)")
plt.title("Q-Learning on Gridworld")
plt.grid(alpha=0.3); plt.show()
# Show learned policy
arrow = ['\u2191', '\u2193', '\u2190', '\u2192']
print("Learned policy:")
for i in range(grid_size):
row = ""
for j in range(grid_size):
if (i,j) == goal:
row += " G "
else:
row += f" {arrow[max(range(4), key=lambda a: Q[(i,j)][a])]} "
print(row)
- Implement REINFORCE on a multi-armed bandit problem. Show how the policy evolves over training to favour the best arm.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
# 5-armed bandit with different expected rewards
true_rewards = jnp.array([0.2, 0.5, 0.8, 0.3, 0.1])
n_arms = len(true_rewards)
# Policy: softmax over logits
logits = jnp.zeros(n_arms)
lr = 0.1
key = jax.random.PRNGKey(42)
policy_history = []
reward_history = []
for step in range(2000):
probs = jax.nn.softmax(logits)
policy_history.append(probs)
# Sample action
key, subkey = jax.random.split(key)
action = jax.random.choice(subkey, n_arms, p=probs)
# Get reward (Bernoulli)
key, subkey = jax.random.split(key)
reward = float(jax.random.uniform(subkey) < true_rewards[action])
reward_history.append(reward)
# REINFORCE update
# grad log pi(a) = e_a - probs (for softmax parameterisation)
grad_log_pi = -probs.at[action].add(1.0) # one-hot(a) - probs
logits = logits + lr * reward * grad_log_pi
policy_history = jnp.stack(policy_history)
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
colors = ['#3498db', '#e74c3c', '#27ae60', '#9b59b6', '#f39c12']
for i in range(n_arms):
axes[0].plot(policy_history[:, i], color=colors[i],
label=f'Arm {i} (true={true_rewards[i]:.1f})', linewidth=1.5)
axes[0].set_xlabel("Step"); axes[0].set_ylabel("P(arm)")
axes[0].set_title("Policy Evolution (REINFORCE)")
axes[0].legend(fontsize=8); axes[0].grid(alpha=0.3)
# Smoothed reward
window = 50
smoothed = [sum(reward_history[max(0,i-window):i+1])/min(i+1,window)
for i in range(len(reward_history))]
axes[1].plot(smoothed, color='#27ae60', linewidth=1.5)
axes[1].axhline(y=0.8, color='#e74c3c', linestyle='--', alpha=0.5, label='Best arm')
axes[1].set_xlabel("Step"); axes[1].set_ylabel("Avg Reward")
axes[1].set_title("Reward Over Time"); axes[1].legend()
axes[1].grid(alpha=0.3)
plt.tight_layout(); plt.show()
Distributed Deep Learning
-
Training a large neural network on a single GPU eventually hits a wall. The model might not fit in memory, or training might take months. Distributed training spreads the work across multiple devices (GPUs, TPUs, or entire machines) to train faster and train bigger models. This file covers the techniques that make that possible.
-
To understand why distribution matters, start with the computational cost of training. A single forward pass through a dense layer with $d_{\text{in}}$ inputs and $d_{\text{out}}$ outputs on a batch of $B$ examples requires roughly $2 \cdot B \cdot d_{\text{in}} \cdot d_{\text{out}}$ FLOPs (floating-point operations): one multiply and one add for each element of the output matrix. The backward pass costs roughly twice the forward pass (computing gradients with respect to both the inputs and the weights), so one training step on a dense layer is about $6 \cdot B \cdot d_{\text{in}} \cdot d_{\text{out}}$ FLOPs.
-
For a transformer layer with hidden dimension $d$, the self-attention block involves four projections (Q, K, V, and output) each costing $O(B \cdot n \cdot d^2)$ FLOPs (where $n$ is the sequence length), plus the attention matrix computation at $O(B \cdot n^2 \cdot d)$. The feed-forward block has two dense layers, typically expanding to $4d$ and back: $O(B \cdot n \cdot 8d^2)$. Total per layer: roughly $O(B \cdot n \cdot 12d^2 + B \cdot n^2 \cdot d)$. Multiply by the number of layers and you see why training GPT-scale models requires thousands of GPU-hours.
-
The memory wall is often the tighter constraint. During training, GPU memory must hold four things simultaneously:
-
Parameters: the model weights. A 7-billion parameter model in FP32 (4 bytes per parameter) needs 28 GB just for weights.
-
Gradients: same size as the parameters. Another 28 GB.
-
Optimizer states: Adam maintains two additional buffers (first and second moment estimates), each the size of the parameters. These are kept in FP32 for numerical stability, even when the model uses lower precision. For our 7B model, that is $2 \times 28 = 56$ GB.
-
Activations: intermediate values saved during the forward pass for use in the backward pass. The size depends on batch size, sequence length, and model width. This is often the largest component and grows linearly with batch size.
-
For our 7B model with FP32 Adam: 28 (params) + 28 (grads) + 56 (optimizer) = 112 GB, before we even count activations. A single 80 GB A100 GPU cannot hold this. This is why distributed strategies are essential.
-
Mixed precision training is the first line of defence. Instead of storing everything in FP32 (32-bit floating point), you train using FP16 or BF16 (16-bit) for the forward and backward passes, while keeping a master copy of weights in FP32 for the optimizer update.
-
FP16 has high precision (10-bit mantissa) but a limited range, which can cause overflow/underflow. Loss scaling (multiplying the loss by a large factor before the backward pass, then dividing gradients by the same factor) mitigates this.
-
BF16 (brain float) has the same exponent range as FP32 (8-bit exponent) but less precision (7-bit mantissa). It almost never overflows and rarely needs loss scaling, making it simpler to use. BF16 is the default for modern transformer training.
-
Mixed precision roughly halves the memory for activations and gradients (the dominant costs during forward/backward passes), while keeping optimizer states in FP32 for numerical stability.
-
Data parallelism is the simplest distributed strategy. You replicate the entire model on $N$ GPUs, split each mini-batch into $N$ equal chunks, and send one chunk to each GPU. Each GPU runs the forward and backward pass on its chunk independently. Then the gradients are averaged across all GPUs (using an all-reduce operation), and each GPU updates its local copy of the model.
-
From the model's perspective, this is equivalent to training with a mini-batch that is $N$ times larger. If each GPU processes a batch of size $B$, the effective batch size is $N \cdot B$.
-
The gradient averaging can be done synchronously or asynchronously. Synchronous SGD waits for all GPUs to finish before averaging, ensuring mathematical equivalence to single-GPU training with a larger batch. The downside is that the slowest GPU (the "straggler") holds everyone up.
-
Asynchronous SGD lets each GPU update a shared parameter server independently, without waiting. This eliminates the straggler problem but introduces "stale gradients": a GPU might compute gradients based on slightly outdated parameters. Stale gradients add noise and can slow convergence. In practice, synchronous SGD with efficient communication is preferred.
-
Gradient accumulation is a software trick for simulating larger batch sizes on limited hardware. Instead of doing one update per mini-batch, you run several forward/backward passes and accumulate the gradients, then do one update. This gives the same result as a larger batch without needing more GPU memory for activations (only one mini-batch of activations is in memory at a time).
-
When the model itself is too large to fit on a single GPU, you need model parallelism. There are two main flavours.
-
Tensor parallelism splits individual layers across GPUs. A large matrix multiply $Y = XW$ can be split column-wise: partition $W$ into $[W_1, W_2]$ across two GPUs, compute $Y_1 = XW_1$ and $Y_2 = XW_2$ in parallel, then concatenate. This works for attention projections and feed-forward layers. It requires fast communication between GPUs (typically NVLink within a node) because partial results must be combined at every layer.
-
Pipeline parallelism assigns different layers to different GPUs. GPU 0 runs layers 1-4, GPU 1 runs layers 5-8, and so on. Data flows through the pipeline like an assembly line. The naive approach has a "pipeline bubble": while GPU 0 processes the forward pass for micro-batch 1, GPUs 1-3 sit idle. Micro-batching mitigates this by splitting the mini-batch into smaller micro-batches that flow through the pipeline in sequence, keeping all GPUs busy most of the time.
-
Hybrid parallelism combines data, tensor, and pipeline parallelism. A typical large-model setup might use tensor parallelism within a node (8 GPUs connected by fast NVLink), pipeline parallelism across nodes, and data parallelism across groups of nodes. This is how models like GPT-4 and Llama are trained.
-
The efficiency of distributed training depends heavily on communication. The key operation is all-reduce: given a value on each of $N$ GPUs, compute the sum (or average) and distribute the result to all GPUs.
-
A naive all-reduce sends all data to one GPU, sums it, and broadcasts back. This is $O(N)$ in communication and creates a bottleneck at the root.
-
Ring all-reduce is much more efficient. Arrange the $N$ GPUs in a ring. Each GPU splits its data into $N$ chunks. In $N - 1$ steps, each GPU sends one chunk to its neighbour and receives a chunk from its other neighbour, accumulating partial sums. After another $N - 1$ steps, the full sum is propagated to all GPUs. Total data transferred per GPU: $2(N-1)/N$ times the data size, which approaches $2\times$ as $N$ grows. Crucially, this does not increase with $N$, making it bandwidth-optimal.
-
Parameter servers are an alternative architecture where dedicated server nodes hold the model parameters. Workers compute gradients and send them to the server, which updates parameters and sends them back. This is simpler but can create communication bottlenecks at the server.
-
NCCL (NVIDIA Collective Communications Library) is the standard library for GPU-to-GPU communication. It provides optimised implementations of all-reduce, all-gather, broadcast, and other collective operations, automatically choosing the best algorithm for the network topology.
-
Scaling laws describe how model performance improves with compute, data, and model size. The original Kaplan et al. (2020) scaling laws found that loss decreases as a power law with each:
$$L(N) \propto N^{-\alpha_N}, \quad L(D) \propto D^{-\alpha_D}, \quad L(C) \propto C^{-\alpha_C}$$
-
where $N$ is the number of parameters, $D$ is the dataset size, and $C$ is the compute budget.
-
The Chinchilla scaling laws (Hoffmann et al., 2022) showed that most models were undertrained: for a given compute budget, you should train a smaller model on more data than previously thought. The optimal ratio is roughly 20 tokens per parameter. A 7B model should see about 140B tokens, not the 300B tokens that Llama 1 used with a 65B model. This finding shifted the field toward "compute-optimal" training.
-
Mixture of Experts (MoE) is an architecture that scales model capacity without proportionally scaling compute. Instead of one feed-forward network per transformer layer, you have $N$ "expert" networks (each a standard FFN). A gating network (router) examines each token and sends it to the top-$K$ experts (typically $K = 1$ or $K = 2$).
-
The total parameter count is much larger (because you have $N$ experts), but the FLOPs per token stay roughly constant (because only $K$ experts activate per token). For example, Mixtral 8x7B has 47B total parameters but only uses about 13B per forward pass, giving the performance of a much larger model at the cost of a smaller one.
-
MoE introduces challenges. Load balancing: if the router sends most tokens to the same expert, the others are wasted. An auxiliary loss encourages uniform routing. Communication: different experts may live on different GPUs, so routing tokens requires all-to-all communication, which is expensive.
-
Fault tolerance is critical when training runs last weeks or months on thousands of GPUs. If a single GPU fails, you do not want to lose all progress. Checkpointing periodically saves model weights, optimizer states, and the training state (learning rate, step count, data position) to disk. If a failure occurs, you restart from the last checkpoint.
-
Gradient checkpointing (also called activation recomputation) is a memory optimisation, not a fault-tolerance mechanism. During the forward pass, instead of saving all activations for the backward pass, you only save activations at certain checkpoints. During the backward pass, you recompute the missing activations from the checkpoints. This trades compute for memory: it increases the forward-pass cost by roughly 33% but can reduce activation memory by a factor of $\sqrt{L}$ (where $L$ is the number of layers).
-
Putting it all together, training a frontier model combines all of these techniques: BF16 mixed precision, data parallelism across thousands of GPUs with ring all-reduce, tensor parallelism within nodes, pipeline parallelism across nodes, gradient checkpointing to reduce memory, MoE for parameter efficiency, and regular checkpointing for fault tolerance. The systems engineering is as challenging as the algorithm design.
-
To summarise the distributed training toolkit:
| Technique | What It Does | Tradeoff |
|---|---|---|
| Mixed precision (BF16) | Halves memory for activations/grads | Slight numerical differences |
| Data parallelism | Scales batch size across GPUs | Communication overhead for gradient sync |
| Tensor parallelism | Splits layers across GPUs | Requires fast interconnect |
| Pipeline parallelism | Splits model stages across GPUs | Pipeline bubble (wasted compute) |
| Gradient accumulation | Simulates large batches | Slower (multiple forward/backward passes) |
| Gradient checkpointing | Reduces activation memory | ~33% more compute |
| Ring all-reduce | Efficient gradient averaging | Bandwidth-limited for large models |
| MoE | More capacity, same FLOPs | Load balancing, routing complexity |
| Scaling laws | Guides compute allocation | Empirical, may not hold at all scales |
Coding Tasks (use CoLab or notebook)
- Compute the FLOPs and memory requirements for a transformer layer. Given hidden dimension $d$, sequence length $n$, batch size $B$, and number of layers, estimate the total training cost.
import jax.numpy as jnp
def transformer_layer_flops(d, n, B):
"""Approximate FLOPs for one transformer layer forward pass."""
# QKV projections: 3 * (B * n * d * d) * 2 (multiply-add)
qkv_flops = 3 * 2 * B * n * d * d
# Attention: (B * n * n * d) * 2 for QK^T, (B * n * n * d) * 2 for attn*V
attn_flops = 2 * 2 * B * n * n * d
# Output projection: (B * n * d * d) * 2
out_flops = 2 * B * n * d * d
# FFN: two layers, d->4d and 4d->d: 2 * (B * n * d * 4d) * 2
ffn_flops = 2 * 2 * B * n * d * 4 * d
return qkv_flops + attn_flops + out_flops + ffn_flops
def transformer_layer_memory(d, n, B, dtype_bytes=2):
"""Approximate activation memory (bytes) for one layer."""
# QKV: 3 * B * n * d
qkv_mem = 3 * B * n * d * dtype_bytes
# Attention weights: B * heads * n * n (approx B * n * n * sizeof)
attn_mem = B * n * n * dtype_bytes
# FFN intermediate: B * n * 4d
ffn_mem = B * n * 4 * d * dtype_bytes
return qkv_mem + attn_mem + ffn_mem
# Example: GPT-2 scale
d, n, B, L = 1024, 1024, 8, 24
fwd_flops = transformer_layer_flops(d, n, B)
total_flops = 3 * L * fwd_flops # 3x for forward + backward
act_mem = L * transformer_layer_memory(d, n, B)
param_count = L * (12 * d * d + 13 * d) # approximate
print(f"Model: d={d}, n={n}, B={B}, L={L}")
print(f"Parameters: {param_count / 1e6:.0f}M")
print(f"FLOPs per step: {total_flops / 1e12:.2f} TFLOPs")
print(f"Activation memory: {act_mem / 1e9:.2f} GB (BF16)")
print(f"Parameter memory (FP32): {param_count * 4 / 1e9:.2f} GB")
print(f"Adam optimizer memory: {param_count * 8 / 1e9:.2f} GB")
print(f"Total training memory: {(param_count * 16 + act_mem) / 1e9:.2f} GB")
- Simulate data-parallel training. Split a dataset across multiple "virtual GPUs," compute gradients independently, average them, and verify the result matches single-GPU training.
import jax
import jax.numpy as jnp
# Simple linear model: y = wx + b
key = jax.random.PRNGKey(0)
X = jax.random.normal(key, (64, 4))
w_true = jnp.array([1.0, -2.0, 3.0, 0.5])
y = X @ w_true + 0.1 * jax.random.normal(key, (64,))
def loss_fn(w, X, y):
return jnp.mean((X @ w - y) ** 2)
grad_fn = jax.grad(loss_fn)
# Single GPU: full batch gradient
w = jnp.zeros(4)
grad_single = grad_fn(w, X, y)
# Data parallel: split across 4 "GPUs"
n_gpus = 4
chunk_size = len(X) // n_gpus
grads = []
for i in range(n_gpus):
X_chunk = X[i*chunk_size:(i+1)*chunk_size]
y_chunk = y[i*chunk_size:(i+1)*chunk_size]
grads.append(grad_fn(w, X_chunk, y_chunk))
# All-reduce: average gradients
grad_parallel = jnp.mean(jnp.stack(grads), axis=0)
print("Single-GPU gradient:", grad_single)
print("Data-parallel gradient (avg):", grad_parallel)
print(f"Match: {jnp.allclose(grad_single, grad_parallel, atol=1e-5)}")
# Train both and compare
w_single, w_parallel = jnp.zeros(4), jnp.zeros(4)
lr = 0.1
for step in range(100):
w_single = w_single - lr * grad_fn(w_single, X, y)
grads = [grad_fn(w_parallel, X[i*chunk_size:(i+1)*chunk_size],
y[i*chunk_size:(i+1)*chunk_size]) for i in range(n_gpus)]
avg_grad = jnp.mean(jnp.stack(grads), axis=0)
w_parallel = w_parallel - lr * avg_grad
print(f"\nAfter 100 steps:")
print(f"Single-GPU weights: {w_single}")
print(f"Data-parallel weights: {w_parallel}")
print(f"Max difference: {jnp.max(jnp.abs(w_single - w_parallel)):.2e}")
- Implement a simple Mixture of Experts layer. Create a gating network that routes tokens to top-K experts and combine their outputs.
import jax
import jax.numpy as jnp
def expert_fn(x, W1, b1, W2, b2):
"""Simple 2-layer FFN expert."""
h = jnp.maximum(0, x @ W1 + b1) # ReLU
return h @ W2 + b2
def moe_layer(x, gate_W, experts_params, top_k=2):
"""
MoE forward pass.
x: (batch, d_model)
gate_W: (d_model, n_experts)
experts_params: list of (W1, b1, W2, b2) per expert
"""
n_experts = len(experts_params)
# Gating: compute routing scores
gate_logits = x @ gate_W # (batch, n_experts)
gate_probs = jax.nn.softmax(gate_logits, axis=-1)
# Top-K selection
top_k_indices = jnp.argsort(-gate_probs, axis=-1)[:, :top_k]
top_k_probs = jnp.take_along_axis(gate_probs, top_k_indices, axis=-1)
# Renormalise
top_k_probs = top_k_probs / jnp.sum(top_k_probs, axis=-1, keepdims=True)
# Compute expert outputs (simplified: run all experts, mask later)
expert_outputs = jnp.stack([
expert_fn(x, *experts_params[i]) for i in range(n_experts)
], axis=1) # (batch, n_experts, d_model)
# Gather top-K expert outputs and weight them
batch_idx = jnp.arange(x.shape[0])[:, None]
selected_outputs = expert_outputs[batch_idx, top_k_indices] # (batch, top_k, d_model)
output = jnp.sum(selected_outputs * top_k_probs[:, :, None], axis=1)
return output, gate_probs
# Setup
key = jax.random.PRNGKey(42)
batch, d_model, d_ff, n_experts = 8, 16, 32, 4
# Initialise experts
experts_params = []
for i in range(n_experts):
k1, k2, key = jax.random.split(key, 3)[0], jax.random.split(key, 3)[1], jax.random.split(key, 3)[2]
experts_params.append((
jax.random.normal(k1, (d_model, d_ff)) * 0.1,
jnp.zeros(d_ff),
jax.random.normal(k2, (d_ff, d_model)) * 0.1,
jnp.zeros(d_model),
))
key, subkey = jax.random.split(key)
gate_W = jax.random.normal(subkey, (d_model, n_experts)) * 0.1
x = jax.random.normal(key, (batch, d_model))
output, gate_probs = moe_layer(x, gate_W, experts_params, top_k=2)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Gate probabilities (first sample): {gate_probs[0]}")
print(f"Expert usage (avg across batch):")
for i in range(n_experts):
usage = jnp.mean(gate_probs[:, i])
print(f" Expert {i}: {usage:.3f}")
Linguistic Foundations
-
Before we can build systems that understand or generate language, we need to understand how language itself works.
-
Linguistics is the scientific study of language, and it provides the conceptual vocabulary that NLP borrows from constantly.
-
Even modern neural models, which learn language from raw data, implicitly rediscover many of the structures that linguists have catalogued for decades.
-
Language has structure at every level: the sounds that make up words, the parts that make up words, the rules that combine words into sentences, the meaning those sentences carry, and the way context shapes interpretation. We will work through each level from the bottom up.
-
Morphology is the study of the internal structure of words. Words are not atomic; they are built from smaller meaningful units called morphemes.
-
The word "unhappiness" contains three morphemes: "un-" (a prefix meaning "not"), "happy" (the root), and "-ness" (a suffix that turns an adjective into a noun). Each morpheme contributes to the meaning.
-
A root (or stem) is the core morpheme that carries the primary meaning. "Happy," "run," "compute" are roots.
-
An affix is a morpheme that attaches to a root to modify it.
-
English has prefixes (before the root: un-, re-, pre-) and suffixes (after: -ing, -ed, -tion). Some languages also have infixes (inserted inside the root) and circumfixes (wrapping around).
-
There are two kinds of morphological processes. Inflection changes the grammatical properties of a word without changing its core meaning or part of speech: "run" becomes "runs" (third person), "running" (progressive), "ran" (past tense). The word is still a verb meaning the same thing.
-
Derivation creates a new word, often changing the part of speech: "happy" (adjective) becomes "happiness" (noun), "compute" (verb) becomes "computation" (noun) becomes "computational" (adjective). Each derivation shifts meaning and grammatical category.
-
Languages vary enormously in morphological complexity. English is relatively analytic (few morphemes per word, relying on word order).
-
Turkish and Finnish are agglutinative (words can contain many morphemes strung together). Arabic and Hebrew use templatic morphology (roots are consonant skeletons like k-t-b for "write," and vowel patterns are inserted to create different words: kitab "book," kataba "he wrote," maktub "written").
-
Morphology matters for NLP because it affects tokenisation. A word-level tokeniser treats "run," "runs," "running," and "ran" as four unrelated symbols.
-
A morphologically-aware system recognises they share a root. Subword tokenisation (BPE, WordPiece), which we will cover in file 02, is a statistical approximation to morphological analysis.
-
Syntax is the study of how words combine into phrases and sentences. Every language has rules governing word order and structure; violating them produces gibberish.
-
"The cat sat on the mat" is grammatical English; "Mat the on sat cat the" is not.
-
There are two main frameworks for describing syntactic structure.
-
Phrase structure grammar (also called constituency grammar) says sentences are built by nesting phrases inside phrases. A sentence (S) consists of a noun phrase (NP) and a verb phrase (VP).
-
A noun phrase might be a determiner (Det) followed by a noun (N). A verb phrase might be a verb (V) followed by a noun phrase. These rules build a tree:
-
This tree is called a constituency tree (or parse tree). Each internal node is a phrase type, each leaf is a word. The tree captures the hierarchical grouping: "on the mat" is a unit (prepositional phrase), "sat on the mat" is a unit (verb phrase), and the whole thing is a sentence.
-
A context-free grammar (CFG) formalises these rules. It consists of a set of production rules, each of the form $A \to \alpha$, where $A$ is a non-terminal symbol (a phrase type like NP or VP) and $\alpha$ is a sequence of terminals (words) and non-terminals. For example:
S → NP VP
NP → Det N
NP → Det N PP
VP → V NP
VP → V PP
PP → P NP
Det → "the" | "a"
N → "cat" | "mat" | "dog"
V → "sat" | "chased"
P → "on" | "under"
-
Starting from S and repeatedly applying rules, you can generate all sentences the grammar allows. Parsing is the reverse: given a sentence, find the tree (or trees) that produced it. A sentence with multiple valid parse trees is syntactically ambiguous. "I saw the man with the telescope" has two parses: I used a telescope to see the man, or I saw a man who had a telescope.
-
Dependency grammar takes a different perspective. Instead of phrase nesting, it describes direct relationships between words. Each word in a sentence depends on exactly one other word (its head), except the root of the sentence. The result is a dependency tree where edges are labelled with grammatical relations (subject, object, modifier, etc.).
-
In the dependency view, "sat" is the root. "Cat" depends on "sat" as its subject (nsubj). "On" depends on "sat" as a prepositional modifier. "Mat" depends on "on" as the object of the preposition. Every word hangs off exactly one head, creating a tree.
-
Dependency grammar has become the dominant framework in modern NLP because dependency trees are easier to produce with statistical parsers and the relations map more directly to semantic roles (who did what to whom).
-
Valency describes how many arguments a verb requires. "Sleep" is intransitive (one argument: the sleeper). "Eat" is transitive (two: the eater and the eaten). "Give" is ditransitive (three: the giver, the thing given, and the receiver). Knowing a verb's valency constrains which parse trees are valid.
-
Semantics is the study of meaning. Syntax tells you how a sentence is structured; semantics tells you what it means.
-
Lexical semantics concerns the meaning of individual words. Words are related to each other in systematic ways:
- Synonymy: words with (nearly) the same meaning. "Big" and "large" are synonyms. True perfect synonymy is rare; there are almost always subtle differences in connotation or usage.
- Antonymy: words with opposite meanings. "Hot" and "cold," "buy" and "sell."
- Hypernymy/hyponymy: "is-a" relationships. "Dog" is a hyponym of "animal" (a dog is a kind of animal). "Animal" is a hypernym of "dog." These form taxonomic hierarchies.
- Meronymy: "part-of" relationships. "Wheel" is a meronym of "car."
- Polysemy: a single word with multiple related meanings. "Bank" means a financial institution or a river bank. Context disambiguates.
-
Word sense disambiguation (WSD) is the task of determining which sense of a polysemous word is intended in a given context. In "I deposited money at the bank," the financial sense is correct. In "We sat by the river bank," the geographical sense is. WSD was a central problem in early NLP; modern contextual embeddings (ELMo, BERT) largely solve it by producing different vector representations for different uses of the same word.
-
Compositional semantics asks how the meanings of individual words combine to form the meaning of a phrase or sentence. The principle of compositionality (attributed to Frege) states that the meaning of a complex expression is determined by the meanings of its parts and the rules used to combine them. "The cat chased the dog" means something different from "the dog chased the cat" because the syntactic structure (who is the subject vs object) interacts with the word meanings.
-
Not all meaning is compositional. Idioms like "kick the bucket" (meaning "to die") have meanings that cannot be derived from their parts. These are a challenge for any compositional approach.
-
Distributional semantics is the computational approach to meaning that underpins modern NLP. The distributional hypothesis (Firth, 1957) states: "You shall know a word by the company it keeps." Words that appear in similar contexts tend to have similar meanings. This is the theoretical foundation for word embeddings (Word2Vec, GloVe), which we will explore in file 03.
-
Pragmatics studies how context affects meaning. The same sentence can mean different things depending on who says it, when, where, and why.
-
"Can you pass the salt?" is syntactically a yes/no question about ability. Pragmatically, it is a request. You would not answer "Yes, I can" and then sit still. Understanding this requires knowledge beyond the literal words, specifically, the conventions of speech acts.
-
Speech act theory (Austin, Searle) distinguishes between:
- Locutionary act: the literal content ("Can you pass the salt?")
- Illocutionary act: the intended function (a request)
- Perlocutionary act: the effect on the listener (they pass the salt)
-
Implicature (Grice) is meaning that is implied but not explicitly stated. If someone asks "Is John a good cook?" and you reply "He's British," you have not answered the question literally, but the listener can infer (through cultural stereotypes, fairly or not) that you mean "no." Grice's cooperative principle says speakers generally try to be informative, truthful, relevant, and clear, and listeners interpret utterances assuming these maxims hold.
-
Coreference is a pragmatic phenomenon where different expressions refer to the same entity. In "Alice went to the store. She bought milk," "she" refers to Alice. Resolving coreference is essential for understanding multi-sentence text and is a key NLP task.
-
Discourse structure describes how sentences connect to form coherent text. A narrative has a beginning, middle, and end. An argument has claims and evidence. Rhetorical Structure Theory (RST) analyses text as a tree of discourse relations (elaboration, contrast, cause, etc.) between segments.
-
Pragmatics is where NLP gets hardest. Modern language models handle much of syntax and semantics implicitly through training data, but pragmatic reasoning, understanding sarcasm, implicature, and context-dependent meaning, remains a frontier challenge.
-
Phonology studies the sound systems of languages. While this chapter focuses on text, a brief overview bridges to the audio and speech chapter (Chapter 09).
-
A phoneme is the smallest unit of sound that distinguishes meaning. English has about 44 phonemes. The words "bat" and "pat" differ by one phoneme (/b/ vs /p/), which changes the meaning entirely. This is called a minimal pair.
-
Allophones are different physical realisations of the same phoneme that do not change meaning. The "p" in "pin" (aspirated, with a puff of air) and the "p" in "spin" (unaspirated) are allophones of /p/ in English; a native speaker treats them as the same sound.
-
The International Phonetic Alphabet (IPA) provides a standardised notation for phonemes across all languages. The word "cat" is transcribed as /kæt/. IPA is the bridge between written text and speech systems.
-
Prosody covers the rhythm, stress, and intonation of speech. "I didn't say he stole the money" has seven different meanings depending on which word is stressed. Prosody carries information that text alone loses, which is why text-to-speech systems must model it carefully.
-
In NLP, phonological knowledge appears in text-to-speech (grapheme-to-phoneme conversion), speech recognition (mapping acoustic signals to phonemes), and even in spelling correction and transliteration.
Coding Tasks (use CoLab or notebook)
- Build a simple morphological analyser that splits English words into likely morphemes using a list of common prefixes and suffixes.
prefixes = ['un', 're', 'pre', 'dis', 'mis', 'over', 'under', 'out', 'non']
suffixes = ['ing', 'ed', 'ly', 'ness', 'ment', 'tion', 'able', 'ible', 'er', 'est', 'ful', 'less', 'ous']
def analyse_morphemes(word):
"""Simple morpheme analysis using known affixes."""
parts = []
remaining = word.lower()
# Check prefixes
for p in sorted(prefixes, key=len, reverse=True):
if remaining.startswith(p) and len(remaining) > len(p) + 2:
parts.append(f"[prefix: {p}]")
remaining = remaining[len(p):]
break
# Check suffixes
for s in sorted(suffixes, key=len, reverse=True):
if remaining.endswith(s) and len(remaining) > len(s) + 2:
root = remaining[:-len(s)]
parts.append(f"[root: {root}]")
parts.append(f"[suffix: {s}]")
remaining = None
break
if remaining is not None:
parts.append(f"[root: {remaining}]")
return parts
for word in ['unhappiness', 'reusable', 'disconnected', 'overreacting', 'kindness']:
print(f"{word:20s} → {' + '.join(analyse_morphemes(word))}")
- Implement a simple context-free grammar parser using recursive descent. Define a small grammar and parse a sentence into a constituency tree.
class CFGParser:
"""Recursive descent parser for a tiny English grammar."""
def __init__(self, tokens):
self.tokens = tokens
self.pos = 0
def peek(self):
return self.tokens[self.pos] if self.pos < len(self.tokens) else None
def consume(self, expected=None):
tok = self.peek()
if expected and tok != expected:
return None
self.pos += 1
return tok
def parse_det(self):
if self.peek() in ('the', 'a'):
return ('Det', self.consume())
return None
def parse_noun(self):
if self.peek() in ('cat', 'dog', 'mat', 'man'):
return ('N', self.consume())
return None
def parse_verb(self):
if self.peek() in ('sat', 'chased', 'saw'):
return ('V', self.consume())
return None
def parse_prep(self):
if self.peek() in ('on', 'under', 'with'):
return ('P', self.consume())
return None
def parse_np(self):
save = self.pos
det = self.parse_det()
noun = self.parse_noun()
if det and noun:
# Check for optional PP
pp = self.parse_pp()
if pp:
return ('NP', det, noun, pp)
return ('NP', det, noun)
self.pos = save
return None
def parse_pp(self):
save = self.pos
prep = self.parse_prep()
np = self.parse_np()
if prep and np:
return ('PP', prep, np)
self.pos = save
return None
def parse_vp(self):
save = self.pos
verb = self.parse_verb()
if verb:
np = self.parse_np()
if np:
return ('VP', verb, np)
pp = self.parse_pp()
if pp:
return ('VP', verb, pp)
self.pos = save
return None
def parse_sentence(self):
np = self.parse_np()
vp = self.parse_vp()
if np and vp and self.pos == len(self.tokens):
return ('S', np, vp)
return None
def print_tree(tree, indent=0):
if isinstance(tree, str):
print(' ' * indent + tree)
elif isinstance(tree, tuple):
print(' ' * indent + tree[0])
for child in tree[1:]:
print_tree(child, indent + 2)
sentences = [
"the cat sat on the mat",
"a dog chased the cat",
]
for sent in sentences:
tokens = sent.split()
parser = CFGParser(tokens)
tree = parser.parse_sentence()
print(f"\n'{sent}':")
if tree:
print_tree(tree)
else:
print(" (no parse found)")
- Explore lexical relations by building a simple word graph. Given a small vocabulary with synonym, antonym, and hypernym relations, find paths between words.
relations = {
('big', 'large'): 'synonym',
('big', 'small'): 'antonym',
('small', 'tiny'): 'synonym',
('dog', 'animal'): 'hypernym',
('cat', 'animal'): 'hypernym',
('puppy', 'dog'): 'hypernym',
('happy', 'glad'): 'synonym',
('happy', 'sad'): 'antonym',
('hot', 'cold'): 'antonym',
('hot', 'warm'): 'synonym',
}
# Build adjacency list
from collections import defaultdict, deque
graph = defaultdict(list)
for (w1, w2), rel in relations.items():
graph[w1].append((w2, rel))
graph[w2].append((w1, rel))
def find_path(start, end):
"""BFS to find a path between two words through the relation graph."""
queue = deque([(start, [(start, None)])])
visited = {start}
while queue:
node, path = queue.popleft()
if node == end:
return path
for neighbor, rel in graph[node]:
if neighbor not in visited:
visited.add(neighbor)
queue.append((neighbor, path + [(neighbor, rel)]))
return None
pairs = [('big', 'tiny'), ('puppy', 'cat'), ('happy', 'sad')]
for w1, w2 in pairs:
path = find_path(w1, w2)
if path:
steps = " → ".join(f"{w}({r})" if r else w for w, r in path)
print(f"{w1} → {w2}: {steps}")
else:
print(f"{w1} → {w2}: no path found")
Text Processing and Classic NLP
-
Raw text is messy. Before any NLP model can work with language, the text must be cleaned, normalised, and converted into a structured representation. This file covers the pipeline from raw characters to features that models can consume, along with the classical NLP algorithms that dominated before deep learning.
-
Text normalisation transforms raw text into a canonical form. The goal is to reduce irrelevant variation so that "Hello", "hello", "HELLO" and "héllo" are treated appropriately.
-
Case folding converts text to lowercase. This collapses "The" and "the" into one token. It helps for most tasks, but destroys useful information in some cases: "US" (the country) vs "us" (the pronoun), or "Apple" (the company) vs "apple" (the fruit).
-
Unicode normalisation handles the fact that the same character can be encoded multiple ways. The character "é" can be a single code point (U+00E9) or a base "e" plus a combining accent (U+0065 + U+0301). NFC normalisation composes them into one code point; NFD decomposes them. Without normalisation, two identical-looking strings may not match.
-
Edit distance measures how different two strings are. The Levenshtein distance counts the minimum number of single-character insertions, deletions, and substitutions needed to transform one string into another. "kitten" → "sitting" has edit distance 3 (k→s, e→i, insert g).
-
Edit distance is computed using dynamic programming (we review in alrothim chapter). Define $D[i][j]$ as the distance between the first $i$ characters of string $s$ and the first $j$ characters of string $t$:
-
Edit distance powers spelling correction, fuzzy matching, and DNA sequence alignment. In NLP, it is used to handle typos and find similar words.
-
Tokenisation splits text into discrete units (tokens) that a model can process. This is the first and arguably most important preprocessing step. The choice of tokenisation strategy profoundly affects model behaviour.
-
Whitespace tokenisation splits on spaces. Simple but naive: "New York" becomes two tokens, "don't" is one token (or split into "don" and "'t" depending on the splitter), and languages like Chinese and Japanese have no spaces between words at all.
-
Rule-based tokenisation uses handcrafted patterns (regular expressions) to handle contractions, punctuation, and special cases. "I'm" → "I" + "'m", "U.S.A." stays as one token. Every language needs its own rules, which is labour-intensive.
-
Subword tokenisation is the modern solution. Instead of splitting at word boundaries, it learns a vocabulary of frequent subword units from data. This elegantly handles unknown words: if "unhappiness" is not in the vocabulary, it might be split into "un" + "happi" + "ness", preserving morphological structure.
-
Byte-Pair Encoding (BPE) starts with individual characters as the vocabulary. It repeatedly finds the most frequent adjacent pair and merges them into a new token. After enough merges, common words are single tokens and rare words are split into frequent subword pieces.
-
The BPE algorithm:
- Initialise the vocabulary with all individual characters in the training corpus
- Count the frequency of every adjacent token pair
- Merge the most frequent pair into a new token
- Repeat steps 2-3 for a desired number of merges (vocabulary size)
-
For example, starting with "l o w" (5 times), "l o w e r" (2 times), "n e w e s t" (6 times): the most frequent pair might be "e s" → merge into "es". Then "es t" → "est". Then "n e w" → "new". The final vocabulary contains both full words and subword pieces.
-
WordPiece (used by BERT) is similar to BPE but selects merges based on likelihood rather than frequency. It merges the pair that maximises the language model likelihood of the training data. Subword tokens that are not word-initial are prefixed with "##" (e.g., "playing" → "play" + "##ing").
-
Unigram (used by SentencePiece) takes the opposite approach: start with a large vocabulary and iteratively remove tokens whose removal least hurts the training data likelihood. The final vocabulary is the set of subword units that best explain the corpus.
-
SentencePiece is a language-agnostic tokenisation library that treats the input as a raw byte stream (no pre-tokenisation on spaces). This makes it work for any language, including those without spaces. It implements both BPE and Unigram algorithms.
-
The vocabulary size is a key hyperparameter. Typical choices range from 30,000 to 100,000 tokens. Larger vocabularies mean fewer tokens per sequence (more efficient) but a larger embedding table. Smaller vocabularies mean more subword splits and longer sequences.
-
Both techniques reduce words to a base form, but they differ in approach.
-
Stemming chops off suffixes using crude rules. The Porter stemmer reduces "running" to "run", "happiness" to "happi", and "studies" to "studi". It is fast but imprecise: "university" and "universe" both stem to "univers" despite being unrelated.
-
Lemmatisation uses vocabulary and morphological analysis to find the true dictionary form (lemma). "Running" → "run", "better" → "good", "mice" → "mouse". It requires knowing the part of speech: "saw" as a verb lemmatises to "see", but as a noun it stays "saw".
-
Modern subword tokenisation has largely replaced stemming and lemmatisation in neural NLP, but they remain useful in information retrieval and when working with smaller models or limited data.
-
Part-of-speech (POS) tagging assigns a grammatical category to each word: noun, verb, adjective, determiner, etc. This is one of the oldest NLP tasks and is fundamental to syntactic analysis.
-
The Penn Treebank tagset is the most common for English, with 36 tags (NN for singular noun, NNS for plural noun, VB for base verb, VBD for past tense, JJ for adjective, etc.).
-
POS tagging is tricky because many words are ambiguous. "Book" can be a noun ("the book") or a verb ("book a flight"). "Run" has dozens of senses across parts of speech. Context is essential.
-
Early taggers used Hidden Markov Models (HMMs) from chapter 05. The hidden states are POS tags, the observations are words. The transition probabilities capture tag sequences (a determiner is likely followed by a noun or adjective), and the emission probabilities capture which words appear with which tags. The Viterbi algorithm finds the most likely tag sequence.
-
The HMM model for POS tagging:
$$\hat{t}{1:n} = \arg\max{t_{1:n}} \prod_{i=1}^{n} P(w_i \mid t_i) \cdot P(t_i \mid t_{i-1})$$
-
Modern POS taggers use neural networks (bidirectional LSTMs or transformers) and achieve over 97% accuracy on English, approaching human performance.
-
Named Entity Recognition (NER) identifies and classifies proper names and other specific entities in text: persons, organisations, locations, dates, monetary amounts, etc.
-
In "Apple CEO Tim Cook announced the event in Cupertino on Monday," a NER system should identify: Apple (ORG), Tim Cook (PER), Cupertino (LOC), Monday (DATE).
-
NER is typically framed as sequence labelling using BIO tagging (also called IOB tagging). Each token gets a tag:
- B-TYPE: beginning of an entity of type TYPE
- I-TYPE: inside (continuation of) an entity of type TYPE
- O: outside any entity
-
"Tim Cook visited New York" becomes: Tim/B-PER Cook/I-PER visited/O New/B-LOC York/I-LOC. The B tag marks where a new entity starts, which is important when two entities of the same type are adjacent.
- Classical NER used Conditional Random Fields (CRFs) from chapter 05, which model the conditional probability of the entire tag sequence given the input. Unlike HMMs, which are generative ($P(x, y)$), CRFs are discriminative and model $P(y \mid x)$ directly. A linear-chain CRF defines:
$$P(y_{1:n} \mid x_{1:n}) = \frac{1}{Z(x)} \exp!\left(\sum_{i=1}^{n} \left[\sum_k \lambda_k f_k(y_i, x, i) + \sum_j \mu_j g_j(y_i, y_{i-1}, x, i)\right]\right)$$
-
Here $f_k$ are emission features (how likely tag $y_i$ is given the input at position $i$) and $g_j$ are transition features (how likely tag $y_i$ is given the previous tag $y_{i-1}$).
-
The partition function $Z(x) = \sum_{y'} \exp(\ldots)$ sums over all possible tag sequences to normalise the distribution. Training maximises the conditional log-likelihood, which requires computing $Z(x)$ efficiently using the forward algorithm (chapter 05).
-
The key advantage over independently classifying each token: the CRF's transition features enforce structural constraints (e.g., I-PER should only follow B-PER or I-PER, never appear after O).
-
Modern NER stacks a CRF on top of a neural encoder (BiLSTM-CRF or BERT-CRF), where the neural network produces the emission scores and the CRF layer learns transition structure.
-
Syntactic parsing converts a sentence into its syntactic structure, either a constituency tree or a dependency tree (both from file 01).
-
The CYK algorithm (Cocke-Younger-Kasami) parses sentences with context-free grammars using dynamic programming.
-
It requires the grammar to be in Chomsky Normal Form (every rule has either two non-terminals or one terminal on the right side). It fills a triangular table bottom-up: cells represent spans of the sentence, and each cell stores the non-terminals that can generate that span.
-
CYK runs in $O(n^3 \cdot |G|)$ time, where $n$ is the sentence length and $|G|$ is the grammar size. This is exact but slow for large grammars.
-
Shift-reduce parsing processes the sentence left to right, maintaining a stack. At each step, it either shifts (pushes the next word onto the stack) or reduces (pops elements from the stack and replaces them with a phrase). A trained classifier decides the action at each step. This runs in $O(n)$ time, making it much faster than CYK.
-
Dependency parsing is now more common than constituency parsing in practice. Transition-based dependency parsers (like shift-reduce) and graph-based parsers (which score all possible edges and find the maximum spanning tree) are the two main approaches. Neural dependency parsers using BiLSTMs or transformers achieve state-of-the-art results.
-
Before embeddings, NLP represented documents as vectors using simple counting methods.
-
The bag-of-words (BoW) model represents a document as a vector of word counts, ignoring word order entirely. If the vocabulary has $V$ words, each document is a vector in $\mathbb{R}^V$ (connecting back to vector spaces from chapter 01). The entry for word $w$ is the number of times $w$ appears in the document.
-
BoW is simple but surprisingly effective for tasks like document classification and spam filtering. Its main weakness is that it treats every word as equally important: "the" and "revolutionary" get equal weight.
-
TF-IDF (Term Frequency-Inverse Document Frequency) fixes this by weighting words based on how informative they are. Words that appear frequently in one document but rarely across the corpus are likely important for that document.
$$\text{TF-IDF}(t, d) = \text{TF}(t, d) \times \text{IDF}(t)$$
-
Term frequency $\text{TF}(t, d)$ is often the raw count of term $t$ in document $d$ (or its log: $1 + \log(\text{count})$).
-
Inverse document frequency $\text{IDF}(t) = \log\frac{N}{|{d : t \in d}|}$, where $N$ is the total number of documents. Words appearing in every document (like "the") get IDF close to 0. Rare words get high IDF.
-
TF-IDF vectors can be compared using cosine similarity (from chapter 01) to measure document similarity. This is the foundation of classical information retrieval and search engines.
-
A language model assigns a probability to a sequence of words. It answers: how likely is this sentence? Language models are central to machine translation, speech recognition, spelling correction, and text generation.
-
The probability of a sentence $w_1, w_2, \ldots, w_n$ is, by the chain rule of probability (chapter 05):
$$P(w_1, w_2, \ldots, w_n) = \prod_{i=1}^{n} P(w_i \mid w_1, \ldots, w_{i-1})$$
-
This is exact but impractical: you would need to store probabilities for every possible history. The Markov assumption (chapter 05) truncates the history to the last $k-1$ words, giving an n-gram model (where $n = k$).
-
A bigram model ($n = 2$) conditions only on the previous word:
$$P(w_i \mid w_1, \ldots, w_{i-1}) \approx P(w_i \mid w_{i-1})$$
- A trigram model ($n = 3$) conditions on the previous two words. N-gram probabilities are estimated by counting in a corpus:
$$P(w_i \mid w_{i-1}) = \frac{\text{count}(w_{i-1}, w_i)}{\text{count}(w_{i-1})}$$
- Perplexity measures how well a language model predicts a test set. It is the inverse probability of the test set, normalised by the number of words:
$$\text{PPL} = P(w_1, \ldots, w_N)^{-1/N} = \exp!\left(-\frac{1}{N} \sum_{i=1}^{N} \log P(w_i \mid w_{<i})\right)$$
-
Lower perplexity means the model is less "surprised" by the test data and therefore better. A model that assigns uniform probability over a 10,000-word vocabulary has perplexity 10,000. A good bigram model might achieve perplexity around 200. Modern neural language models achieve perplexity below 20.
-
Notice that perplexity is the exponentiated cross-entropy (from chapter 05's information theory). Minimising cross-entropy loss during training directly minimises perplexity.
-
Smoothing handles the zero-probability problem: if an n-gram never appeared in training, the model assigns it probability 0, which makes the entire sentence probability 0. Laplace smoothing (add-1) adds a small count to every n-gram:
$$P_{\text{Laplace}}(w_i \mid w_{i-1}) = \frac{\text{count}(w_{i-1}, w_i) + 1}{\text{count}(w_{i-1}) + V}$$
-
This is too aggressive for large vocabularies (it steals too much probability from observed n-grams). Kneser-Ney smoothing is the gold standard for n-gram models. It combines two ideas: absolute discounting and a continuation probability for backoff.
-
First, absolute discounting subtracts a fixed discount $d$ (typically $d \approx 0.75$) from each observed count, rather than adding pseudocounts. The freed probability mass is redistributed to unseen n-grams. The interpolated form is:
$$P_{\text{KN}}(w_i \mid w_{i-1}) = \frac{\max(\text{count}(w_{i-1}, w_i) - d, ; 0)}{\text{count}(w_{i-1})} + \lambda(w_{i-1}) \cdot P_{\text{cont}}(w_i)$$
- where $\lambda(w_{i-1})$ is a normalising constant that distributes the discounted mass. The key innovation is the continuation probability $P_{\text{cont}}(w_i)$, which measures how many different contexts $w_i$ appears in, rather than how often it appears overall:
$$P_{\text{cont}}(w_i) = \frac{|{w' : \text{count}(w', w_i) > 0}|}{|{(w', w'') : \text{count}(w', w'') > 0}|}$$
-
The numerator counts how many distinct words precede $w_i$ in the corpus. A word like "Francisco" appears in few contexts (almost always after "San"), so even if "San Francisco" is very frequent, "Francisco" gets a low continuation probability and will not be predicted spuriously in other contexts.
-
Conversely, common words like "the" appear after many different words and get high continuation probability. This captures the intuition that a word's versatility matters more than its raw frequency for backoff estimation.
-
N-gram models were the state of the art for decades. They are fast, interpretable, and require no training (just counting). But they struggle with long-range dependencies ("The keys that I left on the table are missing" requires knowing the subject "keys" is plural, which is far from the verb). Neural language models, starting with RNNs and culminating in transformers, address this limitation.
Coding Tasks (use CoLab or notebook)
- Implement the Levenshtein edit distance using dynamic programming. Test it on word pairs and use it for simple spelling correction.
import jax.numpy as jnp
def edit_distance(s, t):
"""Compute Levenshtein edit distance using DP."""
m, n = len(s), len(t)
D = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(m + 1):
D[i][0] = i
for j in range(n + 1):
D[0][j] = j
for i in range(1, m + 1):
for j in range(1, n + 1):
if s[i-1] == t[j-1]:
D[i][j] = D[i-1][j-1]
else:
D[i][j] = 1 + min(D[i-1][j], D[i][j-1], D[i-1][j-1])
return D[m][n]
# Test
pairs = [("kitten", "sitting"), ("sunday", "saturday"), ("hello", "hallo")]
for s, t in pairs:
print(f"d('{s}', '{t}') = {edit_distance(s, t)}")
# Simple spelling correction
dictionary = ["the", "their", "there", "then", "than", "this", "that", "these", "those"]
misspelled = "thier"
corrections = sorted(dictionary, key=lambda w: edit_distance(misspelled, w))
print(f"\nClosest to '{misspelled}': {corrections[:3]}")
- Implement BPE tokenisation from scratch. Start with character-level tokens and iteratively merge the most frequent pairs.
from collections import Counter
def get_pairs(corpus):
"""Count adjacent token pairs across all words."""
pairs = Counter()
for word, freq in corpus.items():
symbols = word.split()
for i in range(len(symbols) - 1):
pairs[(symbols[i], symbols[i+1])] += freq
return pairs
def merge_pair(pair, corpus):
"""Merge all occurrences of a pair in the corpus."""
new_corpus = {}
bigram = ' '.join(pair)
replacement = ''.join(pair)
for word, freq in corpus.items():
new_word = word.replace(bigram, replacement)
new_corpus[new_word] = freq
return new_corpus
# Training corpus with word frequencies
text = "low low low low low lower lower newest newest newest newest newest newest"
word_freqs = Counter(text.split())
# Initialise: split each word into characters with end-of-word marker
corpus = {' '.join(word) + ' _': freq for word, freq in word_freqs.items()}
print("Initial corpus:")
for word, freq in corpus.items():
print(f" {word}: {freq}")
# Run BPE for 10 merges
for i in range(10):
pairs = get_pairs(corpus)
if not pairs:
break
best_pair = max(pairs, key=pairs.get)
corpus = merge_pair(best_pair, corpus)
print(f"\nMerge {i+1}: {best_pair} (freq={pairs[best_pair]})")
for word, freq in corpus.items():
print(f" {word}: {freq}")
- Build a bigram language model and compute perplexity on a test sentence. Experiment with Laplace smoothing.
from collections import Counter, defaultdict
import math
# Training corpus
train = """the cat sat on the mat . the dog chased the cat .
the cat ran from the dog . a dog sat on a mat .""".split()
# Count bigrams and unigrams
bigrams = Counter(zip(train[:-1], train[1:]))
unigrams = Counter(train)
vocab_size = len(set(train))
def bigram_prob(w2, w1, alpha=0):
"""P(w2 | w1) with optional Laplace smoothing."""
return (bigrams[(w1, w2)] + alpha) / (unigrams[w1] + alpha * vocab_size)
# Compute perplexity
test = "the cat sat on a mat .".split()
for alpha in [0, 1, 0.1]:
log_prob = 0
for w1, w2 in zip(test[:-1], test[1:]):
p = bigram_prob(w2, w1, alpha=alpha)
if p > 0:
log_prob += math.log(p)
else:
log_prob += float('-inf')
ppl = math.exp(-log_prob / (len(test) - 1)) if log_prob > float('-inf') else float('inf')
print(f"Smoothing α={alpha}: perplexity = {ppl:.2f}")
- Implement TF-IDF from scratch and use cosine similarity to find the most similar document to a query.
import jax.numpy as jnp
import math
from collections import Counter
documents = [
"the cat sat on the mat",
"the dog chased the cat around the park",
"a mat was placed on the floor by the door",
"the quick brown fox jumped over the lazy dog",
]
# Build vocabulary
vocab = sorted(set(word for doc in documents for word in doc.split()))
word_to_idx = {w: i for i, w in enumerate(vocab)}
V = len(vocab)
N = len(documents)
# Compute TF-IDF matrix
doc_freq = Counter()
for doc in documents:
for word in set(doc.split()):
doc_freq[word] += 1
tfidf_matrix = jnp.zeros((N, V))
for i, doc in enumerate(documents):
word_counts = Counter(doc.split())
for word, count in word_counts.items():
tf = 1 + math.log(count)
idf = math.log(N / doc_freq[word])
j = word_to_idx[word]
tfidf_matrix = tfidf_matrix.at[i, j].set(tf * idf)
# Query
query = "cat on the mat"
query_vec = jnp.zeros(V)
query_counts = Counter(query.split())
for word, count in query_counts.items():
if word in word_to_idx:
tf = 1 + math.log(count)
idf = math.log(N / doc_freq.get(word, 1))
query_vec = query_vec.at[word_to_idx[word]].set(tf * idf)
# Cosine similarity (from chapter 01)
def cosine_sim(a, b):
return jnp.dot(a, b) / (jnp.linalg.norm(a) * jnp.linalg.norm(b) + 1e-8)
print(f"Query: '{query}'\n")
for i, doc in enumerate(documents):
sim = cosine_sim(query_vec, tfidf_matrix[i])
print(f" Doc {i} (sim={sim:.3f}): '{doc}'")
Embeddings and Sequence Models
-
In file 01, we introduced the distributional hypothesis: words that appear in similar contexts tend to have similar meanings. In file 02, we represented text using sparse, hand-crafted features like TF-IDF vectors. These vectors live in very high-dimensional spaces (one dimension per vocabulary word) and are mostly zeros. Word embeddings compress this information into dense, low-dimensional vectors that capture semantic relationships, and they are learned directly from data.
-
Word2Vec (Mikolov et al., 2013) learns word embeddings by training a shallow neural network on a simple prediction task. There are two architectures.
-
The Continuous Bag of Words (CBOW) model predicts a target word from its surrounding context words. Given a window of context words (e.g., "the cat ___ on the"), the model averages their embedding vectors and passes the result through a linear layer to predict the missing word ("sat"). The training objective maximises:
$$P(w_t \mid w_{t-k}, \ldots, w_{t-1}, w_{t+1}, \ldots, w_{t+k})$$
- The Skip-gram model does the reverse: given a target word, predict the surrounding context words. For the target word "sat", the model tries to predict "the", "cat", "on", "the" in separate predictions. The objective maximises:
$$P(w_{t+j} \mid w_t) \quad \text{for each } j \in [-k, k], ; j \neq 0$$
-
Skip-gram tends to work better for rare words because each word generates multiple training examples (one per context position). CBOW is faster and slightly better for frequent words because it averages over multiple context signals.
-
Training on the full vocabulary is expensive because the softmax denominator sums over all $V$ words. Negative sampling approximates this by turning the problem into binary classification: distinguish the true context word (positive sample) from randomly sampled noise words (negative samples). Instead of computing the full softmax, the model only updates embeddings for the target, the true context word, and a handful of negatives:
$$\mathcal{L} = \log \sigma(v_{w_O}^T v_{w_I}) + \sum_{i=1}^{k} \mathbb{E}{w_i \sim P_n} [\log \sigma(-v{w_i}^T v_{w_I})]$$
-
Here $v_{w_I}$ is the input word embedding, $v_{w_O}$ is the output (context) word embedding, and $P_n$ is the noise distribution, typically the unigram frequency raised to the 3/4 power (which downweights very frequent words like "the").
-
Why does this simple objective produce meaningful embeddings? Levy and Goldberg (2014) showed that skip-gram with negative sampling is implicitly factorising a shifted pointwise mutual information (PMI) matrix. At convergence, the dot product of two word vectors approximates:
$$v_w^T v_c \approx \text{PMI}(w, c) - \log k$$
-
where $\text{PMI}(w, c) = \log \frac{P(w, c)}{P(w) P(c)}$ measures how much more likely words $w$ and $c$ co-occur than expected by chance (chapter 05 information theory), and $k$ is the number of negative samples. Words that co-occur much more than chance have high PMI and therefore high dot product (similar embeddings). Words that co-occur less than expected have negative PMI and dissimilar embeddings. This reveals that Word2Vec is doing the same thing as classical distributional semantics methods like latent semantic analysis (SVD on co-occurrence matrices), but in a more scalable, online fashion.
-
The most surprising property of Word2Vec embeddings is that they capture analogies through vector arithmetic. The vector $v_{\text{king}} - v_{\text{man}} + v_{\text{woman}}$ is closest to $v_{\text{queen}}$. This works because the embedding space encodes semantic relationships as approximately linear directions: the "royalty" direction is roughly $v_{\text{king}} - v_{\text{man}}$, and adding it to $v_{\text{woman}}$ lands near $v_{\text{queen}}$. This connects to the linear algebra of chapter 01: semantic relationships are vector translations.
-
GloVe (Global Vectors for Word Representation, Pennington et al., 2014) takes a different approach. Instead of learning from local context windows one at a time, it builds a global word co-occurrence matrix $X$ where $X_{ij}$ counts how often word $j$ appears in the context of word $i$ across the entire corpus. The model then learns embeddings whose dot product approximates the log co-occurrence:
$$w_i^T \tilde{w}_j + b_i + \tilde{b}j = \log X{ij}$$
- The loss function weights each pair by a capping function $f(X_{ij})$ that prevents very frequent co-occurrences from dominating:
$$\mathcal{L} = \sum_{i,j=1}^{V} f(X_{ij}) \left(w_i^T \tilde{w}_j + b_i + \tilde{b}j - \log X{ij}\right)^2$$
-
GloVe combines the benefits of global matrix factorisation (like latent semantic analysis) with the local context learning of Word2Vec. In practice, GloVe and Word2Vec produce embeddings of similar quality.
-
FastText (Bojanowski et al., 2017) extends skip-gram by representing each word as a bag of character n-grams. The word "where" with $n = 3$ becomes: "<wh", "whe", "her", "ere", "re>", plus the whole-word token "
". The word's embedding is the sum of all its n-gram embeddings. -
This has a crucial advantage: FastText can produce embeddings for words it has never seen during training. The word "whereabouts" shares n-grams with "where", so its embedding will be reasonable even if "whereabouts" never appeared in the training data. This is especially useful for morphologically rich languages (file 01) where words have many inflected forms.
-
Embedding evaluation typically uses two types of benchmarks. Analogy tasks test whether $v_a - v_b + v_c \approx v_d$ (e.g., "Paris" $-$ "France" $+$ "Italy" $\approx$ "Rome"). Similarity benchmarks compare the cosine similarity (chapter 01) between word pairs to human judgements. Common datasets include WordSim-353, SimLex-999, and the Google analogy test set. A practical caveat: embeddings that excel at analogies may not be best for downstream tasks like sentiment classification. The best evaluation is often the task itself.
-
In chapter 06, we introduced RNNs, LSTMs, and GRUs as architectures for sequential data. Here we focus on how they are applied to language tasks specifically.
-
A language model RNN reads tokens one at a time and predicts the next token at each step. The hidden state $h_t$ compresses the entire history $w_1, \ldots, w_t$ into a fixed-size vector, and a linear layer plus softmax maps $h_t$ to a distribution over the vocabulary. Training uses cross-entropy loss against the true next token, which is identical to minimising perplexity (file 02). The key limitation: the fixed-size hidden state must encode everything about the history, and information from early tokens gets progressively overwritten.
-
Bidirectional RNNs process the sequence in both directions: one RNN reads left-to-right, another reads right-to-left. At each position $t$, the forward hidden state $\overrightarrow{h}_t$ and backward hidden state $\overleftarrow{h}_t$ are concatenated to form a context-aware representation $h_t = [\overrightarrow{h}_t ; \overleftarrow{h}_t]$. This gives the model access to both past and future context, which is powerful for tasks like POS tagging and NER (file 02) where a word's label depends on words both before and after it. Bidirectional RNNs cannot be used for language modelling because you cannot peek at future tokens when predicting them.
-
Deep stacked RNNs place multiple RNN layers on top of each other. The hidden states of layer $l$ at all time steps become the input sequence for layer $l + 1$. Stacking 2-4 layers typically improves performance by building hierarchical representations, similar to how deeper CNNs build feature hierarchies (chapter 06). Beyond 4 layers, vanishing gradients and overfitting become problems unless residual connections are added between layers.
-
The sequence-to-sequence (seq2seq) architecture (Sutskever et al., 2014) maps a variable-length input sequence to a variable-length output sequence. It consists of an encoder RNN that reads the input and compresses it into a context vector (the final hidden state), and a decoder RNN that generates the output one token at a time, conditioned on this context vector.
-
Seq2seq was the breakthrough architecture for machine translation. The encoder reads a French sentence, the decoder produces the English translation. The decoder starts with a special start-of-sequence token and generates tokens autoregressively until it produces an end-of-sequence token. A practical trick: reversing the input sequence (feeding "chat le" instead of "le chat") improved results because it placed the first input word closer to the first output word in the computation graph, shortening the gradient path.
-
The bottleneck problem: the entire input must be compressed into a single fixed-size vector. For long sentences, this vector cannot capture all the information, and performance degrades. This motivated attention mechanisms.
-
Chapter 06 introduced the modern Q, K, V formulation of attention. The original attention mechanisms for NLP were formulated differently, as alignment models between encoder and decoder states.
-
Bahdanau attention (additive attention, Bahdanau et al., 2015) computes an alignment score between the decoder hidden state $s_t$ and each encoder hidden state $h_i$ using a learned feed-forward network:
$$e_{ti} = v^T \tanh(W_s s_{t-1} + W_h h_i)$$
- The scores are normalised to attention weights via softmax, and the context vector is a weighted sum of encoder states:
$$\alpha_{ti} = \frac{\exp(e_{ti})}{\sum_j \exp(e_{tj})}, \quad c_t = \sum_i \alpha_{ti} h_i$$
-
The decoder then uses both $s_{t-1}$ and $c_t$ to produce the next output. The key insight: instead of one fixed context vector for the entire sentence, each decoder step gets a different weighted combination of encoder states, allowing the model to "look back" at the relevant parts of the input.
-
Luong attention (multiplicative attention, Luong et al., 2015) simplifies the score computation. The dot variant uses $e_{ti} = s_t^T h_i$. The general variant uses $e_{ti} = s_t^T W h_i$. These are faster than Bahdanau's additive score because they use matrix multiplication instead of a feed-forward network. Luong attention also computes the context vector from the current decoder state $s_t$ (rather than $s_{t-1}$), which gives it access to more information but makes the computation slightly different.
-
Attention weights are often visualised as heatmaps showing which input tokens the decoder focuses on when producing each output token. In translation, these heatmaps roughly trace the word alignment between source and target languages, with the diagonal pattern broken by reordering (e.g., adjective-noun order differs between French and English).
-
At inference time, the decoder must choose a token at each step. Greedy decoding picks the highest-probability token at each position, but this can lead to suboptimal sequences: a locally good choice may force the model into a globally bad sentence. Beam search maintains the top $k$ (the beam width) partial sequences at each step, expanding each by all possible next tokens and keeping the best $k$ overall.
-
With beam width $k = 1$, beam search reduces to greedy decoding. Typical values are $k = 4$ to $k = 10$. Larger beams find better sequences but are proportionally slower. Beam search also needs length normalisation to avoid favouring shorter sequences, which naturally have higher total probability because they multiply fewer terms. The normalised score is:
$$\text{score}(y) = \frac{1}{|y|^\alpha} \sum_{t=1}^{|y|} \log P(y_t \mid y_{<t})$$
-
where $|y|$ is the sequence length and $\alpha$ (typically 0.6-0.7) controls the strength of the length penalty. With $\alpha = 0$, there is no length normalisation. With $\alpha = 1$, the score is the per-token log-probability (geometric mean). The intermediate value balances between favouring concise outputs and not truncating too early.
-
While RNNs process text sequentially, 1D CNNs process it in parallel by sliding filters across the token sequence. Each filter detects a local pattern (an n-gram feature).
-
TextCNN (Kim, 2014) applies multiple 1D convolutional filters of different widths (e.g., 3, 4, 5 tokens) to the input embedding matrix. Each filter produces a feature map, and max-over-time pooling takes the single maximum value from each feature map, capturing whether the pattern was detected anywhere in the text regardless of position. The pooled features from all filters are concatenated and passed to a classifier.
-
TextCNN is fast and surprisingly effective for text classification tasks like sentiment analysis. It captures local n-gram patterns but cannot model long-range dependencies: a filter of width 5 only sees 5 consecutive tokens. Dilated causal convolutions address this by inserting gaps (dilations) between filter elements. Stacking layers with exponentially increasing dilation rates (1, 2, 4, 8, ...) grows the receptive field exponentially without increasing parameters, allowing the model to capture dependencies across hundreds of tokens.
-
All the embeddings discussed so far (Word2Vec, GloVe, FastText) produce a single vector per word type regardless of context. "Bank" gets the same embedding whether it means a financial institution or a river bank. This is a fundamental limitation that contextual embeddings address.
-
ELMo (Embeddings from Language Models, Peters et al., 2018) produces contextual word representations by running a deep bidirectional LSTM language model on the input text. The forward LSTM predicts the next word at each position; a separate backward LSTM predicts the previous word. Both are trained as language models on large corpora.
-
At each position $k$, ELMo combines the hidden states from all $L$ layers using task-specific learned weights:
$$\text{ELMo}k = \gamma \sum{j=0}^{L} s_j , h_{k,j}$$
-
Here $h_{k,j}$ is the hidden state at position $k$ and layer $j$ (layer 0 is the raw token embedding), $s_j$ are softmax-normalised scalar weights, and $\gamma$ is a task-specific scaling factor. Different layers capture different information: lower layers capture syntax (POS tags, word morphology), upper layers capture semantics (word sense, semantic role). By mixing all layers with learned weights, ELMo embeddings adapt to diverse downstream tasks.
-
ELMo marked the beginning of the pre-train then fine-tune paradigm: train a large language model on massive unlabelled text, then use its representations for downstream tasks. ELMo specifically uses the pre-trained representations as fixed or lightly tuned features that are concatenated with task-specific inputs. BERT and GPT (file 04) push this further by fine-tuning the entire model end-to-end, which proves dramatically more effective.
-
The progression from Word2Vec to ELMo illustrates a recurring theme in NLP: moving from static to dynamic representations, from local to global context, and from shallow to deep models. Each step trades computational cost for richer representations. Transformers (file 04) complete this progression by replacing recurrence entirely with attention, enabling both deep contextualisation and parallel computation.
Coding Tasks (use CoLab or notebook)
- Implement Word2Vec skip-gram with negative sampling from scratch. Train on a small corpus and visualise the learned embeddings using PCA.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
# Small corpus
corpus = """the king ruled the kingdom . the queen ruled the kingdom .
the prince is the son of the king . the princess is the daughter of the queen .
a man worked in the castle . a woman worked in the castle .
the king and queen lived in the castle . the prince and princess played outside .""".lower().split()
vocab = sorted(set(corpus))
word2idx = {w: i for i, w in enumerate(vocab)}
idx2word = {i: w for w, i in word2idx.items()}
V = len(vocab)
# Generate skip-gram pairs with window size 2
window = 2
pairs = []
for i, word in enumerate(corpus):
for j in range(max(0, i - window), min(len(corpus), i + window + 1)):
if i != j:
pairs.append((word2idx[word], word2idx[corpus[j]]))
pairs = jnp.array(pairs)
print(f"Vocabulary: {V} words, Training pairs: {len(pairs)}")
# Model parameters
embed_dim = 16
key = jax.random.PRNGKey(42)
k1, k2 = jax.random.split(key)
W_in = jax.random.normal(k1, (V, embed_dim)) * 0.1 # input embeddings
W_out = jax.random.normal(k2, (V, embed_dim)) * 0.1 # output embeddings
# Negative sampling loss for one pair
def neg_sampling_loss(W_in, W_out, target, context, neg_ids):
v_in = W_in[target] # (embed_dim,)
v_out = W_out[context] # (embed_dim,)
v_neg = W_out[neg_ids] # (k, embed_dim)
pos_loss = -jax.nn.log_sigmoid(jnp.dot(v_in, v_out))
neg_loss = -jnp.sum(jax.nn.log_sigmoid(-v_neg @ v_in))
return pos_loss + neg_loss
# Training loop
num_neg = 5
lr = 0.05
@jax.jit
def train_step(W_in, W_out, target, context, neg_ids):
loss, (g_in, g_out) = jax.value_and_grad(neg_sampling_loss, argnums=(0, 1))(
W_in, W_out, target, context, neg_ids)
return loss, W_in - lr * g_in, W_out - lr * g_out
key = jax.random.PRNGKey(0)
for epoch in range(50):
total_loss = 0.0
for i in range(len(pairs)):
key, subkey = jax.random.split(key)
neg_ids = jax.random.randint(subkey, (num_neg,), 0, V)
loss, W_in, W_out = train_step(W_in, W_out, pairs[i, 0], pairs[i, 1], neg_ids)
total_loss += loss
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}: avg loss = {total_loss / len(pairs):.4f}")
# Visualise with PCA (chapter 01)
embeddings = W_in
mean = embeddings.mean(axis=0)
centered = embeddings - mean
U, S, Vt = jnp.linalg.svd(centered, full_matrices=False)
coords = centered @ Vt[:2].T # project onto top 2 PCs
plt.figure(figsize=(10, 8))
for i, word in idx2word.items():
plt.scatter(coords[i, 0], coords[i, 1], c='#3498db', s=40)
plt.annotate(word, (coords[i, 0] + 0.02, coords[i, 1] + 0.02), fontsize=9)
plt.title("Word2Vec Skip-gram Embeddings (PCA projection)")
plt.grid(alpha=0.3); plt.show()
- Build a character-level RNN language model that learns to generate text from a small training string.
import jax
import jax.numpy as jnp
# Tiny training text
text = "to be or not to be that is the question "
chars = sorted(set(text))
char2idx = {c: i for i, c in enumerate(chars)}
idx2char = {i: c for c, i in char2idx.items()}
V = len(chars)
data = jnp.array([char2idx[c] for c in text])
# RNN parameters
hidden_dim = 64
key = jax.random.PRNGKey(0)
k1, k2, k3, k4, k5 = jax.random.split(key, 5)
params = {
'Wx': jax.random.normal(k1, (V, hidden_dim)) * 0.1,
'Wh': jax.random.normal(k2, (hidden_dim, hidden_dim)) * 0.05,
'bh': jnp.zeros(hidden_dim),
'Wy': jax.random.normal(k3, (hidden_dim, V)) * 0.1,
'by': jnp.zeros(V),
}
def rnn_step(params, h, x_idx):
x = jnp.eye(V)[x_idx] # one-hot
h = jnp.tanh(x @ params['Wx'] + h @ params['Wh'] + params['bh'])
logits = h @ params['Wy'] + params['by']
return h, logits
def loss_fn(params, inputs, targets):
h = jnp.zeros(hidden_dim)
total_loss = 0.0
for t in range(len(inputs)):
h, logits = rnn_step(params, h, inputs[t])
log_probs = jax.nn.log_softmax(logits)
total_loss -= log_probs[targets[t]]
return total_loss / len(inputs)
grad_fn = jax.jit(jax.grad(loss_fn))
# Training
inputs = data[:-1]
targets = data[1:]
lr = 0.01
for step in range(500):
grads = grad_fn(params, inputs, targets)
params = {k: params[k] - lr * grads[k] for k in params}
if (step + 1) % 100 == 0:
l = loss_fn(params, inputs, targets)
print(f"Step {step+1}: loss = {l:.4f}")
# Generate text
def generate(params, seed_char, length=60):
h = jnp.zeros(hidden_dim)
idx = char2idx[seed_char]
result = [seed_char]
key = jax.random.PRNGKey(42)
for _ in range(length):
h, logits = rnn_step(params, h, idx)
key, subkey = jax.random.split(key)
idx = jax.random.categorical(subkey, logits)
result.append(idx2char[int(idx)])
return ''.join(result)
print(f"\nGenerated: {generate(params, 't')}")
- Implement a toy seq2seq model with Bahdanau attention for sequence reversal. Visualise the attention alignment matrix.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
# Task: reverse a sequence of digits (e.g., [3, 1, 4] -> [4, 1, 3])
vocab_size = 10 # digits 0-9
SOS, EOS = 10, 11 # special tokens
total_vocab = 12
embed_dim, hidden_dim = 16, 32
max_len = 5
key = jax.random.PRNGKey(42)
keys = jax.random.split(key, 8)
params = {
'embed': jax.random.normal(keys[0], (total_vocab, embed_dim)) * 0.1,
'enc_Wx': jax.random.normal(keys[1], (embed_dim, hidden_dim)) * 0.1,
'enc_Wh': jax.random.normal(keys[2], (hidden_dim, hidden_dim)) * 0.05,
'dec_Wx': jax.random.normal(keys[3], (embed_dim, hidden_dim)) * 0.1,
'dec_Wh': jax.random.normal(keys[4], (hidden_dim, hidden_dim)) * 0.05,
# Bahdanau attention
'Ws': jax.random.normal(keys[5], (hidden_dim, hidden_dim)) * 0.1,
'Wh_att': jax.random.normal(keys[6], (hidden_dim, hidden_dim)) * 0.1,
'v_att': jax.random.normal(keys[7], (hidden_dim,)) * 0.1,
# Output projection (from hidden + context to vocab)
'Wo': jax.random.normal(keys[0], (hidden_dim * 2, total_vocab)) * 0.1,
}
def encode(params, seq):
"""Encode input sequence, return all hidden states."""
h = jnp.zeros(hidden_dim)
states = []
for t in range(len(seq)):
x = params['embed'][seq[t]]
h = jnp.tanh(x @ params['enc_Wx'] + h @ params['enc_Wh'])
states.append(h)
return jnp.stack(states), h
def bahdanau_attention(params, dec_state, enc_states):
"""Compute Bahdanau attention weights and context vector."""
scores = jnp.tanh(enc_states @ params['Wh_att'] + dec_state @ params['Ws'])
e = scores @ params['v_att'] # (src_len,)
alpha = jax.nn.softmax(e)
context = alpha @ enc_states
return context, alpha
def decode_step(params, dec_h, prev_token, enc_states):
x = params['embed'][prev_token]
dec_h = jnp.tanh(x @ params['dec_Wx'] + dec_h @ params['dec_Wh'])
context, alpha = bahdanau_attention(params, dec_h, enc_states)
combined = jnp.concatenate([dec_h, context])
logits = combined @ params['Wo']
return dec_h, logits, alpha
def seq2seq_loss(params, src, tgt):
enc_states, enc_final = encode(params, src)
dec_h = enc_final
loss = 0.0
prev_token = SOS
for t in range(len(tgt)):
dec_h, logits, _ = decode_step(params, dec_h, prev_token, enc_states)
log_probs = jax.nn.log_softmax(logits)
loss -= log_probs[tgt[t]]
prev_token = tgt[t]
return loss / len(tgt)
# Generate training data: reverse sequences
key = jax.random.PRNGKey(0)
train_srcs, train_tgts = [], []
for _ in range(200):
key, subkey = jax.random.split(key)
length = jax.random.randint(subkey, (), 3, max_len + 1)
key, subkey = jax.random.split(key)
seq = jax.random.randint(subkey, (int(length),), 0, vocab_size)
train_srcs.append(seq)
train_tgts.append(seq[::-1]) # reverse
# Training
grad_fn = jax.grad(seq2seq_loss)
lr = 0.01
for epoch in range(100):
total_loss = 0.0
for src, tgt in zip(train_srcs, train_tgts):
grads = grad_fn(params, src, tgt)
params = {k: params[k] - lr * grads[k] for k in params}
total_loss += seq2seq_loss(params, src, tgt)
if (epoch + 1) % 20 == 0:
print(f"Epoch {epoch+1}: avg loss = {total_loss / len(train_srcs):.4f}")
# Visualise attention for one example
test_src = jnp.array([3, 1, 4, 1, 5])
test_tgt = test_src[::-1]
enc_states, enc_final = encode(params, test_src)
dec_h = enc_final
attentions = []
prev_token = SOS
for t in range(len(test_tgt)):
dec_h, logits, alpha = decode_step(params, dec_h, prev_token, enc_states)
attentions.append(alpha)
prev_token = test_tgt[t]
att_matrix = jnp.stack(attentions)
fig, ax = plt.subplots(figsize=(6, 5))
im = ax.imshow(att_matrix, cmap='Blues')
ax.set_xlabel("Source position"); ax.set_ylabel("Target position")
src_labels = [str(int(x)) for x in test_src]
tgt_labels = [str(int(x)) for x in test_tgt]
ax.set_xticks(range(len(src_labels))); ax.set_xticklabels(src_labels)
ax.set_yticks(range(len(tgt_labels))); ax.set_yticklabels(tgt_labels)
for i in range(len(tgt_labels)):
for j in range(len(src_labels)):
ax.text(j, i, f"{att_matrix[i,j]:.2f}", ha='center', va='center', fontsize=9)
ax.set_title("Bahdanau Attention Alignment (sequence reversal)")
plt.colorbar(im); plt.tight_layout(); plt.show()
Transformers and Language Models
-
In chapter 06, we introduced the Transformer architecture: self-attention, multi-head attention, positional encoding, and the encoder-decoder structure. Here we focus on how transformers are adapted for specific NLP paradigms, the models that defined modern NLP (BERT, GPT, T5), and the techniques that make them practical at scale.
-
Recall the core operation: scaled dot-product attention computes $\text{softmax}(QK^T / \sqrt{d_k}) V$, where queries, keys, and values are linear projections of the input. Multi-head attention runs $h$ parallel attention heads, each with different learned projections, and concatenates the results. The Transformer block wraps this with residual connections, layer normalisation, and a position-wise feed-forward network (chapter 06).
-
A subtle but important architectural choice is the placement of layer normalisation. The original Transformer uses post-norm: the residual and normalisation come after the sublayer, as $\text{LayerNorm}(x + \text{Sublayer}(x))$.
-
Most modern models use pre-norm: normalise before the sublayer, as $x + \text{Sublayer}(\text{LayerNorm}(x))$. Pre-norm is more stable during training because the residual connection passes gradients directly through the identity path without them being affected by the normalisation. This makes it easier to train very deep models without careful learning rate warmup.
-
The feed-forward sublayer in each Transformer block is a two-layer MLP applied independently to each token position:
$$\text{FFN}(x) = W_2 \cdot \text{GELU}(W_1 x + b_1) + b_2$$
-
The inner dimension is typically 4 times the model dimension (e.g., $d_{\text{model}} = 768$, $d_{\text{ff}} = 3072$). This FFN accounts for about two-thirds of the parameters in each block and is thought to function as a key-value memory that stores factual knowledge learned during training.
-
Positional encoding gives the model information about token order, since attention itself is permutation-equivariant. The original sinusoidal encoding (chapter 06) uses fixed sine and cosine functions at different frequencies. Learned positional embeddings simply add a trainable vector for each position (used in BERT and GPT-2). Both are absolute encodings: position 5 gets the same vector regardless of context.
-
Rotary Position Embedding (RoPE) encodes position by rotating the query and key vectors in 2D subspaces. For a pair of dimensions $(q_{2i}, q_{2i+1})$, the rotation by angle $m\theta_i$ (where $m$ is the position and $\theta_i = 10000^{-2i/d}$) applies:
-
The beauty of RoPE is that the dot product $q'^T k'$ between rotated queries and keys depends only on the relative position $m - n$, not the absolute positions.
-
To see why, write the rotation as $q' = R_m q$ and $k' = R_n k$, where $R_m$ is a block-diagonal rotation matrix. The attention score becomes:
$$q'^T k' = (R_m q)^T (R_n k) = q^T R_m^T R_n , k = q^T R_{n-m} , k$$
-
The last step follows from the rotation group property: $R_m^T R_n = R_{n-m}$ (rotating back by $m$ then forward by $n$ equals rotating by $n - m$).
-
This means the attention score depends only on the relative distance $n - m$, not the absolute positions $m$ and $n$ individually.
-
The model gains a natural notion of distance without any learned position parameters and can generalise to sequence lengths not seen during training.
-
ALiBi (Attention with Linear Biases) takes an even simpler approach: it adds a fixed linear penalty to the attention scores based on distance, as $\text{score}_{ij} = q_i^T k_j - m \cdot |i - j|$, where $m$ is a head-specific slope. Different heads use different slopes, allowing some heads to focus locally and others globally. ALiBi requires no learned parameters for position and generalises well to sequences longer than those seen during training.
-
The three dominant paradigms for Transformer-based language models are encoder-only, decoder-only, and encoder-decoder. They differ in what the model can see (the attention mask) and how they are trained.
-
BERT (Bidirectional Encoder Representations from Transformers, Devlin et al., 2019) is the canonical encoder-only model. It processes text with full bidirectional attention: every token can attend to every other token, both left and right. This gives BERT rich contextual representations but means it cannot generate text autoregressively.
-
BERT is pre-trained with two objectives. Masked language modelling (MLM) randomly masks 15% of input tokens and trains the model to predict them. Of the selected tokens, 80% are replaced with a [MASK] token, 10% with a random word, and 10% are left unchanged (to prevent the model from learning to only predict when it sees [MASK]). The training objective is:
$$\mathcal{L}{\text{MLM}} = -\sum{i \in \mathcal{M}} \log P(w_i \mid w_{\backslash \mathcal{M}})$$
- where $\mathcal{M}$ is the set of masked positions and $w_{\backslash \mathcal{M}}$ is the sentence with those positions masked. This is a denoising objective: the model learns to reconstruct corrupted input.
-
Next Sentence Prediction (NSP) trains BERT to predict whether two sentences are consecutive in the original text. A special [CLS] token at the start of the input is used for this binary classification. NSP was included to help with tasks like question answering that require understanding sentence relationships, though later work (RoBERTa) showed it contributes little and can be dropped.
-
BERT's pre-trained representations are adapted to downstream tasks by adding a task-specific head (a simple linear layer) on top and fine-tuning the entire model. For classification tasks, the [CLS] token representation is used. For token-level tasks (NER, POS tagging), each token's representation is used. This fine-tuning approach transfers the linguistic knowledge learned during pre-training to new tasks with relatively little labelled data.
-
GPT (Generative Pre-trained Transformer, Radford et al., 2018) is the canonical decoder-only model. It uses causal (autoregressive) attention: each token can only attend to tokens at earlier positions (and itself). This is enforced by masking future positions in the attention matrix (setting their scores to $-\infty$ before the softmax). The training objective is simple causal language modelling: predict the next token given all previous tokens.
$$\mathcal{L}{\text{CLM}} = -\sum{i=1}^{n} \log P(w_i \mid w_1, \ldots, w_{i-1})$$
-
This is the same n-gram language model objective from file 02, but with a Transformer parameterisation that can condition on the entire preceding context rather than just the last $k-1$ tokens.
-
GPT-2 scaled this up to 1.5 billion parameters and demonstrated strong zero-shot performance: without any fine-tuning, it could perform tasks by conditioning on a natural language prompt ("Translate English to French: ...").
-
GPT-3 (175 billion parameters) showed that scale alone could enable in-context learning: by providing a few input-output examples in the prompt, the model could perform new tasks without any gradient updates.
-
Encoder-decoder models like T5 (Text-to-Text Transfer Transformer, Raffel et al., 2020) frame every NLP task as text-to-text: the input is a text string (possibly with a task prefix like "translate English to German:") and the output is a text string. The encoder processes the input with bidirectional attention, and the decoder generates the output autoregressively with cross-attention to the encoder.
-
T5 is pre-trained with span corruption: random contiguous spans of tokens are replaced with sentinel tokens, and the model must generate the original tokens. For example, "The cat sat on the mat" might become "The [X] on [Y]" as input, and the target is "[X] cat sat [Y] the mat". This is a generalisation of BERT's MLM to spans rather than individual tokens.
-
BART (Lewis et al., 2020) is another encoder-decoder model pre-trained with a denoising objective, but it applies a broader set of corruption strategies: token masking, token deletion, span masking, sentence permutation, and document rotation. The diversity of corruption forces the model to learn more robust representations.
-
As language models grow larger, full fine-tuning (updating all parameters) becomes impractical: a 175B parameter model requires hundreds of gigabytes just to store the optimizer states. Parameter-efficient fine-tuning (PEFT) methods adapt only a small fraction of parameters.
-
Adapters insert small bottleneck layers (typically two linear layers with a nonlinearity: down-project to a small dimension, then up-project back) between the existing Transformer layers. Only the adapter weights are trained; the original model weights are frozen. This adds less than 5% new parameters while matching full fine-tuning performance on most tasks.
-
LoRA (Low-Rank Adaptation) modifies the weight matrices themselves without adding new layers. Instead of updating the full weight matrix $W$, LoRA learns a low-rank decomposition of the update: $W' = W + BA$, where $B$ is $d \times r$ and $A$ is $r \times d$ with $r \ll d$ (typically $r = 4$ to $r = 64$). The original $W$ is frozen; only $A$ and $B$ are trained. At inference time, the update can be merged into the original weights with no additional latency:
$$W' = W + BA$$
-
Prefix tuning prepends a sequence of learnable "virtual tokens" to the key and value matrices of each attention layer. The model attends to these prefix vectors as if they were real tokens, and only the prefix parameters are trained. This is similar to prompt tuning but operates in the activation space rather than the embedding space.
-
Prompt engineering is the art of designing input text that elicits the desired behaviour from a pre-trained model without any parameter updates.
-
Zero-shot prompting describes the task in natural language ("Classify the sentiment of the following review:").
-
Few-shot prompting provides input-output examples before the actual query.
-
Chain-of-thought (CoT) prompting adds "Let's think step by step" or includes reasoning traces in the examples, which dramatically improves performance on arithmetic and logical reasoning tasks by guiding the model to decompose problems.
-
-
In-context learning (ICL) is the phenomenon where large language models can learn to perform tasks from examples provided in the prompt, without any gradient updates. The model's weights do not change; it uses the examples as a kind of implicit specification.
-
How ICL works mechanically remains an active research question; one hypothesis is that the attention layers implement a form of gradient descent in their forward pass, effectively "training" on the in-context examples.
-
Scaling laws describe predictable relationships between model size, data size, compute budget, and performance (measured by loss). Kaplan et al. (2020) found that loss follows a power law in each variable:
$$L(N) \propto N^{-\alpha_N}, \quad L(D) \propto D^{-\alpha_D}, \quad L(C) \propto C^{-\alpha_C}$$
- where $N$ is the number of parameters, $D$ is the dataset size, and $C$ is the compute budget. These power laws hold over many orders of magnitude and suggest that simply scaling up yields predictable improvements.
- The Chinchilla scaling laws (Hoffmann et al., 2022) revised this by showing that most large models are undertrained. For a fixed compute budget $C$, the optimal allocation scales model size and training data equally:
$$N_{\text{opt}} \propto C^{0.5}, \quad D_{\text{opt}} \propto C^{0.5}$$
-
This means that if you double your compute budget, you should increase both model size and dataset size by a factor of $\sqrt{2}$, not just make the model bigger.
-
Kaplan et al. had recommended scaling $N$ faster than $D$, which led to very large but undertrained models. Chinchilla (70B parameters, 1.4T tokens) matched the performance of Gopher (280B parameters, 300B tokens) with the same compute budget, demonstrating that the earlier models were severely data-starved.
-
The practical rule of thumb: train on roughly 20 tokens per parameter.
-
Mixture of Experts (MoE) is an architecture that scales model capacity without proportionally scaling computation. Instead of one large feed-forward layer, MoE uses multiple expert FFN layers and a gating network (router) that selects which experts to activate for each token.
-
The gating function computes a routing score for each expert and selects the top-$k$ (typically $k = 1$ or $k = 2$):
$$G(x) = \text{TopK}(\text{softmax}(W_g x))$$
- Only the selected experts process the token, so the computational cost scales with $k$ (the number of active experts) rather than the total number of experts $E$. A model with 8 experts and top-2 routing has 4x the parameters of a dense model but only 2x the computation.
- A critical challenge in MoE is load balancing: if the router sends most tokens to a few popular experts, the others are wasted. Training adds an auxiliary load balancing loss that encourages uniform expert utilisation:
$$\mathcal{L}{\text{balance}} = E \cdot \sum{i=1}^{E} f_i \cdot p_i$$
-
where $f_i$ is the fraction of tokens assigned to expert $i$ and $p_i$ is the average routing probability for expert $i$. This product is minimised when both the token fractions and probabilities are uniform (each equal to $1/E$).
-
Expert parallelism distributes different experts across different accelerators. During the forward pass, an all-to-all communication step routes tokens to the device hosting their assigned expert, then routes the results back. This communication cost is the main engineering challenge of MoE at scale. Models like Switch Transformer, Mixtral, and GShard use MoE to achieve strong performance with practical inference costs.
-
Building models is half the job; measuring whether they work is the other half. NLP evaluation is uniquely difficult because language is ambiguous, subjective, and open-ended.
-
A translation can be correct in many different ways. A summary can be good even if it shares no exact words with a reference.
-
A chatbot response can be helpful, harmless, and honest, yet reasonable humans will disagree.
-
Exact match (EM) is the simplest metric: does the model's output exactly match the gold answer? It is used for tasks with short, unambiguous answers like extractive question answering (SQuAD) or closed-form maths.
-
EM is harsh; "New York City" and "new york city" fail to match unless normalisation is applied — but its simplicity makes it unambiguous.
-
Token-level metrics treat NLP as a classification problem at the token level, using precision, recall, and F1 from chapter 06.
-
Precision measures what fraction of the model's predicted tokens are correct: $P = \text{TP} / (\text{TP} + \text{FP})$. A model that predicts very few entities but gets them all right has high precision.
-
Recall measures what fraction of the gold tokens the model found: $R = \text{TP} / (\text{TP} + \text{FN})$. A model that predicts every token as an entity has perfect recall but terrible precision.
-
F1 is the harmonic mean of precision and recall:
$$F_1 = \frac{2PR}{P + R}$$
-
The harmonic mean (rather than arithmetic) penalises imbalance: if either $P$ or $R$ is low, $F_1$ is low. For NER (file 02), F1 is computed per entity type and then macro-averaged across types. For POS tagging, token-level accuracy is more common because every token gets a tag.
-
Span-level F1 (used in SQuAD) compares the set of tokens in the predicted span to the set in the gold span. This is more forgiving than exact match: if the gold answer is "the Eiffel Tower" and the model predicts "Eiffel Tower", the span F1 is high (4 overlapping tokens out of 5) even though EM is zero.
-
BLEU (Bilingual Evaluation Understudy, Papineni et al., 2002) is the classic metric for machine translation. It measures n-gram overlap between the candidate translation and one or more reference translations. The score combines precision at multiple n-gram levels (unigram through 4-gram) with a brevity penalty:
$$\text{BLEU} = \text{BP} \cdot \exp!\left(\sum_{n=1}^{N} w_n \log p_n\right)$$
-
where $p_n$ is the modified n-gram precision: the count of each n-gram in the candidate is clipped to its maximum count in any reference, preventing a degenerate candidate like "the the the the" from scoring high. The weights $w_n$ are typically uniform ($w_n = 1/N$, with $N = 4$).
-
The brevity penalty $\text{BP} = \min(1, \exp(1 - r/c))$ penalises candidates shorter than the reference ($c$ is candidate length, $r$ is reference length). Without this, a model could achieve high precision by outputting very few, very safe words.
-
BLEU correlates reasonably with human judgement at the corpus level (averaged over many sentences) but poorly at the sentence level.
-
It rewards exact n-gram matches and misses valid paraphrases: "the cat is on the mat" and "a feline sits atop the rug" have zero bigram overlap despite meaning the same thing.
-
BLEU also ignores recall entirely — a candidate that produces only the most common words scores well on precision.
-
ROUGE (Recall-Oriented Understudy for Gisting Evaluation, Lin, 2004) is the standard metric for summarisation. Unlike BLEU, which emphasises precision, ROUGE emphasises recall: what fraction of the reference n-grams appear in the candidate?
-
ROUGE-N computes recall of n-grams: $\text{ROUGE-N} = \frac{|\text{n-grams}{\text{ref}} \cap \text{n-grams}{\text{cand}}|}{|\text{n-grams}_{\text{ref}}|}$. ROUGE-1 (unigram) and ROUGE-2 (bigram) are most common.
-
ROUGE-L uses the longest common subsequence (LCS) between candidate and reference, which captures sentence-level word ordering without requiring consecutive matches.
-
The LCS length normalised by reference length gives recall, normalised by candidate length gives precision, and the F-measure combines them.
-
LCS is computed via dynamic programming in $O(mn)$ time (similar to edit distance from file 02):
$$R_{\text{LCS}} = \frac{\text{LCS}(X, Y)}{m}, \quad P_{\text{LCS}} = \frac{\text{LCS}(X, Y)}{n}, \quad F_{\text{LCS}} = \frac{(1 + \beta^2) R_{\text{LCS}} P_{\text{LCS}}}{R_{\text{LCS}} + \beta^2 P_{\text{LCS}}}$$
-
where $m$ and $n$ are the lengths of reference and candidate, and $\beta$ is typically set to favour recall ($\beta \to \infty$ gives pure recall).
-
METEOR (Metric for Evaluation of Translation with Explicit ORdering, Banerjee and Lavie, 2005) addresses BLEU's weaknesses by incorporating synonyms, stemming, and word order.
-
It first aligns words between candidate and reference using exact matches, stem matches (via Porter stemming from file 02), and synonym matches (via WordNet from file 01).
-
Then it computes a harmonic mean of unigram precision and recall weighted toward recall, and applies a fragmentation penalty that penalises candidates where matched words appear in a different order than the reference.
-
ChrF (Character n-gram F-score) computes F-score over character n-grams rather than word n-grams. This makes it robust to morphological variation (critical for agglutinative languages from file 01) and partially handles tokenisation differences. ChrF++ adds word bigrams to the character n-grams.
-
It has become a recommended metric for machine translation alongside BLEU, especially for morphologically rich languages.
-
Perplexity (file 02) measures how well a language model predicts a held-out test set. It is the standard intrinsic metric for language models: $\text{PPL} = \exp(-\frac{1}{N} \sum_{i} \log P(w_i \mid w_{<i}))$. Lower is better.
-
Perplexity is comparable only between models using the same tokenisation, since different tokenisers produce different sequence lengths $N$ for the same text.
-
A model with a larger vocabulary tends to have lower perplexity per token but processes fewer tokens per sentence.
-
Bits-per-byte (BPB) normalises by the number of UTF-8 bytes in the text rather than the number of tokens, making it tokenisation-independent:
- BERTScore (Zhang et al., 2020) moves beyond surface-level n-gram matching by computing similarity in embedding space. Each token in the candidate is matched to its most similar token in the reference using cosine similarity of contextual embeddings (typically from a pre-trained BERT model). The scores are aggregated into precision, recall, and F1:
$$R_{\text{BERT}} = \frac{1}{|r|} \sum_{r_i \in r} \max_{c_j \in c} \cos(r_i, c_j), \quad P_{\text{BERT}} = \frac{1}{|c|} \sum_{c_j \in c} \max_{r_i \in r} \cos(c_j, r_i)$$
-
where $r_i$ and $c_j$ are contextual embeddings of reference and candidate tokens. This captures semantic similarity that n-gram metrics miss: "automobile" and "car" score highly because their BERT embeddings are similar, even though they share no characters.
-
BLEURT (Sellam et al., 2020) takes this further by fine-tuning a BERT model directly on human quality judgements. Given a reference and candidate pair, it outputs a scalar quality score. BLEURT is trained on synthetic data (random perturbations of reference translations rated by metrics like BLEU and METEOR) and then fine-tuned on human ratings. It correlates better with human judgement than any surface-level metric.
-
COMET (Crosslingual Optimized Metric for Evaluation of Translation, Rei et al., 2020) is a learned metric for machine translation that conditions on the source sentence, reference, and candidate — not just reference and candidate. It uses a multilingual encoder (XLM-R) to embed all three and predicts a quality score. By seeing the source, COMET can detect meaning errors that reference-only metrics miss (e.g., a fluent but factually wrong translation).
-
LLM-as-judge is the modern approach to evaluation at scale. Instead of computing metrics against references, a powerful language model (GPT-4, Claude) is prompted to evaluate the quality of model outputs. The judge receives the input, the model's response, and optionally a reference answer, and produces a rating (e.g., 1-5) or a pairwise preference (response A is better than response B).
-
Pairwise comparison (used in Chatbot Arena) is the most reliable LLM-as-judge format. The judge sees two responses and picks the better one, rather than assigning absolute scores. This avoids calibration issues (different judges may have different baselines for "3 out of 5"). Results are aggregated into Elo ratings (from chess), where each model starts with a base rating and gains or loses points based on wins and losses against other models. The expected win probability of model $A$ against model $B$ is:
$$P(A \succ B) = \frac{1}{1 + 10^{(R_B - R_A) / 400}}$$
-
where $R_A, R_B$ are the Elo ratings. After each comparison, ratings are updated: $R_A' = R_A + K(S - P(A \succ B))$, where $S \in {0, 1}$ is the actual outcome and $K$ controls the update magnitude. Models that consistently beat strong opponents rise quickly; models that lose to weak opponents fall.
-
Position bias is a known issue with LLM judges: they tend to prefer the response presented first (or in some models, the response presented second). Swapping (evaluating each pair twice with responses in both orders) and averaging the results mitigates this.
-
Verbosity bias is another: judges tend to prefer longer, more detailed responses even when a concise answer is better.
-
Self-consistency checks whether the judge gives the same rating across multiple evaluations of the same input. High variance indicates the evaluation signal is noisy.
-
Inter-annotator agreement (Cohen's kappa or Krippendorff's alpha) measures whether multiple judges agree, providing an upper bound on evaluation reliability.
-
Contamination is a critical concern: if the evaluation data appeared in the model's training set, benchmark scores are inflated and meaningless.
-
This is especially problematic for LLMs trained on web-scraped data, where popular benchmarks are likely present. Mitigation strategies include: using held-out test sets that are not publicly released, creating dynamic benchmarks that regenerate questions periodically, canary strings (unique identifiers embedded in benchmark data to detect leakage), and comparing performance on contaminated vs clean subsets.
-
Standard NLU benchmarks evaluate language understanding across diverse tasks.
-
GLUE (General Language Understanding Evaluation) and SuperGLUE are multi-task benchmarks covering sentiment (SST-2), textual similarity (STS-B), natural language inference (MNLI, RTE), coreference (WSC), and question answering (BoolQ).
-
Models are evaluated on each task separately and scored by an aggregate metric. GLUE is now considered saturated (models exceed human performance on most tasks); SuperGLUE remains more challenging.
-
MMLU (Massive Multitask Language Understanding) evaluates knowledge and reasoning across 57 academic subjects (mathematics, history, law, medicine, computer science, etc.) using multiple-choice questions.
-
It tests whether a model has absorbed broad knowledge during pre-training. Scores are reported per subject and as a macro-average.
-
MMLU-Pro adds harder, multi-step reasoning questions with 10 answer choices instead of 4.
-
HellaSwag tests commonsense reasoning by asking the model to choose the most plausible continuation of a scenario. The wrong answers are generated adversarially (using models) to be superficially plausible but semantically wrong.
-
WinoGrande tests commonsense coreference resolution with minimal pairs that differ by one word.
-
ARC (AI2 Reasoning Challenge) uses grade-school science questions in easy and challenge sets, testing factual and reasoning ability.
-
Reasoning and maths benchmarks evaluate the problem-solving capabilities that separate strong LLMs from weak ones.
-
GSM8K (Grade School Math 8K) contains 8,500 elementary maths word problems requiring multi-step arithmetic reasoning. It is the standard benchmark for basic mathematical reasoning and for evaluating chain-of-thought prompting (file 04).
-
MATH is a harder dataset of competition-level maths problems across algebra, number theory, geometry, counting, and probability. Problems require multi-step symbolic reasoning, and MATH-500 is a commonly reported 500-problem subset.
-
AIME (American Invitational Mathematics Examination) problems are competition-level: solving them correctly requires deep mathematical reasoning over many steps. DeepSeek-R1 scores 79.8% on AIME 2024, demonstrating that RL-trained reasoning models (file 05) can approach strong human competitors.
-
HumanEval and MBPP (Mostly Basic Programming Problems) evaluate code generation by checking whether the model's code passes unit tests. HumanEval contains 164 Python problems with function signatures and docstrings; the model must generate the function body.
-
The metric is pass@k: the probability that at least one of $k$ generated solutions passes all tests. For a single sample:
$$\text{pass@}k = 1 - \frac{\binom{n-c}{k}}{\binom{n}{k}}$$
-
where $n$ is the total number of generated samples and $c$ is the number that pass. This formula corrects for the bias in simply taking the best of $k$ samples.
-
SWE-bench goes further, evaluating whether models can resolve real GitHub issues by modifying existing codebases — a much harder test of practical software engineering ability.
-
GPQA (Graduate-Level Google-Proof QA) contains expert-level questions in biology, physics, and chemistry that are difficult even for domain experts. It tests whether models have genuine understanding rather than pattern matching. The "Diamond" subset is the hardest.
-
Safety and alignment benchmarks evaluate whether models are helpful, harmless, and honest.
-
TruthfulQA tests whether models reproduce common misconceptions. Questions are designed so that the most common internet answers are wrong (e.g., "What happens if you swallow gum?", the common myth is that it stays for 7 years, but the truthful answer is that it passes through normally). Models that have memorised popular but incorrect claims score poorly.
-
BBQ (Bias Benchmark for QA) tests for social biases across categories like age, gender, race, and religion. Questions are structured so that a biased model would systematically choose stereotypical answers. Toxigen evaluates the model's tendency to generate toxic content about specific demographic groups.
-
MT-Bench evaluates multi-turn conversation ability using 80 carefully designed questions across writing, roleplay, reasoning, maths, coding, extraction, STEM, and humanities. An LLM judge (GPT-4) scores responses on a 1-10 scale. The multi-turn format tests whether models can follow up, maintain context, and handle clarification requests.
-
Chatbot Arena (LMSYS) uses real users to conduct blind pairwise comparisons between anonymous models. Users submit prompts and vote for the better response without knowing which model produced it. The resulting Elo leaderboard is considered the most ecologically valid evaluation of general-purpose LLM quality because it reflects real user preferences on diverse, uncurated prompts.
-
AlpacaEval automates pairwise evaluation by comparing model outputs against a reference model (GPT-4) on a fixed set of instructions. A judge model determines the win rate.
-
AlpacaEval 2.0 uses length-controlled win rates to correct for verbosity bias.
-
Task-specific evaluation requires tailored metrics for specialised domains.
-
Word Error Rate (WER) for speech recognition: $\text{WER} = (S + D + I) / N$, where $S$, $D$, $I$ are substitution, deletion, and insertion errors and $N$ is the number of reference words. This is the edit distance (file 02) normalised by reference length, applied at the word level.
-
Slot F1 for task-oriented dialogue systems measures whether the model correctly extracts structured information from user utterances (e.g., extracting "destination: Paris" and "date: tomorrow" from "Book me a flight to Paris tomorrow").
-
Citation accuracy for RAG systems (file 05) checks whether the model's generated citations actually support the claims made. A claim is verified against the retrieved passage, and the metric counts the fraction of claims that are fully, partially, or not supported.
-
Evaluation pitfalls are common and can invalidate entire benchmark comparisons.
-
Teaching to the test: optimising for benchmark performance rather than genuine capability. A model fine-tuned on MMLU-style multiple choice will score well on MMLU but may fail on the same questions posed in open-ended format.
-
Metric gaming: models can be optimised to produce outputs that score well on automatic metrics (high BLEU, low perplexity) without being genuinely good. The BLEU-optimal translation is often a safe, generic paraphrase rather than a natural, fluent one.
-
Benchmark saturation: when models approach or exceed human performance on a benchmark, the benchmark stops being informative. GLUE, SQuAD 1.1, and several others are now saturated.
-
The field continuously creates harder benchmarks, but the cycle of creation, saturation, and replacement makes longitudinal comparison difficult.
-
Human evaluation remains the gold standard but is expensive, slow, and hard to reproduce. Different annotator pools (crowdworkers vs domain experts, different cultures, different languages) produce different judgements. Reporting inter-annotator agreement and annotator demographics is essential for reproducibility.
Coding Tasks (use CoLab or notebook)
- Implement a full Transformer encoder block from scratch (multi-head attention, feed-forward, residual connections, layer norm). Apply it to a simple sequence classification task.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
def layer_norm(x, gamma, beta, eps=1e-5):
mean = x.mean(axis=-1, keepdims=True)
var = x.var(axis=-1, keepdims=True)
return gamma * (x - mean) / jnp.sqrt(var + eps) + beta
def multi_head_attention(Q, K, V, W_q, W_k, W_v, W_o, n_heads):
B, T, D = Q.shape
head_dim = D // n_heads
q = Q @ W_q # (B, T, D)
k = K @ W_k
v = V @ W_v
# Reshape to (B, n_heads, T, head_dim)
q = q.reshape(B, T, n_heads, head_dim).transpose(0, 2, 1, 3)
k = k.reshape(B, T, n_heads, head_dim).transpose(0, 2, 1, 3)
v = v.reshape(B, T, n_heads, head_dim).transpose(0, 2, 1, 3)
scores = q @ k.transpose(0, 1, 3, 2) / jnp.sqrt(head_dim)
weights = jax.nn.softmax(scores, axis=-1)
out = (weights @ v).transpose(0, 2, 1, 3).reshape(B, T, D)
return out @ W_o, weights
def transformer_block(x, params):
# Pre-norm multi-head self-attention
normed = layer_norm(x, params['ln1_g'], params['ln1_b'])
attn_out, weights = multi_head_attention(
normed, normed, normed,
params['W_q'], params['W_k'], params['W_v'], params['W_o'],
n_heads=4
)
x = x + attn_out
# Pre-norm feed-forward
normed = layer_norm(x, params['ln2_g'], params['ln2_b'])
ff = jax.nn.gelu(normed @ params['W1'] + params['b1'])
ff = ff @ params['W2'] + params['b2']
x = x + ff
return x, weights
# Initialise parameters
d_model, d_ff, n_heads = 32, 128, 4
key = jax.random.PRNGKey(42)
keys = jax.random.split(key, 10)
params = {
'W_q': jax.random.normal(keys[0], (d_model, d_model)) * 0.05,
'W_k': jax.random.normal(keys[1], (d_model, d_model)) * 0.05,
'W_v': jax.random.normal(keys[2], (d_model, d_model)) * 0.05,
'W_o': jax.random.normal(keys[3], (d_model, d_model)) * 0.05,
'ln1_g': jnp.ones(d_model), 'ln1_b': jnp.zeros(d_model),
'ln2_g': jnp.ones(d_model), 'ln2_b': jnp.zeros(d_model),
'W1': jax.random.normal(keys[4], (d_model, d_ff)) * 0.05,
'b1': jnp.zeros(d_ff),
'W2': jax.random.normal(keys[5], (d_ff, d_model)) * 0.05,
'b2': jnp.zeros(d_model),
}
# Test with random input
x = jax.random.normal(keys[6], (2, 8, d_model)) # batch=2, seq_len=8
out, attn_weights = transformer_block(x, params)
print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}")
print(f"Attention weights shape: {attn_weights.shape}") # (B, n_heads, T, T)
# Visualise attention patterns for each head
fig, axes = plt.subplots(1, 4, figsize=(16, 3.5))
for h in range(4):
im = axes[h].imshow(attn_weights[0, h], cmap='Blues', vmin=0)
axes[h].set_title(f"Head {h}")
axes[h].set_xlabel("Key pos"); axes[h].set_ylabel("Query pos")
plt.suptitle("Multi-Head Attention Patterns")
plt.tight_layout(); plt.show()
- Implement causal (autoregressive) attention masking and compare it with bidirectional attention. Show how the mask prevents information from flowing from future to past tokens.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
def attention(Q, K, V, mask=None):
d_k = Q.shape[-1]
scores = Q @ K.T / jnp.sqrt(d_k)
if mask is not None:
scores = jnp.where(mask, scores, -1e9)
weights = jax.nn.softmax(scores, axis=-1)
return weights @ V, weights
seq_len, d_model = 6, 8
key = jax.random.PRNGKey(0)
k1, k2, k3 = jax.random.split(key, 3)
Q = jax.random.normal(k1, (seq_len, d_model))
K = jax.random.normal(k2, (seq_len, d_model))
V = jax.random.normal(k3, (seq_len, d_model))
# Bidirectional (encoder-style): all positions visible
bidir_mask = jnp.ones((seq_len, seq_len), dtype=bool)
bidir_out, bidir_weights = attention(Q, K, V, bidir_mask)
# Causal (decoder-style): only past and current positions visible
causal_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=bool))
causal_out, causal_weights = attention(Q, K, V, causal_mask)
fig, axes = plt.subplots(1, 3, figsize=(14, 4))
tokens = [f"t{i}" for i in range(seq_len)]
axes[0].imshow(bidir_weights, cmap='Blues', vmin=0, vmax=0.5)
axes[0].set_title("Bidirectional Attention\n(BERT-style)")
axes[0].set_xticks(range(seq_len)); axes[0].set_xticklabels(tokens)
axes[0].set_yticks(range(seq_len)); axes[0].set_yticklabels(tokens)
axes[1].imshow(causal_mask.astype(float), cmap='Greys', vmin=0, vmax=1)
axes[1].set_title("Causal Mask\n(1 = allowed, 0 = blocked)")
axes[1].set_xticks(range(seq_len)); axes[1].set_xticklabels(tokens)
axes[1].set_yticks(range(seq_len)); axes[1].set_yticklabels(tokens)
axes[2].imshow(causal_weights, cmap='Blues', vmin=0, vmax=0.5)
axes[2].set_title("Causal Attention\n(GPT-style)")
axes[2].set_xticks(range(seq_len)); axes[2].set_xticklabels(tokens)
axes[2].set_yticks(range(seq_len)); axes[2].set_yticklabels(tokens)
for ax in axes:
ax.set_xlabel("Key"); ax.set_ylabel("Query")
plt.tight_layout(); plt.show()
# Verify: in causal attention, output at position i depends only on positions <= i
print("Causal attention weight at position 2 (should only attend to 0, 1, 2):")
print(f" Weights: {causal_weights[2]}")
print(f" Sum of future weights (should be ~0): {causal_weights[2, 3:].sum():.6f}")
- Implement LoRA (Low-Rank Adaptation) and show how it modifies a weight matrix with far fewer trainable parameters than full fine-tuning.
import jax
import jax.numpy as jnp
d_model = 256
rank = 4 # LoRA rank (much smaller than d_model)
key = jax.random.PRNGKey(42)
k1, k2, k3 = jax.random.split(key, 3)
# Original frozen weight matrix
W_frozen = jax.random.normal(k1, (d_model, d_model)) * 0.02
# LoRA matrices (only these are trainable)
B = jnp.zeros((d_model, rank)) # initialised to zero
A = jax.random.normal(k2, (rank, d_model)) * 0.01 # random init
# Forward pass: W_effective = W_frozen + B @ A
x = jax.random.normal(k3, (8, d_model))
# Without LoRA
y_original = x @ W_frozen.T
# With LoRA
W_effective = W_frozen + B @ A
y_lora = x @ W_effective.T
# Parameter counts
full_params = d_model * d_model
lora_params = d_model * rank + rank * d_model # B + A
print(f"Model dimension: {d_model}")
print(f"LoRA rank: {rank}")
print(f"Full fine-tuning parameters: {full_params:,}")
print(f"LoRA parameters: {lora_params:,}")
print(f"Parameter reduction: {full_params / lora_params:.1f}x")
print(f"\nSince B is initialised to zeros, initial LoRA output matches original:")
print(f" Max difference: {jnp.abs(y_original - y_lora).max():.2e}")
# Simulate training: only update A and B
def lora_forward(A, B, W_frozen, x):
return x @ (W_frozen + B @ A).T
def dummy_loss(A, B, W_frozen, x, target):
pred = lora_forward(A, B, W_frozen, x)
return jnp.mean((pred - target) ** 2)
# Target: some transformation of x
target = x @ jax.random.normal(jax.random.PRNGKey(99), (d_model, d_model)).T * 0.02
grad_fn = jax.jit(jax.grad(dummy_loss, argnums=(0, 1)))
lr = 0.01
for step in range(200):
gA, gB = grad_fn(A, B, W_frozen, x, target)
A = A - lr * gA
B = B - lr * gB
loss_before = dummy_loss(jnp.zeros_like(A), jnp.zeros_like(B), W_frozen, x, target)
loss_after = dummy_loss(A, B, W_frozen, x, target)
print(f"\nLoss before LoRA: {loss_before:.6f}")
print(f"Loss after LoRA: {loss_after:.6f}")
print(f"Effective weight change rank: {jnp.linalg.matrix_rank(B @ A)}")
Advanced Text Generation
-
Standard autoregressive generation (file 04) produces text one token at a time, left to right. This is simple and effective, but it is inherently sequential, allows no global planning, and gives limited control over the output. This file covers methods that go beyond vanilla autoregressive decoding: diffusion models for text, optical character recognition, controllable generation through human feedback, handling long contexts, retrieval-augmented generation, and speculative decoding for faster inference.
-
Text diffusion models apply the diffusion framework (introduced for images in chapter 08) to discrete text. The core challenge is that text is discrete: you cannot add continuous Gaussian noise to tokens the way you add noise to pixels. Several approaches address this.
-
D3PM (Discrete Denoising Diffusion Probabilistic Models, Austin et al., 2021) defines a forward corruption process directly over discrete tokens using transition matrices. At each forward step, a token has some probability of being replaced by another token (uniform noise), masked (absorbing state), or staying the same. The reverse process learns to denoise, predicting the clean token from the corrupted one. The transition matrix $Q_t$ at step $t$ controls corruption:
$$q(x_t \mid x_{t-1}) = \text{Cat}(x_t ; , x_{t-1} Q_t)$$
- where $\text{Cat}$ denotes a categorical distribution and $x$ is a one-hot vector. The multi-step forward process $q(x_t \mid x_0)$ has a closed form: $q(x_t \mid x_0) = \text{Cat}(x_t ; , x_0 \bar{Q}_t)$ where $\bar{Q}_t = Q_1 Q_2 \cdots Q_t$ is the product of all transition matrices up to step $t$. Training minimises a variational lower bound (ELBO) that decomposes across timesteps, similar to the continuous case (chapter 08):
$$\mathcal{L}{\text{D3PM}} = D{\text{KL}}(q(x_T \mid x_0) | p(x_T)) + \sum_{t=2}^{T} D_{\text{KL}}(q(x_{t-1} \mid x_t, x_0) | p_\theta(x_{t-1} \mid x_t)) - \log p_\theta(x_0 \mid x_1)$$
-
The first term ensures the fully corrupted distribution matches the prior (uniform or all-mask). The sum of KL terms trains the model to reverse each corruption step: the true reverse posterior $q(x_{t-1} \mid x_t, x_0)$ can be computed in closed form using Bayes' rule and the known transition matrices, and the model $p_\theta(x_{t-1} \mid x_t)$ is trained to match it.
-
Since both distributions are categorical, the KL divergence is a simple sum over vocabulary entries. The final term measures reconstruction quality from the least corrupted state.
-
MDLM (Masked Diffusion Language Models, Sahoo et al., 2024) simplifies D3PM by using masking as the only corruption operation: the forward process gradually replaces tokens with a [MASK] token, and the reverse process predicts the original tokens. This connects text diffusion to masked language modelling (BERT, file 04), with the diffusion timestep controlling what fraction of tokens are masked. At $t = 0$ the text is fully clean; at $t = T$ it is fully masked.
-
Continuous text diffusion sidesteps the discrete problem by working in the continuous embedding space. Tokens are first mapped to their embedding vectors (chapter 06), noise is added in this continuous space, and a denoising model (typically a Transformer) learns to reverse the process. At generation time, the model produces continuous vectors that are mapped back to discrete tokens by finding the nearest embedding. The challenge is that small errors in continuous space can map to completely wrong tokens, so careful rounding and clamping are needed.
-
The appeal of text diffusion is that it generates all tokens simultaneously through iterative refinement, rather than left-to-right. This allows global coherence and easy infilling (generating missing text in the middle of a passage), but current text diffusion models still lag behind autoregressive models in generation quality for long-form text.
-
Text OCR (Optical Character Recognition) is the task of extracting text from images. While not traditionally grouped with language generation, modern OCR systems are deeply integrated with NLP and increasingly use language model components.
-
Scene text detection locates text regions in natural images (street signs, product labels, licence plates). This is challenging because text in the wild appears at arbitrary angles, scales, fonts, and against cluttered backgrounds. Detection methods typically use CNN or Transformer backbones to produce bounding boxes or segmentation masks around text regions.
-
CRNN (Convolutional Recurrent Neural Network, Shi et al., 2017) is a classic text recognition architecture. A CNN extracts visual features from the text image, the feature map is sliced into a sequence of columns (one per horizontal position), and a bidirectional LSTM reads this sequence to model context. The output is decoded using CTC (Connectionist Temporal Classification), which handles the alignment between input columns and output characters without requiring explicit segmentation.
-
The fundamental problem CTC solves: the model produces $T$ output distributions (one per input column), but the target text has $L \leq T$ characters.
-
We do not know which columns correspond to which characters. CTC introduces a blank token $\epsilon$ and defines a many-to-one mapping $\mathcal{B}$ that collapses repeated characters and removes blanks: $\mathcal{B}(\text{"HH-ee-ll-ll-oo"}) = \text{"Hello"}$ (where "-" is blank).
-
The probability of the target sequence $y$ is the sum over all input alignments that collapse to $y$:
$$P(y \mid x) = \sum_{\pi \in \mathcal{B}^{-1}(y)} \prod_{t=1}^{T} P(\pi_t \mid x)$$
-
where $\pi$ is an alignment path of length $T$ (one label per column, including blanks). Naively summing over all paths is exponential, but the forward algorithm (chapter 05 HMMs) computes this sum efficiently in $O(T \cdot L)$ time using dynamic programming.
-
The blank token is essential: without it, repeated characters like "ll" in "Hello" would be indistinguishable from a single "l". Training maximises $\log P(y \mid x)$, and at inference time, the best path is found by beam search or greedy decoding over the CTC output.
-
Document OCR processes structured documents (invoices, forms, scientific papers) and must understand layout in addition to recognising characters. Modern systems like LayoutLM combine text recognition with spatial position features: each token gets both its text embedding and a positional embedding encoding its $(x, y)$ coordinates on the page. This allows the model to understand that a number appearing below "Total:" is the total amount.
-
Vision-language OCR models like TrOCR treat text recognition as image-to-text generation: a Vision Transformer encoder processes the image, and a language model decoder generates the text character by character. This leverages the power of pre-trained vision and language models and handles diverse scripts, fonts, and layouts without handcrafted feature engineering.
-
Controllable generation is the challenge of steering a language model to produce outputs with desired properties: a particular style, topic, sentiment, safety level, or factual accuracy. The model should follow instructions while remaining fluent and coherent.
-
Classifier-free guidance (CFG) for text adapts a technique from image generation. During training, the conditioning signal (e.g., a prompt) is randomly dropped some fraction of the time, training both a conditional and unconditional model in one. At inference, the output logits are interpolated:
$$\text{logits}{\text{guided}} = (1 + w) \cdot \text{logits}{\text{conditional}} - w \cdot \text{logits}_{\text{unconditional}}$$
-
where $w > 0$ amplifies the influence of the condition. Higher $w$ makes the output more strongly follow the prompt but reduces diversity.
-
RLHF (Reinforcement Learning from Human Feedback, Ouyang et al., 2022) is the dominant method for aligning language models with human preferences. The process has three stages:
-
First, supervised fine-tuning (SFT): fine-tune the base language model on a dataset of high-quality human-written responses to prompts.
-
Second, reward model training: collect human comparisons (given prompt $x$ and two responses $y_1, y_2$, which is better?) and train a reward model $r_\phi(x, y)$ to predict human preferences. The reward model is trained with a pairwise ranking loss:
$$\mathcal{L}{\text{RM}} = -\log \sigma(r\phi(x, y_w) - r_\phi(x, y_l))$$
-
where $y_w$ is the preferred response and $y_l$ is the dispreferred one.
-
Third, RL fine-tuning: optimise the language model to maximise the reward while staying close to the SFT model (to prevent mode collapse). This uses PPO (Proximal Policy Optimisation, from chapter 06) with a KL penalty:
$$\mathcal{L}{\text{RL}} = -\mathbb{E}\left[r\phi(x, y) - \beta , D_{\text{KL}}(\pi_\theta | \pi_{\text{SFT}})\right]$$
- The KL term prevents the model from drifting too far from the base model and exploiting quirks of the reward model ("reward hacking").
- DPO (Direct Preference Optimisation, Rafailov et al., 2023) simplifies RLHF by eliminating the reward model entirely. The key mathematical insight is that the KL-constrained RL objective above has a closed-form optimal policy:
$$\pi^\ast(y \mid x) = \frac{1}{Z(x)} \pi_{\text{ref}}(y \mid x) \exp!\left(\frac{r(x, y)}{\beta}\right)$$
- where $Z(x)$ is a normalising partition function. Rearranging this for the reward gives $r(x, y) = \beta \log \frac{\pi^\ast(y \mid x)}{\pi_{\text{ref}}(y \mid x)} + \beta \log Z(x)$. Substituting this implicit reward into the Bradley-Terry preference model $P(y_w \succ y_l) = \sigma(r(x, y_w) - r(x, y_l))$ causes the intractable $Z(x)$ terms to cancel, yielding the DPO loss directly:
$$\mathcal{L}{\text{DPO}} = -\log \sigma!\left(\beta \log \frac{\pi\theta(y_w \mid x)}{\pi_{\text{ref}}(y_w \mid x)} - \beta \log \frac{\pi_\theta(y_l \mid x)}{\pi_{\text{ref}}(y_l \mid x)}\right)$$
-
This is mathematically equivalent to RLHF but collapses the reward model and RL training into a single supervised step.
-
The expression inside the sigmoid can be read as: "increase the relative probability of the preferred response and decrease the relative probability of the dispreferred response, measured against the reference model."
-
The $\beta$ parameter controls how much the policy can deviate from the reference. In practice, DPO is simpler to implement (just compute log-probabilities under the current and reference models for both completions) and avoids the instabilities of PPO training.
-
Constitutional AI (Bai et al., 2022) automates parts of the alignment process. Instead of collecting human comparisons, it uses the language model itself to critique and revise its own outputs according to a set of principles (the "constitution"), such as "choose the response that is less harmful." The AI-generated comparisons are then used for preference training (RLAIF: RL from AI Feedback).
-
Long-context methods address the $O(n^2)$ memory and compute cost of standard self-attention, which limits sequence length. As $n$ grows into the tens or hundreds of thousands of tokens, standard attention becomes infeasible.
-
Sparse attention replaces the dense $n \times n$ attention matrix with a sparse pattern where each token attends to only a subset of other tokens. Common patterns include local attention (each token attends to a fixed-size window of neighbours), strided attention (attend to every $k$-th token), and random attention (attend to a random subset). Combinations of these patterns (used in BigBird, Longformer) achieve $O(n)$ or $O(n \sqrt{n})$ complexity while maintaining the ability to capture both local and global dependencies.
-
Sliding window attention restricts each token to attend only to the previous $w$ tokens (its local window). This is $O(nw)$ rather than $O(n^2)$, but long-range information must propagate through overlapping windows across layers. With $L$ layers and window size $w$, the effective receptive field is $L \times w$ tokens.
-
Ring attention distributes long sequences across multiple devices by arranging them in a ring topology. Each device holds a chunk of the sequence and computes attention for its chunk while simultaneously sending key-value blocks to the next device in the ring. This overlaps computation with communication and allows sequences of arbitrary length limited only by the total memory across all devices, not the memory of any single one.
-
Memory-augmented models extend context by equipping the Transformer with an external memory bank. At each layer, the model can read from and write to this memory using attention. Memorizing Transformers cache key-value pairs from previous chunks and attend to them in subsequent chunks, effectively extending context beyond the training window. The retrieval is approximate (using $k$-nearest neighbours over cached keys) to keep it efficient.
-
The methods above are architectural solutions to long context. Equally important is how models are trained to use long contexts effectively.
-
Progressive context extension is the standard approach. Training on very long sequences from the start is prohibitively expensive ($O(n^2)$ attention cost), so models are pre-trained at a short context length (typically 4K–8K tokens) and then continued pre-training extends to the target length in stages.
-
Llama 3.1 extends from 8K to 128K over 800B tokens with gradually increasing sequence length. DeepSeek-V3 trains at 4K, then extends to 32K, then 128K.
-
Each stage uses a modest number of tokens (relative to the full pre-training budget) because the model only needs to learn how to use longer positions, not relearn language itself.
-
The position encoding must be adjusted during extension. RoPE interpolation scales down the position indices so that the model sees the same rotation angles it was trained on, just spread over a longer sequence. If the model was trained at length $L$ and you want to extend to $L' = 4L$, you divide all position indices by 4.
-
This means the model never sees a rotation angle it has not encountered, but the effective resolution between adjacent positions drops.
-
RoPE extrapolation keeps the original position indices unchanged and simply applies RoPE to positions beyond $L$, relying on the model generalising to unseen angles.
-
Interpolation is much more stable; extrapolation degrades rapidly without base frequency adjustment (ABF).
-
YaRN (Yet another RoPE extensioN) improves on naive interpolation by recognising that not all RoPE dimensions should be treated equally.
-
High-frequency dimensions (small $i$ in $\theta_i = \theta_{\text{base}}^{-2i/d}$) rotate many times within the training length and can extrapolate well.
-
Low-frequency dimensions (large $i$) rotate slowly and are more sensitive to length extension.
-
YaRN interpolates only the low-frequency dimensions, extrapolates the high-frequency ones, and applies a temperature scaling $t$ to the attention logits to compensate for the distributional shift:
$$\text{score}'_{ij} = \frac{q_i^T k_j}{t \sqrt{d_k}}$$
-
where $t > 1$ flattens the attention distribution, preventing the model from attending too sharply to nearby tokens when position signals are compressed.
-
Long-context data curation is a critical and often underestimated challenge. Most pre-training corpora consist of short documents (news articles, web pages, social media posts).
-
Long-context training requires a data mix that actually exercises the full context window: books, code repositories, long-form scientific articles, multi-turn conversation logs, and concatenated thematically related documents.
-
If the model is only trained on short documents padded or packed to fill the context window, it learns to ignore distant tokens because they are never relevant.
-
Sequence packing is a training efficiency technique: multiple documents are concatenated into a single training sequence to avoid padding waste, with attention masks preventing cross-document attention.
-
For long-context training, the packing strategy matters: packing many unrelated short documents teaches the model that distant tokens are noise, while packing fewer, genuinely long documents teaches it to use the full context.
-
A known failure mode is the "lost in the middle" phenomenon (Liu et al., 2023): language models tend to use information at the beginning and end of the context window effectively but struggle with information placed in the middle.
-
This resembles the serial position effect in human memory (primacy and recency).
-
It arises partly from training data distributions (important information is often at the start or end of documents) and partly from attention patterns that concentrate on nearby and initial tokens.
-
Long-context training with diverse placement of key information mitigates but does not fully solve this.
-
Needle-in-a-haystack evaluation tests whether a model can retrieve a specific fact ("the needle") placed at various positions within a long distractor context ("the haystack").
-
A model with genuine long-context ability should achieve near-perfect retrieval regardless of where the needle is placed.
-
This test reveals the lost-in-the-middle effect clearly and is used to benchmark context extension methods.
-
Long-context fine-tuning after pre-training uses targeted SFT data: long multi-turn dialogues, document QA with evidence scattered across thousands of tokens, long-form summarisation, and repository-level code understanding.
-
Qwen3 uses Dual Chunk Attention (DCA) during this stage, which processes long sequences as pairs of chunks where intra-chunk attention is full and inter-chunk attention is efficient, achieving 4x the effective sequence capacity during fine-tuning.
-
State Space Models (SSMs) offer a fundamentally different approach to long-sequence modelling. Rather than modifying attention, they replace it entirely with a linear dynamical system inspired by continuous-time control theory.
-
An SSM maps an input sequence $u(t)$ to an output $y(t)$ through a latent state $x(t) \in \mathbb{R}^N$ governed by:
$$x'(t) = Ax(t) + Bu(t), \quad y(t) = Cx(t) + Du(t)$$
-
where $A \in \mathbb{R}^{N \times N}$ is the state transition matrix, $B \in \mathbb{R}^{N \times 1}$ is the input projection, $C \in \mathbb{R}^{1 \times N}$ is the output projection, and $D$ is a skip connection.
-
To apply this to discrete sequences (tokens), the continuous system is discretised using a step size $\Delta$. The zero-order hold discretisation gives:
$$\bar{A} = \exp(\Delta A), \quad \bar{B} = (\Delta A)^{-1}(\exp(\Delta A) - I) \cdot \Delta B$$
-
The discrete recurrence then becomes $x_k = \bar{A} x_{k-1} + \bar{B} u_k$, $y_k = C x_k + D u_k$, which looks like an RNN: process one token at a time with a hidden state.
-
Unlike RNNs, this recurrence can also be unrolled as a global convolution: because the system is linear, the output is $y = \bar{K} \ast u$ where the kernel $\bar{K} = (C\bar{B}, , C\bar{A}\bar{B}, , C\bar{A}^2\bar{B}, \ldots)$ depends only on the fixed parameters.
-
This dual view — recurrence for efficient autoregressive inference ($O(1)$ per step) and convolution for efficient parallel training ($O(n \log n)$ via FFT) — is the central insight of SSMs.
-
S4 (Structured State Spaces for Sequence Modeling, Gu et al., 2022) made SSMs practical by solving the key numerical challenge: the state matrix $A$ must capture long-range dependencies, but naively parameterising it leads to vanishing or exploding dynamics (the same problem as vanilla RNNs).
-
S4 initialises $A$ using the HiPPO (High-order Polynomial Projection Operators) matrix, which is derived from the theory of optimal polynomial approximation of continuous signals. The HiPPO matrix has a specific structure that provably enables the state to maintain a compressed representation of the entire input history with graceful decay:
-
This lower-triangular structure ensures that the state acts as an online approximation of the input signal using Legendre polynomials. Computing $\bar{A}^k$ for long kernels is expensive, so S4 uses the fact that the HiPPO matrix can be decomposed as a sum of low-rank and diagonal terms, enabling $O(n \log n)$ kernel computation.
-
Mamba (Gu and Dao, 2023) introduces the critical innovation of selective state spaces: making the SSM parameters input-dependent. In S4, the matrices $A$, $B$, $C$, and the step size $\Delta$ are fixed — the same dynamics apply to every token regardless of content. Mamba makes $B$, $C$, and $\Delta$ functions of the input:
$$B_k = \text{Linear}(u_k), \quad C_k = \text{Linear}(u_k), \quad \Delta_k = \text{softplus}(\text{Linear}(u_k))$$
-
This selectivity allows the model to decide, at each position, what information to store in the state and what to ignore — analogous to how attention selects relevant tokens, but without the quadratic cost. The step size $\Delta_k$ controls the "gate": a large $\Delta$ causes the state to integrate the current input strongly (the continuous dynamics advance a large step, effectively resetting the state), while a small $\Delta$ preserves the existing state and ignores the current input.
-
The trade-off is that input-dependent parameters break the convolution view (the kernel is no longer fixed), so Mamba cannot use FFT-based training. Instead, it uses a hardware-aware parallel scan algorithm that exploits the associativity of the recurrence: the state update $(x_k, u_k) \mapsto x_{k+1}$ can be expressed as a sequence of associative operations and parallelised using a prefix sum (scan), analogous to parallel prefix addition in hardware design. This runs in $O(n)$ time with $O(\log n)$ depth on a GPU, nearly matching the efficiency of convolution.
-
Mamba achieves inference that is truly $O(1)$ per token (just update the fixed-size state, no KV cache that grows with context), making it fundamentally more memory-efficient than Transformers at long sequence lengths. The state size $N$ (typically 16) is much smaller than a Transformer's KV cache, which stores $O(n \cdot d)$ values. In practice, Mamba matches or exceeds Transformer quality at the same parameter count on language modelling benchmarks, with significantly faster inference on long sequences.
-
Hybrid architectures combine SSM layers with attention layers, using SSMs for the majority of layers (efficient long-range propagation) and sprinkling in a few attention layers (precise content-based retrieval). Models like Jamba and Zamba interleave Mamba and Transformer blocks, achieving better quality than pure SSMs while maintaining much of the inference efficiency advantage. This suggests that attention and SSMs capture complementary capabilities: SSMs excel at smooth, long-range state propagation while attention excels at precise, content-dependent lookups.
-
Retrieval-Augmented Generation (RAG) addresses the knowledge limitations of language models by giving them access to an external knowledge base at inference time. Instead of relying solely on knowledge encoded in model parameters during training, RAG retrieves relevant documents and conditions generation on them.
-
The classic retriever-reader architecture has two components. The retriever takes a query and fetches the top-$k$ most relevant passages from a corpus. The reader (a language model) generates the answer conditioned on both the query and the retrieved passages. The retriever can use sparse methods (BM25, which extends TF-IDF from file 02) or dense methods.
-
Dense passage retrieval (DPR) uses a dual-encoder architecture: one encoder maps questions to vectors, another maps passages to vectors. Both are typically BERT-based. At indexing time, all passages are encoded and stored. At query time, the question is encoded and the nearest passages are found using approximate nearest neighbour search (such as FAISS). The similarity metric is the dot product between question and passage vectors.
-
Chunking strategies affect retrieval quality significantly. Documents must be split into passages small enough for the retriever to handle, but large enough to contain complete ideas. Fixed-size chunking (e.g., 256 tokens with 50-token overlap) is simple but may split sentences awkwardly. Semantic chunking splits at paragraph or section boundaries. Hierarchical chunking creates a tree of summaries at different granularities.
-
RAG provides several advantages: the knowledge base can be updated without retraining the model, the model can cite sources, and hallucination is reduced because the model can ground its answers in retrieved text. The main challenges are retrieval quality (if the wrong passages are retrieved, the model may produce wrong answers confidently) and latency (retrieval adds a step to inference).
-
Speculative decoding accelerates autoregressive generation by using a small, fast draft model to propose multiple tokens in parallel, which are then verified by the large target model in a single forward pass.
-
The algorithm works as follows: the draft model generates $k$ candidate tokens autoregressively (this is fast because the draft model is small).
-
The target model then scores all $k$ tokens simultaneously in a single forward pass (this is efficient because the work is batched).
-
For each candidate token $t$ sampled from the draft distribution $p_d(t)$, it is accepted with probability $\min(1, , p_{\text{target}}(t) / p_d(t))$. If rejected, a corrected token is resampled from the adjusted distribution $p_{\text{adj}}(t) = \max(0, , p_{\text{target}}(t) - p_d(t))$, normalised.
-
This acceptance-rejection scheme guarantees that the output distribution is identical to the target model alone.
-
To see why, consider the effective probability of emitting token $t$. It can be accepted directly (probability $p_d(t) \cdot \min(1, p_{\text{target}}(t)/p_d(t))$) or produced through resampling.
-
For tokens where $p_{\text{target}}(t) \leq p_d(t)$, the direct acceptance contributes $p_{\text{target}}(t)$. For tokens where $p_{\text{target}}(t) > p_d(t)$, direct acceptance contributes $p_d(t)$ and resampling contributes the remainder $p_{\text{target}}(t) - p_d(t)$ (after accounting for the rejection probability).
-
In both cases, the total probability of emitting $t$ equals $p_{\text{target}}(t)$. The draft model affects only speed, not quality.
-
The speedup depends on the acceptance rate: if the draft model is well-aligned with the target model, most tokens are accepted and the wall-clock time is roughly that of the draft model. Typical speedups are 2-3x with no quality degradation.
-
Medusa (Cai et al., 2024) takes a different approach: instead of a separate draft model, it adds multiple lightweight prediction heads to the target model itself. Each head predicts a different future token position simultaneously ($k = 1, 2, 3, \ldots$ steps ahead). At each step, Medusa proposes several candidate continuations using a tree structure, and a single forward pass through the target model's attention layers verifies which candidates are consistent. This avoids the need for a separate draft model entirely.
-
Parallel generation methods more broadly aim to break the sequential bottleneck of autoregressive decoding. Jacobi decoding initialises all positions with guesses and iteratively refines them in parallel until convergence, treating generation as a fixed-point iteration. Non-autoregressive models (NAT) generate all tokens simultaneously in a single forward pass but typically suffer quality degradation and require techniques like iterative refinement, CTC loss, or knowledge distillation from autoregressive teachers to close the gap.
-
The techniques above — alignment, long context, retrieval, efficient decoding, state space models — come together in modern production LLMs.
-
The remainder of this file surveys the architectural innovations in frontier models, showing how theoretical ideas from files 01–04 and the methods above are combined in practice.
-
Grouped Query Attention (GQA) is the most widely adopted attention efficiency technique. Standard multi-head attention (MHA) maintains separate key and value projections per head, requiring $n_{\text{heads}} \times d_{\text{head}}$ values cached per token. GQA groups multiple query heads to share a single key-value head.
-
With 64 query heads and 8 KV heads (a common configuration in Llama 3, Qwen, Gemma), each KV head is shared by 8 query heads, reducing the KV cache by 8x compared to MHA.
-
The output quality is nearly identical to MHA because the queries can still attend to different patterns, they just share the same key-value subspace. Multi-query attention (MQA) is the extreme case with a single KV head for all queries, but GQA provides a better quality-efficiency trade-off.
-
Multi-head Latent Attention (MLA), introduced in DeepSeek-V2, achieves even more aggressive KV cache compression. Instead of caching the full key-value projections (even with GQA), MLA down-projects the hidden state into a low-rank latent vector $c_t \in \mathbb{R}^{d_c}$ with $d_c \ll n_{\text{heads}} \times d_{\text{head}}$:
$$c_t = W_{\text{down}} , h_t$$
-
Only this compressed vector is cached. At attention time, the full key and value representations are reconstructed via up-projection: $k_t = W_{\text{up}}^K c_t$, $v_t = W_{\text{up}}^V c_t$. In DeepSeek-V3 (671B total parameters, 37B active), the compression dimension is $d_c = 512$ versus $128 \times 128 = 16{,}384$ for full MHA, a 93% reduction in KV cache.
-
A subtlety: standard RoPE is position-dependent and incompatible with the shared compression, so MLA uses decoupled RoPE: a small separate stream of the query and key (64 dimensions per head) carries position information via RoPE, while the bulk of the representation flows through the compressed latent path.
-
Position encoding at scale has diverged significantly from the original sinusoidal scheme. All frontier models use RoPE (file 04), but with key modifications for long context. The base frequency $\theta_{\text{base}}$ in the original RoPE formula $\theta_i = \theta_{\text{base}}^{-2i/d}$ is typically 10,000, which limits extrapolation beyond the training length.
-
Adjusted Base Frequency (ABF) simply increases $\theta_{\text{base}}$ to 500,000 (Llama 3) or 1,000,000 (Qwen3, Gemma 3), stretching the rotation periods so the model encounters fewer full rotations during training and can extrapolate further.
-
YaRN (Yet another RoPE extensioN) applies frequency-dependent interpolation: low-frequency dimensions are interpolated (scaled down), high-frequency dimensions are extrapolated, and a temperature factor adjusts the attention distribution. DeepSeek-V3, Qwen, and Kimi K2 all use YaRN-based extension to reach 128K context from models pre-trained at 4K–8K.
-
iRoPE (interleaved RoPE), introduced in Llama 4, takes a more radical approach: every 4th attention layer uses no positional encoding at all (NoPE), while the other layers use standard RoPE with chunked attention.
-
The NoPE layers can attend to all positions without any positional bias, while the RoPE layers provide local ordering. Combined with temperature scaling at inference, this enables Llama 4 Scout's 10M-token context window — orders of magnitude beyond any pure RoPE approach.
-
Mixture of Experts at scale has become the dominant architecture for frontier models (file 04 introduced MoE fundamentals). The key design choices are the number of experts, routing sparsity, and load balancing.
-
Routing sparsity varies significantly: DeepSeek-V3 uses 256 experts with top-8 routing (32x sparsity), Qwen3 uses 128 experts with top-8 (16x sparsity), Mixtral uses 8 experts with top-2 (4x sparsity), and Llama 4 Maverick uses 128 experts with top-1 plus a shared expert (128x sparsity).
-
Higher sparsity means more total parameters for the same active compute, but requires more careful load balancing and communication infrastructure.
-
Auxiliary-loss-free load balancing (DeepSeek-V3) replaces the traditional load balancing loss (file 04) which was found to degrade model quality. Instead, each expert maintains a dynamic bias term adjusted per training step: overloaded experts have their bias decreased (receiving fewer tokens), underloaded experts have their bias increased. This achieves balanced routing without any auxiliary loss polluting the main training signal.
-
Shared experts appear in most MoE designs: one or more expert FFNs that process every token regardless of routing. These handle common patterns that all tokens need (basic syntax, function words), freeing the routed experts to specialise. Llama 4 uses 1 shared expert plus 1 routed expert per token (very sparse); DeepSeek-V3 uses 1 shared plus 8 routed.
-
Alternating dense and MoE layers provide another design axis. Gemma 2 and 3 alternate local/global attention layers (5:1 ratio in Gemma 3, where local layers use a 1,024-token sliding window and only global layers cache the full 128K context).
-
Llama 4 Maverick interleaves dense FFN layers with MoE layers. Kimi K2 uses hybrid-sparsity layers (one dense layer interspersed among expert layers). This heterogeneous design allows different layers to serve different functions.
-
Multi-token prediction (MTP), used in DeepSeek-V3, trains the model to predict not just the next token but also the token after that. At each position, a secondary prediction module (sharing the main model's embeddings) predicts one additional future token. The MTP loss is weighted at 0.1–0.3 relative to the main next-token loss. Beyond improving representation quality during training, the MTP heads can serve as draft heads for speculative decoding at inference time, providing a free speedup.
-
Knowledge distillation is a training strategy where a large "teacher" model's outputs guide the training of a smaller "student" model. Gemma 2 and 3 use distillation extensively: the smaller models (2B, 4B) are trained on 50x the compute-optimal amount of data with the teacher's probability distributions as soft targets. This is why Gemma 3-4B matches Gemma 2-27B in quality.
-
The distillation loss replaces or supplements the standard cross-entropy: the student minimises the KL divergence between its output distribution and the teacher's:
$$\mathcal{L}{\text{distill}} = D{\text{KL}}(p_{\text{teacher}}(\cdot \mid x) | p_{\text{student}}(\cdot \mid x))$$
-
DeepSeek-R1 distilled its 671B reasoning model into dense models as small as 1.5B using 800K curated chain-of-thought samples, producing small models with disproportionately strong reasoning.
-
Reasoning via reinforcement learning represents the most significant recent advance in LLM capabilities. DeepSeek-R1 demonstrated that pure reinforcement learning on a base model (without supervised fine-tuning) can elicit chain-of-thought reasoning, self-verification, and error correction, behaviours that emerge spontaneously when the model is rewarded for correct final answers.
-
DeepSeek-R1 uses GRPO (Group Relative Policy Optimisation), which eliminates the value network required by PPO. For each prompt, GRPO samples a group of $G$ outputs, computes their rewards, and normalises advantages within the group:
$$A_i = \frac{r_i - \text{mean}(r_1, \ldots, r_G)}{\text{std}(r_1, \ldots, r_G)}$$
-
The policy gradient then uses these group-relative advantages with a clipped objective (similar to PPO's clipping).
-
Eliminating the critic network halves the memory and compute requirements of RL training, making it practical to train 671B-parameter models with RL.
-
A critical design choice: DeepSeek-R1 uses rule-based rewards (checking mathematical answers against ground truth, running code test cases) rather than neural reward models, because neural reward models were found to be susceptible to reward hacking at this scale.
-
Qwen3's hybrid thinking mode integrates reasoning (with
<think>tags for step-by-step chain-of-thought) and fast direct response into a single model, allowing users to control a "thinking budget" that trades latency for reasoning depth. -
This is achieved by training on both thinking and non-thinking data, not through separate model checkpoints.
-
Training stabilisation at scale requires new techniques beyond standard practices. Logit soft-capping (Gemma 2) passes attention scores through $s \cdot \tanh(\text{logits} / s)$ with a soft cap $s$ (typically 30–50) to prevent unbounded growth.
-
QK-Norm (Qwen3) applies RMSNorm to query and key vectors before computing attention scores, replacing the need for QKV bias. QK-Clip (Kimi K2's MuonClip optimiser) monitors the maximum attention logit during training and rescales query-key weight matrices when they exceed a threshold, enabling stable pre-training of 1T-parameter models with zero instability events.
-
FP8 mixed-precision training (DeepSeek-V3) uses 8-bit floating point for the compute-intensive matrix multiplications in the forward and backward passes while keeping master weights in higher precision.
-
This roughly doubles throughput compared to BF16/FP16 training with negligible quality loss. DeepSeek-V3 trained its 671B-parameter model for only 2.8M H800 GPU-hours — a fraction of comparable models — largely due to this and other engineering optimisations.
Coding Tasks (use CoLab or notebook)
- Implement a simple retrieval-augmented generation pipeline from scratch. Index a set of documents using TF-IDF (file 02), retrieve the most relevant passage for a query, and prepend it to a prompt.
import jax.numpy as jnp
import math
from collections import Counter
# Knowledge base: a set of short passages
knowledge_base = [
"The Eiffel Tower is a wrought-iron lattice tower in Paris, France. It was constructed from 1887 to 1889 as the centerpiece of the 1889 World's Fair.",
"The Great Wall of China is a series of fortifications built along the northern borders of China. Construction began in the 7th century BC.",
"Photosynthesis is the process by which plants convert sunlight, water, and carbon dioxide into glucose and oxygen using chlorophyll.",
"The theory of general relativity, published by Albert Einstein in 1915, describes gravity as the curvature of spacetime caused by mass and energy.",
"Python is a high-level programming language known for its simple syntax and readability. It was created by Guido van Rossum and released in 1991.",
"The mitochondria are organelles found in eukaryotic cells. They generate most of the cell's supply of ATP, used as a source of chemical energy.",
]
# Build TF-IDF index (reusing concepts from file 02)
def tokenise(text):
return text.lower().split()
vocab = sorted(set(w for doc in knowledge_base for w in tokenise(doc)))
word2idx = {w: i for i, w in enumerate(vocab)}
V = len(vocab)
N = len(knowledge_base)
# Document frequencies
doc_freq = Counter()
for doc in knowledge_base:
for w in set(tokenise(doc)):
doc_freq[w] += 1
def tfidf_vector(text):
words = tokenise(text)
counts = Counter(words)
vec = jnp.zeros(V)
for w, c in counts.items():
if w in word2idx:
tf = 1 + math.log(c)
idf = math.log(N / (doc_freq.get(w, 0) + 1))
vec = vec.at[word2idx[w]].set(tf * idf)
return vec
# Index all documents
doc_vectors = jnp.stack([tfidf_vector(doc) for doc in knowledge_base])
def cosine_sim(a, b):
return jnp.dot(a, b) / (jnp.linalg.norm(a) * jnp.linalg.norm(b) + 1e-8)
def retrieve(query, top_k=2):
"""Retrieve top-k most relevant passages for a query."""
q_vec = tfidf_vector(query)
sims = jnp.array([cosine_sim(q_vec, doc_vectors[i]) for i in range(N)])
top_indices = jnp.argsort(-sims)[:top_k]
return [(int(i), float(sims[i]), knowledge_base[int(i)]) for i in top_indices]
# Test retrieval
queries = [
"Who built the Eiffel Tower?",
"How do plants make food?",
"What did Einstein discover?",
]
for query in queries:
results = retrieve(query, top_k=1)
print(f"\nQuery: '{query}'")
for idx, sim, passage in results:
print(f" Retrieved (sim={sim:.3f}): '{passage[:80]}...'")
# RAG-style prompt construction
context = results[0][2]
rag_prompt = f"Context: {context}\n\nQuestion: {query}\nAnswer:"
print(f" RAG prompt:\n {rag_prompt[:120]}...")
- Implement speculative decoding with a toy draft and target model. Show that the accepted output matches the target model's distribution.
import jax
import jax.numpy as jnp
# Simulate a draft model (fast, less accurate) and target model (slow, accurate)
vocab_size = 8
seq_len = 5
key = jax.random.PRNGKey(42)
# Target model: returns logits given a sequence
def target_model(seq, key):
"""Simulated target model: produces token logits (expensive)."""
# In practice this would be a large Transformer forward pass
k1, k2 = jax.random.split(key)
logits = jax.random.normal(k1, (len(seq), vocab_size)) * 2
# Make it somewhat predictable: bias toward token (seq[-1] + 1) % vocab_size
for i in range(len(seq)):
logits = logits.at[i, (seq[i] + 1) % vocab_size].add(3.0)
return logits
def draft_model(seq, key):
"""Simulated draft model: similar but noisier (cheap)."""
k1, k2 = jax.random.split(key)
logits = jax.random.normal(k1, (len(seq), vocab_size))
for i in range(len(seq)):
logits = logits.at[i, (seq[i] + 1) % vocab_size].add(2.0)
return logits
def sample_token(logits, key):
return jax.random.categorical(key, logits)
def speculative_decode(prefix, draft_steps=3, key=jax.random.PRNGKey(0)):
"""Speculative decoding: draft proposes, target verifies."""
seq = list(prefix)
total_accepted = 0
total_proposed = 0
for _ in range(4): # generate 4 rounds
key, *subkeys = jax.random.split(key, draft_steps + 3)
# Draft model proposes draft_steps tokens
draft_tokens = []
draft_probs = []
draft_seq = list(seq)
for i in range(draft_steps):
d_logits = draft_model(jnp.array(draft_seq), subkeys[i])
d_probs = jax.nn.softmax(d_logits[-1])
tok = sample_token(d_logits[-1], subkeys[i])
draft_tokens.append(int(tok))
draft_probs.append(d_probs)
draft_seq.append(int(tok))
# Target model scores all draft tokens in one pass
target_logits = target_model(jnp.array(draft_seq), subkeys[draft_steps])
target_start = len(seq) - 1 # position of last prefix token
# Accept/reject each draft token
accepted = 0
for i in range(draft_steps):
t_probs = jax.nn.softmax(target_logits[target_start + i])
d_prob = draft_probs[i][draft_tokens[i]]
t_prob = t_probs[draft_tokens[i]]
# Accept with probability min(1, target_prob / draft_prob)
accept_prob = jnp.minimum(1.0, t_prob / (d_prob + 1e-10))
key, accept_key = jax.random.split(key)
if jax.random.uniform(accept_key) < accept_prob:
seq.append(draft_tokens[i])
accepted += 1
else:
# Reject: sample from adjusted distribution
key, resample_key = jax.random.split(key)
adjusted = jnp.maximum(0, t_probs - draft_probs[i])
adjusted = adjusted / (adjusted.sum() + 1e-10)
new_tok = jax.random.categorical(resample_key, jnp.log(adjusted + 1e-10))
seq.append(int(new_tok))
break
total_accepted += accepted
total_proposed += draft_steps
return seq, total_accepted, total_proposed
# Run speculative decoding
prefix = [0, 1]
result_seq, accepted, proposed = speculative_decode(prefix)
acceptance_rate = accepted / proposed if proposed > 0 else 0
print(f"Prefix: {prefix}")
print(f"Generated sequence: {result_seq}")
print(f"Draft proposals: {proposed}")
print(f"Accepted: {accepted}")
print(f"Acceptance rate: {acceptance_rate:.1%}")
print(f"Speedup potential: {(accepted + proposed) / proposed:.2f}x")
- Build a simple DPO training loop. Given pairs of preferred and dispreferred completions, update a small model using the DPO loss.
import jax
import jax.numpy as jnp
# Tiny language model: linear projection from one-hot to logits
vocab_size = 10
seq_len = 4
key = jax.random.PRNGKey(42)
k1, k2 = jax.random.split(key)
# Current policy parameters (trainable)
theta = jax.random.normal(k1, (vocab_size, vocab_size)) * 0.1
# Reference policy parameters (frozen copy of initial theta)
theta_ref = theta.copy()
def log_prob_sequence(params, sequence):
"""Compute log P(sequence) under a simple autoregressive model."""
total = 0.0
for t in range(1, len(sequence)):
# Simple: logits at position t depend on token at t-1
logits = params[sequence[t-1]]
log_probs = jax.nn.log_softmax(logits)
total += log_probs[sequence[t]]
return total
def dpo_loss(theta, theta_ref, preferred, dispreferred, beta=0.1):
"""Direct Preference Optimisation loss for one pair."""
log_pi_w = log_prob_sequence(theta, preferred)
log_pi_l = log_prob_sequence(theta, dispreferred)
log_ref_w = log_prob_sequence(theta_ref, preferred)
log_ref_l = log_prob_sequence(theta_ref, dispreferred)
# DPO objective
return -jax.nn.log_sigmoid(
beta * ((log_pi_w - log_ref_w) - (log_pi_l - log_ref_l))
)
# Preference dataset: (prompt_prefix, preferred_completion, dispreferred_completion)
preferences = [
(jnp.array([1, 3, 5, 7]), jnp.array([1, 3, 5, 2])), # prefer 7 over 2 at end
(jnp.array([0, 2, 4, 6]), jnp.array([0, 2, 4, 9])), # prefer 6 over 9
(jnp.array([3, 3, 3, 3]), jnp.array([3, 3, 3, 0])), # prefer repeating over 0
(jnp.array([5, 6, 7, 8]), jnp.array([5, 6, 7, 1])), # prefer 8 over 1
]
grad_fn = jax.jit(jax.grad(dpo_loss))
lr = 0.05
print("Training DPO...")
for epoch in range(100):
total_loss = 0.0
for preferred, dispreferred in preferences:
loss = dpo_loss(theta, theta_ref, preferred, dispreferred)
grads = grad_fn(theta, theta_ref, preferred, dispreferred)
theta = theta - lr * grads
total_loss += loss
if (epoch + 1) % 20 == 0:
avg_loss = total_loss / len(preferences)
print(f" Epoch {epoch+1}: avg DPO loss = {avg_loss:.4f}")
# Check: the model should now prefer the preferred completions
print("\nPreference check after DPO training:")
for preferred, dispreferred in preferences:
lp_w = log_prob_sequence(theta, preferred)
lp_l = log_prob_sequence(theta, dispreferred)
print(f" Preferred {list(preferred.astype(int))}: logP={lp_w:.3f} "
f"Dispreferred {list(dispreferred.astype(int))}: logP={lp_l:.3f} "
f"{'correct' if lp_w > lp_l else 'WRONG'}")
Image Fundamentals
- Digital images: pixels, colour spaces (RGB, HSV, YCbCr, LAB), bit depth
- Image formation: pinhole camera model, lens distortion, intrinsic and extrinsic parameters
- Spatial filtering: convolution (2D), kernels, edge detection (Sobel, Canny), blurring (Gaussian, median)
- Frequency domain: Fourier transform of images, low-pass and high-pass filtering
- Histograms, histogram equalisation, thresholding (Otsu)
- Feature extraction: corners (Harris, Shi-Tomasi), blobs (SIFT, SURF, ORB), HOG descriptors
- Image pyramids and scale space
Convolutional Networks
- Convolution operation: filters, stride, padding, receptive field
- Pooling: max pooling, average pooling, global average pooling
- Batch normalisation, dropout, data augmentation
- Landmark architectures: LeNet, AlexNet, VGG, GoogLeNet/Inception, ResNet (skip connections), DenseNet
- Efficient architectures: MobileNet (depthwise separable convolutions), EfficientNet (compound scaling), ShuffleNet
- Transfer learning: feature extraction, fine-tuning pretrained backbones
- Visualising CNNs: activation maps, Grad-CAM, feature inversion
Object Detection and Segmentation
- Object detection problem: bounding boxes, IoU, mAP
- Two-stage detectors: R-CNN, Fast R-CNN, Faster R-CNN (region proposal networks), anchor boxes
- One-stage detectors: YOLO family, SSD, RetinaNet (focal loss)
- Anchor-free detection: FCOS, CenterNet, CornerNet
- Semantic segmentation: FCN, U-Net (skip connections), DeepLab (atrous/dilated convolutions, CRF)
- Instance segmentation: Mask R-CNN
- Panoptic segmentation: unifying semantic and instance segmentation
- Real-time segmentation: BiSeNet, DDRNet
Vision Transformers and Generation
- Vision Transformer (ViT): patch embedding, class token, position embeddings
- Hybrid architectures: DeiT, Swin Transformer (shifted windows, hierarchical), PVT
- Self-supervised visual learning: contrastive (SimCLR, MoCo, BYOL, DINO), masked image modelling (MAE, BEiT)
- Image generation: GANs (generator, discriminator, mode collapse, training tricks, StyleGAN), VAEs
- Diffusion models: forward/reverse process, DDPM, DDIM, score-based models, classifier-free guidance, latent diffusion (Stable Diffusion)
- Flow matching: continuous normalising flows, optimal transport, rectified flows
Video and 3D Vision
- Video understanding: temporal modelling, optical flow (Lucas-Kanade, Farneback), two-stream networks
- Video architectures: 3D CNNs (C3D, I3D), TimeSformer, VideoMAE, SlowFast networks
- Action recognition and temporal action detection
- Video object tracking: SORT, DeepSORT, ByteTrack
- 3D vision: depth estimation (monocular, stereo), point clouds, NeRFs, 3D Gaussian splatting
- SLAM: visual odometry, feature-based SLAM, ORB-SLAM, LiDAR SLAM
- VR/AR: pose estimation, scene reconstruction, real-time rendering considerations
Digital Signal Processing
- Sound as a signal: waveforms, amplitude, frequency, phase
- Sampling: Nyquist theorem, sample rate, aliasing, quantisation
- Time-domain analysis: energy, zero-crossing rate, autocorrelation
- Frequency-domain analysis: DFT, FFT, spectrograms, mel scale
- Mel-frequency cepstral coefficients (MFCCs): derivation and intuition
- Filtering: FIR and IIR filters, bandpass, low-pass, high-pass
- Windowing: Hamming, Hanning, rectangular, overlap-add
- Short-time Fourier transform (STFT) and time-frequency tradeoff
Automatic Speech Recognition
- ASR pipeline: audio → features → acoustic model → decoder → text
- Traditional ASR: GMM-HMM, WFST decoding
- End-to-end ASR: CTC loss and decoding, RNN-Transducer (RNN-T)
- Attention-based ASR: Listen Attend and Spell (LAS)
- Modern architectures: Whisper, Conformer (convolution + attention), wav2vec 2.0 (self-supervised pretraining)
- Language model integration: shallow fusion, deep fusion, rescoring
- Streaming vs offline ASR: chunked attention, lookahead, latency constraints
- Evaluation: word error rate (WER), character error rate (CER)
Text to Speech and Voice
- TTS pipeline: text normalisation → phoneme conversion → acoustic model → vocoder
- Vocoders: WaveNet, WaveRNN, WaveGlow, HiFi-GAN, neural source-filter models
- Acoustic models: Tacotron 1/2, FastSpeech 1/2 (non-autoregressive, duration prediction)
- Modern TTS: VITS (end-to-end), VALL-E (codec language model), StyleTTS
- Prosody modelling: pitch, duration, energy, style embeddings
- Voice conversion: speaker embeddings, disentangled representations
- Voice cloning: few-shot and zero-shot approaches
- Voice activity detection (VAD) and acoustic activity detection
Speaker and Audio Analysis
- Speaker recognition: verification vs identification
- Speaker embeddings: i-vectors, d-vectors, x-vectors (TDNN-based), ECAPA-TDNN
- Speaker diarisation: who spoke when, clustering-based, end-to-end neural diarisation
- Audio classification: environmental sounds, music genre, audio tagging
- Audio event detection: Sound Event Detection (SED), AudioSet
- Acoustic scene classification
- Audio embeddings: VGGish, PANNs, audio spectrogram transformer (AST)
- Music information retrieval: beat tracking, chord recognition, source separation basics
Source Separation and Noise Cancellation
- Cocktail party problem: separating overlapping sources
- Classical methods: ICA, NMF, beamforming (delay-and-sum, MVDR)
- Deep learning methods: deep clustering, Conv-TasNet, DPRNN, SepFormer
- Permutation invariant training (PIT)
- Music source separation: Demucs, Open-Unmix
- Active noise cancellation (ANC): feedforward vs feedback, adaptive filtering (LMS, NLMS)
- Noise reduction and speech enhancement: spectral subtraction, Wiener filtering, neural speech enhancement
- Echo cancellation: acoustic echo cancellation (AEC), double-talk detection
Multimodal Representations
- What is multimodal learning: combining vision, language, audio, and other modalities
- Early vs late fusion: feature-level vs decision-level combination
- Joint embedding spaces: learning shared representations across modalities
- Contrastive learning: CLIP (image-text contrastive), ALIGN, SigLIP
- Loss functions: InfoNCE, NT-Xent, contrastive loss with temperature
- Image-text retrieval: zero-shot classification via embeddings
- Audio-visual correspondence: learning from paired audio and video
- Evaluation: zero-shot benchmarks, retrieval metrics (recall@k)
Vision Language Models
- Visual question answering (VQA): task formulation, datasets
- Image captioning: show-and-tell, attention-based captioning
- Architecture patterns: dual encoder, fusion encoder, encoder-decoder
- Flamingo: interleaving visual and text tokens, few-shot multimodal learning
- LLaVA and visual instruction tuning: projecting vision features into LLM space
- PaLI, Qwen-VL, InternVL: scaling vision-language models
- Grounding and referring: pointing, bounding box prediction from language
- OCR-free document understanding: Donut, Pix2Struct
Image and Video Tokenisation
- Why tokenise images: bridging continuous pixels and discrete language model vocabularies
- VQ-VAE: vector quantisation, codebook learning, commitment loss
- VQ-GAN: combining VQ-VAE with adversarial training for higher fidelity
- Residual quantisation and multi-scale codebooks
- Image tokenisers: DALL-E tokeniser, LlamaGen, Cosmos tokeniser
- Video tokenisation: temporal compression, 3D VQ-VAE, causal video tokenisers
- Continuous vs discrete tokens: when to quantise and when to project
- Applications: autoregressive image generation, unified vision-language tokens
Cross-Modal Generation
- Text-to-image: DALL-E (autoregressive), Stable Diffusion (latent diffusion + CLIP guidance), Imagen, Parti
- Text-to-video: Make-A-Video, VideoPoet, Sora-style temporal diffusion, Wan
- Text-to-audio: AudioLM, MusicLM, MusicGen
- Image-to-text generation: captioning as conditional generation
- Video-audio co-generation: joint temporal modelling
- Instruction-following generation: InstructPix2Pix, editing by description
- Consistency and alignment: measuring text-image alignment (CLIPScore), FID, IS
- Ethical considerations: deepfakes, bias in generation, content filtering
Unified Multimodal Architectures
- The case for unification: one model, many modalities, shared weights
- Any-to-any models: CoDi, NExT-GPT, Gemini, GPT-4o architecture patterns
- Modality-specific encoders and decoders with shared transformer backbone
- Multimodal tokenisation: interleaving text, image, audio tokens in one sequence
- Training recipes: staged pretraining, modality-specific warm-up, joint fine-tuning
- Multimodal chain-of-thought reasoning
- Multimodal agents: tool use, grounding actions in visual context
- Benchmarks: MMLU, MMBench, SEED-Bench, multimodal evaluation suites
Perception
- Sensor modalities: cameras (mono, stereo, fisheye), LiDAR (spinning, solid-state), radar, ultrasonic, IMU
- Sensor fusion: early fusion (raw data), late fusion (decision level), multi-sensor calibration
- 3D object detection: PointPillars, CenterPoint, BEVFusion, bird's-eye-view representations
- Depth estimation: stereo matching, monocular depth networks, LiDAR-camera projection
- Occupancy networks: 3D occupancy prediction, voxel representations
- Lane detection and road topology: curve fitting, polynomial models, graph-based topology
- Semantic mapping: building environmental representations from sensor streams
Robot Learning
- Robot kinematics: forward and inverse kinematics, DH parameters, joint spaces
- Dynamics and control: PID control, model predictive control (MPC), impedance control
- Imitation learning: behavioural cloning, DAgger, learning from demonstrations
- Sim-to-real transfer: domain randomisation, system identification, reality gap
- Reward shaping and curriculum learning for robotics
- Manipulation: grasping (analytical, data-driven), dexterous manipulation, contact-rich tasks
- Locomotion: legged robots, quadrupeds, humanoid balance, CPG-based control
- Safety: safe exploration, constrained RL, risk-aware planning
Vision-Language-Action Models
- From vision-language to action: grounding language instructions in physical actions
- VLAs: architecture (vision encoder + LLM + action head), RT-2, Octo, OpenVLA
- Action tokenisation: discretising continuous actions, action chunking
- Pretraining recipes: web-scale vision-language data → robot manipulation data
- Generalisation: unseen objects, environments, instructions
- Co-training with internet data and robot data
- Embodiment-agnostic models: one model for multiple robot form factors
- Benchmarks: SIMPLER, real-world evaluation protocols
Self-Driving Cars
- Autonomous driving stack: perception → prediction → planning → control
- HD maps vs mapless driving: pre-built maps, online map construction
- Motion prediction: trajectory forecasting, social forces, graph neural networks for agent interaction
- Planning: rule-based planners, optimisation-based (trajectory optimisation), learning-based (neural planners)
- End-to-end driving: UniAD, from sensor inputs directly to control outputs
- Simulation: CARLA, nuPlan, closed-loop vs open-loop evaluation
- Safety: functional safety (ISO 26262), SOTIF, operational design domain (ODD)
- Levels of autonomy: SAE L1–L5, current industry state
Space and Extreme Robotics
- Space robotics: orbital servicing, planetary rovers (Mars rover autonomy), satellite inspection
- Communication constraints: high latency, limited bandwidth, onboard autonomy
- Radiation-hardened computing: constraints on hardware, model compression for space
- Autonomous navigation in unstructured terrain: visual-inertial odometry, hazard avoidance
- Underwater robotics: AUVs, ROVs, acoustic communication, SLAM in low-visibility
- Search and rescue robotics: disaster response, multi-robot coordination
- Swarm robotics: decentralised control, emergent behaviour, consensus algorithms
- Human-robot interaction: shared autonomy, teleoperation, trust calibration
Discrete Maths
- Logic: propositional logic, truth tables, logical equivalences, predicate logic, quantifiers
- Proofs: direct proof, proof by contradiction, proof by induction, pigeonhole principle
- Sets: operations (union, intersection, complement, Cartesian product), power sets, cardinality
- Relations: equivalence relations, partial orders, total orders
- Functions: injective, surjective, bijective, composition, inverse
- Combinatorics: permutations, combinations, binomial theorem, inclusion-exclusion
- Graph theory: vertices, edges, paths, cycles, trees, planarity, colouring, Euler and Hamiltonian paths
- Recurrence relations and generating functions
Computer Architecture
- Number systems: binary, hexadecimal, two's complement, IEEE 754 floating point
- Logic gates: AND, OR, NOT, NAND, XOR, multiplexers, adders
- CPU architecture: ALU, registers, program counter, instruction cycle (fetch-decode-execute)
- Instruction set architectures: CISC vs RISC, x86, ARM, RISC-V
- Pipelining: stages, hazards (data, control, structural), forwarding, branch prediction
- Memory hierarchy: registers → L1/L2/L3 cache → RAM → disk, cache associativity, cache coherence
- Virtual memory: page tables, TLB, page faults, address translation
- Bus architecture, I/O, interrupts, DMA
Operating Systems
- What an OS does: abstraction, resource management, isolation
- Processes: creation (fork/exec), process states, PCB, context switching
- Threads: kernel threads vs user threads, pthreads, thread pools
- Scheduling: FCFS, SJF, round robin, priority scheduling, multilevel feedback queues, CFS (Linux)
- Memory management: paging, segmentation, demand paging, page replacement (LRU, clock)
- File systems: inodes, FAT, ext4, journaling, B-tree based file systems
- I/O subsystem: buffering, spooling, device drivers
- System calls, user mode vs kernel mode, interrupts and traps
Concurrency and Parallelism
- Concurrency vs parallelism: interleaving vs simultaneous execution
- Synchronisation primitives: mutexes, semaphores, condition variables, monitors
- Classic problems: producer-consumer, readers-writers, dining philosophers
- Deadlock: conditions (mutual exclusion, hold-and-wait, no preemption, circular wait), detection, prevention, avoidance (banker's algorithm)
- Lock-free and wait-free data structures: CAS operations, atomic variables
- Parallel programming models: shared memory (OpenMP), message passing (MPI)
- Async and event-driven: event loops, coroutines, async/await
- Amdahl's law, Gustafson's law, scalability limits
Programming Languages
- Language paradigms: imperative, object-oriented, functional, logic
- Type systems: static vs dynamic, strong vs weak, type inference, generics
- Memory management: stack vs heap, manual (C/C++), garbage collection (tracing, reference counting), ownership (Rust borrow checker)
- Compilation: lexing, parsing (ASTs), semantic analysis, code generation, LLVM
- Interpretation: bytecode VMs (JVM, CPython), JIT compilation
- Key language features: closures, pattern matching, algebraic data types, traits/interfaces
- Domain-specific languages: SQL, regex, shader languages
- Language design tradeoffs: performance vs safety vs expressiveness
Arrays and Hashing
- Arrays: contiguous memory, indexing, dynamic arrays (amortised doubling), cache locality
- Strings: encoding (ASCII, UTF-8), string matching (KMP, Rabin-Karp, Boyer-Moore)
- Hash tables: hash functions, collision resolution (chaining, open addressing, linear probing, robin hood hashing)
- Hash maps and hash sets: average vs worst-case complexity, load factor, rehashing
- Bloom filters: probabilistic membership, false positive rate, applications
- Two pointers technique, sliding window
- Prefix sums and difference arrays
Linked Lists, Stacks, and Queues
- Singly linked lists: insertion, deletion, traversal, reversal
- Doubly linked lists: bidirectional traversal, sentinel nodes
- Circular linked lists
- Skip lists: probabilistic balancing, expected O(log n) search
- Stacks: LIFO, array-based and linked-list-based implementations
- Applications of stacks: function call stack, expression evaluation, parenthesis matching, monotonic stack
- Queues: FIFO, circular buffer, deque (double-ended queue)
- Priority queues and binary heaps: insert, extract-min, heapify, heap sort
Trees
- Binary trees: traversals (inorder, preorder, postorder, level-order), height, depth
- Binary search trees (BST): search, insert, delete, successor/predecessor
- Balanced BSTs: AVL trees (rotations), red-black trees (colour invariants)
- B-trees and B+ trees: disk-friendly, database indexing, order and fill factor
- Tries: prefix trees, autocomplete, word search
- Segment trees: range queries, lazy propagation
- Fenwick trees (Binary Indexed Trees): prefix sums, point updates
- Union-Find (Disjoint Set Union): path compression, union by rank
Graphs
- Representations: adjacency matrix, adjacency list, edge list, incidence matrix
- Traversals: BFS (shortest path in unweighted graphs), DFS (cycle detection, topological sort)
- Shortest paths: Dijkstra (non-negative weights), Bellman-Ford (negative weights, cycle detection), Floyd-Warshall (all pairs)
- Minimum spanning trees: Kruskal (union-find), Prim (priority queue)
- Topological sort: Kahn's algorithm (BFS), DFS-based
- Strongly connected components: Tarjan's algorithm, Kosaraju's algorithm
- Network flow: Ford-Fulkerson, Edmonds-Karp, max-flow min-cut theorem
- Bipartite matching: Hungarian algorithm, Hopcroft-Karp
Sorting and Search
- Comparison sorts: bubble sort, insertion sort, merge sort, quicksort (pivot strategies, Hoare vs Lomuto partition), heapsort
- Lower bound for comparison sorting: O(n log n) via decision trees
- Non-comparison sorts: counting sort, radix sort, bucket sort
- Binary search: standard, lower/upper bound, search on answer (monotonic functions)
- Divide and conquer: master theorem, merge sort analysis, closest pair of points
- Greedy algorithms: activity selection, Huffman coding, interval scheduling
- Dynamic programming: overlapping subproblems, optimal substructure, memoisation vs tabulation, classic problems (knapsack, LCS, edit distance, coin change)
- Backtracking: N-queens, Sudoku, constraint satisfaction
Hardware Fundamentals
- Moore's law and the end of frequency scaling: why parallelism matters
- CPU architecture recap: superscalar execution, out-of-order execution, branch prediction, speculative execution
- SIMD concept: single instruction, multiple data, data-level parallelism
- Vector registers and vector width: 128-bit, 256-bit, 512-bit
- Memory bandwidth vs compute: roofline model, arithmetic intensity
- Latency vs throughput: pipelining, instruction-level parallelism
- Chip families overview: x86 (Intel, AMD), ARM, RISC-V, Apple Silicon
- Thermal and power constraints: TDP, power efficiency, dark silicon
ARM and NEON
- ARM architecture: load-store ISA, register file, condition codes, Thumb mode
- ARM NEON: 128-bit SIMD, data types (int8, int16, float16, float32), register layout
- NEON intrinsics: load/store (vld1, vst1), arithmetic (vadd, vmul, vmla), shuffle and permute
- SVE and SVE2: scalable vector extensions, predicate registers, vector-length agnostic programming
- Apple Silicon specifics: AMX (Apple Matrix eXtensions), performance cores vs efficiency cores
- Practical examples: vectorised dot product, matrix multiply, image processing kernels
- Auto-vectorisation: compiler flags, pragmas, loop patterns that help/hinder vectorisation
x86 and AVX
- x86 SIMD evolution: MMX → SSE → SSE2/3/4 → AVX → AVX2 → AVX-512 → AMX
- AVX/AVX2 programming: 256-bit YMM registers, intrinsics (mm256*), FMA instructions
- AVX-512: 512-bit ZMM registers, mask registers, gather/scatter, conflict detection
- Intel AMX: tile registers, TMUL (tile matrix multiply), BF16/INT8 acceleration
- Memory alignment: aligned vs unaligned loads, cache line considerations
- Performance pitfalls: AVX frequency throttling, register pressure, lane crossing penalties
- Benchmarking and profiling: RDTSC, perf, VTune, likwid
GPU Architecture and CUDA
- GPU vs CPU: throughput-oriented design, thousands of cores, SIMT execution model
- GPU memory hierarchy: global memory, shared memory, registers, L1/L2 cache, constant memory
- CUDA programming model: grids, blocks, threads, warps (32 threads), warp divergence
- Kernel launch: grid/block dimensions, occupancy, register usage
- Memory access patterns: coalesced access, bank conflicts in shared memory, memory fences
- Synchronisation: __syncthreads, atomic operations, cooperative groups
- Streams and concurrency: overlapping compute and data transfer, multi-stream execution
- Profiling: nsight compute, nsight systems, occupancy calculator
- NVIDIA GPU generations: Volta (tensor cores), Ampere (TF32, sparsity), Hopper (transformer engine, FP8), Blackwell
Triton, TPUs, and Vulkan
- Triton: Python-based GPU kernel programming, block-level abstraction, auto-tuning
- Writing Triton kernels: tl.load, tl.store, tl.dot, masking, grid/block programs
- Triton vs CUDA: productivity vs control tradeoff, when to use each
- Flash Attention as a case study: memory-efficient attention via tiling, online softmax
- TPU architecture: systolic arrays, MXU (matrix multiply unit), HBM, ICI interconnect
- TPU programming: XLA compiler, JAX/pjit, GSPMD, sharding annotations
- Vulkan compute: compute shaders, SPIR-V, descriptor sets, command buffers
- Comparison: GPU (CUDA/Triton) vs TPU (JAX/XLA) vs Vulkan, choosing the right tool
Systems Design Fundamentals
- Client-server architecture, request-response model
- Networking basics: TCP/IP, UDP, HTTP/HTTPS, WebSockets, gRPC, protocol buffers
- DNS: resolution, caching, load balancing via DNS
- Proxies: forward proxy, reverse proxy (Nginx, HAProxy), API gateways
- Load balancing: round robin, least connections, consistent hashing, L4 vs L7
- Caching: cache-aside, write-through, write-back, eviction policies (LRU, LFU, TTL), CDNs
- Databases: SQL vs NoSQL, ACID, CAP theorem, sharding, replication (leader-follower, multi-leader)
- Message queues: Kafka, RabbitMQ, pub/sub, event-driven architecture, exactly-once delivery
- Consistency models: strong, eventual, causal, read-your-writes
- Rate limiting, circuit breakers, backpressure
Cloud Computing
- Cloud service models: IaaS, PaaS, SaaS, FaaS (serverless)
- Major providers overview: AWS, GCP, Azure — compute, storage, networking primitives
- Virtualisation: hypervisors (Type 1, Type 2), VMs vs containers
- Containers: Docker (images, layers, Dockerfile), container registries
- Orchestration: Kubernetes (pods, services, deployments, StatefulSets, DaemonSets), Helm
- Storage: block (EBS), object (S3/GCS), file (EFS/NFS), data lakes
- Networking in cloud: VPCs, subnets, security groups, load balancers, service mesh (Istio, Envoy)
- Serverless: Lambda/Cloud Functions, cold starts, event triggers
- Cost management: spot/preemptible instances, reserved capacity, autoscaling policies
- Infrastructure as code: Terraform, CloudFormation, Pulumi
Large Scale Infrastructure
- Scalability: vertical vs horizontal scaling, stateless services
- Distributed systems: consensus (Paxos, Raft), leader election, distributed locks
- Microservices: service decomposition, API contracts, service discovery, saga pattern
- Data pipelines: batch processing (MapReduce, Spark), stream processing (Flink, Kafka Streams)
- Database scaling: read replicas, partitioning strategies (range, hash, directory), cross-shard queries
- Search systems: inverted indices, Elasticsearch, vector search (FAISS, Milvus, Pinecone)
- Observability: logging (ELK), metrics (Prometheus, Grafana), tracing (Jaeger, OpenTelemetry)
- Reliability: SLOs, SLIs, SLAs, error budgets, chaos engineering
- CI/CD: build pipelines, blue-green deployments, canary releases, feature flags
ML Systems Design
- ML system lifecycle: problem framing → data → training → evaluation → deployment → monitoring
- Data management: feature stores, data versioning (DVC), labelling pipelines, data quality checks
- Training infrastructure: distributed training (data parallel, model parallel), experiment tracking (MLflow, W&B)
- Model evaluation: offline metrics, A/B testing, shadow deployment, interleaving experiments
- Model serving: batch vs real-time inference, model registry, model versioning
- Feature engineering: online vs offline features, feature freshness, feature serving latency
- ML pipelines: orchestration (Airflow, Kubeflow, Metaflow), reproducibility
- Monitoring: data drift, concept drift, model degradation, alerting
ML Design Examples
- Recommendation system: candidate generation → ranking → re-ranking, collaborative filtering, content-based, embeddings, two-tower model
- Search ranking: query understanding, retrieval (BM25, dense retrieval), learning to rank (pointwise, pairwise, listwise)
- Ads click prediction: feature engineering (user, ad, context), real-time bidding, calibration, explore-exploit
- Fraud detection system: real-time streaming, feature pipelines, imbalanced classification, human-in-the-loop
- Content moderation: multi-modal classification (text + image), policy-as-code, escalation workflows
- Conversational AI system: intent detection, dialogue management, retrieval-augmented generation, guardrails
- Large-scale image search: embedding extraction, approximate nearest neighbour (ANN), indexing, serving
Quantisation
- Why quantise: memory reduction, throughput gains, energy savings
- Number formats: FP32, FP16, BF16, FP8 (E4M3, E5M2), INT8, INT4, binary/ternary
- Post-training quantisation (PTQ): calibration, min-max, percentile, MSE-optimal scaling
- Quantisation-aware training (QAT): fake quantisation, straight-through estimator
- Weight-only quantisation: GPTQ, AWQ, QuIP, squeeze-and-multiply
- Activation quantisation: dynamic vs static, per-tensor vs per-channel vs per-token
- Mixed-precision: choosing precision per layer, sensitivity analysis
- KV-cache quantisation: reducing memory for long sequences
Efficient Architectures
- StreamingLLM: attention sinks, rolling KV-cache, infinite-length generation
- Sparse attention: local attention, sliding window (Mistral), dilated, BigBird, Longformer
- Linear attention: kernel approximation, RWKV, RetNet, Mamba (state-space models)
- Multi-query attention (MQA) and grouped-query attention (GQA): reducing KV-cache size
- Mixture of Experts at inference: expert caching, routing efficiency
- Knowledge distillation: teacher-student, task-specific vs general distillation
- Pruning: unstructured (magnitude), structured (channel/head pruning), lottery ticket hypothesis
- Neural architecture search (NAS) for efficient models
Serving and Batching
- LLM serving fundamentals: prefill vs decode phases, time to first token (TTFT) vs tokens per second
- Continuous batching: dynamic request scheduling, iteration-level batching
- PagedAttention: virtual memory for KV-cache, vLLM architecture
- Batching strategies: static batching, dynamic batching, sequence bucketing
- Scheduling: first-come-first-served, shortest-job-first, preemption
- Disaggregated serving: separating prefill and decode stages
- Multi-model serving: model multiplexing, LoRA serving (S-LoRA, Punica)
- Metrics: throughput (tokens/s), latency (p50/p99), SLO compliance, cost per token
Edge Inference
- Edge constraints: limited memory, power budget, no network dependency
- Model compression pipeline: pruning → quantisation → compilation
- On-device runtimes: TensorFlow Lite, ONNX Runtime, Core ML, TensorRT, ExecuTorch
- Compiler stack: graph optimisation, operator fusion, memory planning, tiling
- Hardware targets: mobile GPUs (Adreno, Mali), NPUs (Qualcomm Hexagon, Apple Neural Engine, Google Edge TPU)
- On-device LLMs: Phi, Gemma, Llama at 1-3B parameter scale, 4-bit inference
- Federated learning: on-device training, privacy-preserving aggregation, communication efficiency
- Latency optimisation: model partitioning, early exit, caching strategies
Scaling and Deployment
- Model parallelism: tensor parallelism (Megatron-style column/row splitting), pipeline parallelism (GPipe, microbatching), sequence parallelism
- Data parallelism at inference: replicating models across GPUs
- Distributed KV-cache: sharding across nodes, communication overhead
- Speculative decoding: draft model + verification, Medusa heads, EAGLE, self-speculative decoding
- Prefix caching: sharing KV-cache across requests with common prefixes
- Inference frameworks: vLLM, TensorRT-LLM, SGLang, llama.cpp, TGI
- Cost optimisation: spot instances, autoscaling, right-sizing GPU selection
- Monitoring: token-level logging, latency histograms, degradation detection
Quantum Machine Learning
- Quantum computing basics: qubits, superposition, entanglement, measurement
- Quantum gates: Pauli (X, Y, Z), Hadamard, CNOT, Toffoli, rotation gates
- Quantum circuits: circuit model, parameterised circuits, depth and width
- Variational quantum algorithms: VQE, QAOA, variational classifiers
- Quantum kernel methods: quantum feature maps, quantum support vector machines
- Quantum neural networks: parameterised quantum circuits as neural layers
- Barren plateaus: vanishing gradients in quantum circuits, expressibility vs trainability
- Quantum advantage debate: NISQ era limitations, fault-tolerant quantum computing timeline
- Hybrid classical-quantum architectures: quantum layers in classical pipelines
Neuromorphic Computing
- Biological inspiration: spiking neurons, synaptic plasticity, temporal coding
- Spiking neural networks (SNNs): integrate-and-fire models (LIF, IF), spike timing
- Learning in SNNs: STDP (spike-timing-dependent plasticity), surrogate gradient methods, conversion from ANNs
- Neuromorphic hardware: Intel Loihi 2, IBM TrueNorth, SpiNNaker, BrainScaleS
- Event-driven computation: asynchronous processing, energy efficiency
- Event cameras (DVS): neuromorphic vision sensors, sparse temporal data
- Applications: low-power edge inference, robotics, always-on sensing
- Comparison with conventional deep learning: latency, power, accuracy tradeoffs
AI for Finance
- Time series forecasting: ARIMA, exponential smoothing, Prophet, neural approaches (LSTM, Temporal Fusion Transformer, PatchTST)
- Algorithmic trading: signal generation, execution algorithms (TWAP, VWAP), market microstructure
- Portfolio optimisation: mean-variance (Markowitz), Black-Litterman, RL-based portfolio management
- Risk modelling: Value at Risk (VaR), Expected Shortfall, Monte Carlo simulation, credit scoring
- Fraud detection: anomaly detection, graph-based approaches, real-time streaming
- NLP in finance: sentiment analysis of news/earnings calls, financial document understanding
- Alternative data: satellite imagery, social media, web scraping
- Regulatory and ethical: model explainability (SHAP, LIME), fairness in credit, regulatory compliance
AI for Biology
- Protein structure prediction: AlphaFold 1/2/3, ESMFold, co-evolutionary analysis, MSA transformers
- Protein design: inverse folding (ProteinMPNN), diffusion for protein generation (RFDiffusion), hallucination
- Drug discovery: molecular representations (SMILES, graphs), molecular property prediction, virtual screening, docking
- Generative chemistry: molecular generation (VAE, GAN, diffusion), retrosynthesis prediction
- Genomics: DNA sequence modelling (Enformer, Hyena DNA), variant effect prediction, CRISPR guide design
- Single-cell analysis: scRNA-seq, cell type clustering, trajectory inference
- Medical imaging: radiology (CheXNet), pathology (whole-slide images), segmentation (nnU-Net)
- Clinical NLP: medical entity extraction, clinical trial matching, electronic health records
Emerging Intersections
- AI for climate: weather forecasting (GraphCast, Pangu-Weather, GenCast), carbon footprint of AI, energy grid optimisation
- AI for materials science: crystal structure prediction, property screening, generative materials design
- AI for mathematics: automated theorem proving (Lean, Isabelle), conjecture generation, symbolic regression
- AI for code: code generation (Codex, StarCoder), program synthesis, formal verification, code review agents
- AI for education: intelligent tutoring systems, personalised learning, automated grading
- AI for law: contract analysis, legal document retrieval, case outcome prediction
- AI safety and alignment: RLHF, constitutional AI, interpretability (mechanistic, probing), deceptive alignment
- Societal impact: labour markets, intellectual property, deepfakes, governance frameworks
[Experiment Title]
Author: [Your Name] Affiliation: [Your Institution/Company] Date: [YYYY-MM-DD]
Motivation
- Why did you run this experiment? What question were you trying to answer?
Setup
- Model, dataset, hardware, hyperparameters, or any relevant configuration
Results
- Key findings, tables, plots (use SVG in
../images/)
Takeaways
- What did you learn? What would you do differently?