Statistical Inference
-
Hypothesis testing gives you a yes/no decision: reject or fail to reject. But often you want something more informative, a range of plausible values for the parameter you are estimating. That is what confidence intervals provide.
-
A point estimate is a single number computed from your sample, like the sample mean $\bar{x}$. It is your best guess for the population parameter, but on its own it gives no sense of how precise the estimate is.
-
A confidence interval wraps that point estimate with a range that reflects uncertainty. It takes the form:
$$\text{CI} = \bar{x} \pm \text{ME}$$
- The margin of error (ME) depends on three things: how confident you want to be, how much variability is in the data, and how large your sample is:
$$\text{ME} = z^\ast \cdot \frac{\sigma}{\sqrt{n}}$$
- Here $z^\ast$ is the critical value from the normal distribution that matches your desired confidence level. For 95% confidence, $z^\ast = 1.96$. For 99% confidence, $z^\ast = 2.576$.
-
A 95% confidence interval means: if you repeated the experiment many times and built an interval each time, about 95% of those intervals would contain the true population parameter. It does not mean there is a 95% probability the parameter is in this specific interval. The parameter is fixed; the intervals are what vary.
-
Worked example: You measure the heights of 50 people and find $\bar{x} = 170$ cm with $\sigma = 8$ cm. Construct a 95% confidence interval.
$$\text{ME} = 1.96 \cdot \frac{8}{\sqrt{50}} = 1.96 \cdot 1.131 = 2.22 \text{ cm}$$
$$\text{CI} = [170 - 2.22, ; 170 + 2.22] = [167.78, ; 172.22]$$
-
You can say with 95% confidence that the true mean height lies between 167.78 and 172.22 cm.
-
When $\sigma$ is unknown (the usual case), use the sample standard deviation $s$ and the t-distribution instead:
$$\text{CI} = \bar{x} \pm t^\ast_{n-1} \cdot \frac{s}{\sqrt{n}}$$
-
Wider intervals are more confident but less precise. Narrower intervals are more precise but less confident. You can narrow an interval without losing confidence by increasing the sample size.
-
Power analysis helps you plan an experiment before you run it. The question is: how large a sample do I need to detect an effect of a given size with a specified power?
-
Recall from the previous file that power = $1 - \beta$, the probability of correctly rejecting a false $H_0$. A common target is 80% power.
-
The required sample size for a z-test detecting a difference $\delta$ with significance $\alpha$ and power $1-\beta$ is:
$$n = \left(\frac{(z_{\alpha/2} + z_{\beta}) \cdot \sigma}{\delta}\right)^2$$
- For example, to detect a 2 cm difference in mean height ($\sigma = 8$) with $\alpha = 0.05$ and 80% power ($z_{0.025} = 1.96$, $z_{0.20} = 0.84$):
$$n = \left(\frac{(1.96 + 0.84) \cdot 8}{2}\right)^2 = \left(\frac{22.4}{2}\right)^2 = 11.2^2 \approx 126$$
-
You would need about 126 people per group.
-
Power analysis prevents two common mistakes: running an experiment too small to detect a real effect (underpowered), or wasting resources on an experiment far larger than necessary (overpowered).
-
Monte Carlo methods use random sampling to solve problems that are difficult or impossible to solve analytically. The core idea: if you cannot compute something exactly, simulate it many times and use the results as an approximation.
-
The name comes from the Monte Carlo casino, a nod to the role of randomness. These methods are workhorses in ML for tasks like estimating integrals, evaluating model uncertainty, and approximating complex distributions.
-
The general Monte Carlo recipe:
- Define a domain of possible inputs
- Generate random inputs from that domain
- Evaluate a function on each input
- Aggregate the results (average, count, etc.)
-
A classic example is estimating $\pi$. Imagine a square with side length 2, centred at the origin, with a circle of radius 1 inscribed inside it. The area of the square is 4, and the area of the circle is $\pi$.
- Drop random points uniformly in the square. The fraction that land inside the circle approximates $\pi/4$:
$$\pi \approx 4 \times \frac{\text{points inside circle}}{\text{total points}}$$
-
A point $(x, y)$ is inside the circle if $x^2 + y^2 \le 1$. The more points you throw, the closer your estimate gets to the true value of $\pi$.
-
In ML, Monte Carlo methods appear in:
- Monte Carlo dropout: run inference multiple times with dropout enabled to estimate prediction uncertainty
- MCMC (Markov Chain Monte Carlo): sample from complex posterior distributions in Bayesian models
- Policy gradient methods: estimate gradients in reinforcement learning by sampling trajectories
-
Factor analysis is a technique for discovering hidden (latent) variables that explain the correlations among observed variables. If 10 personality survey questions can be explained by 3 underlying traits (extraversion, agreeableness, conscientiousness), factor analysis finds those traits.
-
The model assumes each observed variable $x_i$ is a linear combination of a few latent factors $f_j$ plus noise:
$$x_i = \lambda_{i1} f_1 + \lambda_{i2} f_2 + \ldots + \lambda_{ik} f_k + \epsilon_i$$
-
The $\lambda$ values are called factor loadings and tell you how strongly each observed variable relates to each factor. This connects directly to the matrix decompositions from Chapter 2; factor analysis is closely related to eigenvalue decomposition and SVD.
-
Experimental design is the art of structuring an experiment so that you can draw valid conclusions. Poor design can make even a large dataset useless.
-
Key components of a well-designed experiment:
- Independent variable (IV): what you manipulate (e.g. drug dose, model architecture)
- Dependent variable (DV): what you measure (e.g. recovery time, accuracy)
- Control group: receives no treatment (or a placebo), providing a baseline for comparison
- Random assignment: participants are assigned to groups randomly, which balances out confounding variables you did not measure
-
Common experimental designs:
- Completely randomised design: subjects are randomly assigned to treatment groups. Simple and effective when groups are comparable.
- Randomised block design: subjects are first grouped into blocks (e.g. by age), then randomly assigned to treatments within each block. This reduces variability from the blocking factor, similar in spirit to stratified sampling.
- Factorial design: tests multiple IVs simultaneously. A $2 \times 3$ factorial design has 2 levels of one variable and 3 of another, giving 6 treatment combinations. This lets you detect interactions, where the effect of one variable depends on the level of another.
- Crossover design: each subject receives all treatments in sequence (with washout periods in between). Every subject serves as their own control, reducing the effect of individual differences.
-
In ML experiments, these principles are critical. When comparing models, you should control for random seed, dataset split, and hardware. Cross-validation is a form of crossover design. Ablation studies, where you remove one component at a time, follow the logic of factorial designs.
Coding Tasks (use CoLab or notebook)
- Construct a 95% confidence interval for the height example, then experiment with different confidence levels and sample sizes.
import jax.numpy as jnp
x_bar = 170.0 # sample mean
sigma = 8.0 # population std (known)
n = 50 # sample size
# Critical values for common confidence levels
z_stars = {0.90: 1.645, 0.95: 1.960, 0.99: 2.576}
for conf, z_star in z_stars.items():
me = z_star * (sigma / jnp.sqrt(n))
lower, upper = x_bar - me, x_bar + me
print(f"{conf*100:.0f}% CI: [{lower:.2f}, {upper:.2f}] (ME = {me:.2f})")
- Estimate $\pi$ using Monte Carlo simulation. Plot how the estimate converges as you increase the number of points.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
key = jax.random.PRNGKey(42)
# Generate random points in [-1, 1] x [-1, 1]
n_points = 100_000
k1, k2 = jax.random.split(key)
x = jax.random.uniform(k1, shape=(n_points,), minval=-1, maxval=1)
y = jax.random.uniform(k2, shape=(n_points,), minval=-1, maxval=1)
# Check which points are inside the unit circle
inside = (x**2 + y**2) <= 1.0
cumulative_inside = jnp.cumsum(inside)
counts = jnp.arange(1, n_points + 1)
pi_estimates = 4.0 * cumulative_inside / counts
plt.figure(figsize=(10, 4))
plt.plot(pi_estimates, color="#3498db", alpha=0.7, linewidth=0.5)
plt.axhline(y=jnp.pi, color="#e74c3c", linestyle="--", label=f"π = {jnp.pi:.6f}")
plt.xlabel("Number of points")
plt.ylabel("Estimate of π")
plt.title("Monte Carlo estimation of π")
plt.legend()
plt.ylim(2.8, 3.5)
plt.show()
print(f"Final estimate: {pi_estimates[-1]:.6f}")
print(f"True value: {jnp.pi:.6f}")
print(f"Error: {abs(pi_estimates[-1] - jnp.pi):.6f}")
- Perform a simple power analysis: for a given effect size and standard deviation, compute the required sample size and verify it by simulation.
import jax
import jax.numpy as jnp
# Parameters
delta = 2.0 # effect size (difference in means)
sigma = 8.0 # population std
alpha = 0.05
power_target = 0.80
# Analytical sample size
z_alpha = 1.96 # two-tailed, alpha=0.05
z_beta = 0.84 # power=0.80
n_required = ((z_alpha + z_beta) * sigma / delta) ** 2
print(f"Required n per group: {n_required:.0f}")
# Verify by simulation
key = jax.random.PRNGKey(7)
n = int(jnp.ceil(n_required))
n_sims = 5000
rejections = 0
for _ in range(n_sims):
key, k1, k2 = jax.random.split(key, 3)
group_a = jax.random.normal(k1, shape=(n,)) * sigma + 50
group_b = jax.random.normal(k2, shape=(n,)) * sigma + 50 + delta
pooled_se = jnp.sqrt(2 * sigma**2 / n)
z = (group_b.mean() - group_a.mean()) / pooled_se
p = 2 * (1 - __import__("jax").scipy.stats.norm.cdf(jnp.abs(z)))
if p <= alpha:
rejections += 1
print(f"Simulated power: {rejections/n_sims:.3f}")
print(f"Target power: {power_target:.3f}")
- Visualise how confidence interval width changes with sample size. This shows why collecting more data gives more precise estimates.
import jax.numpy as jnp
import matplotlib.pyplot as plt
sigma = 8.0
z_star = 1.96 # 95% confidence
sample_sizes = jnp.array([10, 20, 30, 50, 100, 200, 500, 1000], dtype=jnp.float32)
margins = z_star * sigma / jnp.sqrt(sample_sizes)
plt.figure(figsize=(8, 4))
plt.bar([str(int(n)) for n in sample_sizes], margins, color="#3498db", alpha=0.7)
plt.xlabel("Sample size")
plt.ylabel("Margin of error (cm)")
plt.title("95% CI margin of error shrinks with larger samples")
plt.show()