scientific-skills/pymc/references/sampling_inference.md
This reference covers the sampling algorithms and inference methods available in PyMC for posterior inference.
pm.sample(draws=1000, tune=1000, chains=4, **kwargs)
The main interface for MCMC sampling in PyMC.
Key Parameters:
draws: Number of samples to draw per chain (default: 1000)tune: Number of tuning/warmup samples (default: 1000, discarded)chains: Number of parallel chains (default: 4)cores: Number of CPU cores to use (default: all available)target_accept: Target acceptance rate for step size tuning (default: 0.8, increase to 0.9-0.95 for difficult posteriors)random_seed: Random seed for reproducibilityreturn_inferencedata: Return ArviZ InferenceData object (default: True)idata_kwargs: Additional kwargs for InferenceData creation (e.g., {"log_likelihood": True} for model comparison)Returns: InferenceData object containing posterior samples, sampling statistics, and diagnostics
Example:
with pm.Model() as model:
# ... define model ...
idata = pm.sample(draws=2000, tune=1000, chains=4, target_accept=0.9)
PyMC automatically selects appropriate samplers based on model structure, but you can specify algorithms manually.
Default algorithm for continuous parameters. Highly efficient Hamiltonian Monte Carlo variant.
Manual specification:
with model:
idata = pm.sample(step=pm.NUTS(target_accept=0.95))
When to adjust:
target_accept (0.9-0.99) if seeing divergencesinit='adapt_diag' for faster initialization (default)init='jitter+adapt_diag' for difficult initializationsGeneral-purpose Metropolis-Hastings sampler.
Example:
with model:
idata = pm.sample(step=pm.Metropolis())
Slice sampling for univariate distributions.
Example:
with model:
idata = pm.sample(step=pm.Slice())
Combine different samplers for different parameters.
Example:
with model:
# Use NUTS for continuous params, Metropolis for discrete
step1 = pm.NUTS([continuous_var1, continuous_var2])
step2 = pm.Metropolis([discrete_var])
idata = pm.sample(step=[step1, step2])
PyMC automatically computes diagnostics. Check these before trusting results:
Measures independent information in correlated samples.
az.ess(idata)Measures convergence across chains.
az.rhat(idata)Indicate regions where NUTS struggled.
target_accept, reparameterize, or use stronger priorsidata.sample_stats.diverging.sum()Visualizes Hamiltonian Monte Carlo energy transitions.
az.plot_energy(idata)
Good separation between energy distributions indicates healthy sampling.
# Increase target acceptance rate
idata = pm.sample(target_accept=0.95)
# Or reparameterize using non-centered parameterization
# Bad (centered):
mu = pm.Normal('mu', 0, 1)
sigma = pm.HalfNormal('sigma', 1)
x = pm.Normal('x', mu, sigma, observed=data)
# Good (non-centered):
mu = pm.Normal('mu', 0, 1)
sigma = pm.HalfNormal('sigma', 1)
x_offset = pm.Normal('x_offset', 0, 1, observed=(data - mu) / sigma)
# Use fewer tuning steps if model is simple
idata = pm.sample(tune=500)
# Increase cores for parallelization
idata = pm.sample(cores=8, chains=8)
# Use variational inference for initialization
with model:
approx = pm.fit() # Run ADVI
idata = pm.sample(start=approx.sample(return_inferencedata=False)[0])
# Increase draws
idata = pm.sample(draws=5000)
# Reparameterize to reduce correlation
# Consider using QR decomposition for regression models
Faster approximate inference for large models or quick exploration.
pm.fit(n=10000, method='advi', **kwargs)
Approximates posterior with simpler distribution (typically mean-field Gaussian).
Key Parameters:
n: Number of iterations (default: 10000)method: VI algorithm ('advi', 'fullrank_advi', 'svgd')random_seed: Random seedReturns: Approximation object for sampling and analysis
Example:
with model:
approx = pm.fit(n=50000)
# Draw samples from approximation
idata = approx.sample(1000)
# Or sample for MCMC initialization
start = approx.sample(return_inferencedata=False)[0]
Trade-offs:
Captures correlations between parameters.
with model:
approx = pm.fit(method='fullrank_advi')
More accurate than mean-field but slower.
Non-parametric variational inference.
with model:
approx = pm.fit(method='svgd', n=20000)
Better captures multimodality but more computationally expensive.
Sample from the prior distribution (before seeing data).
pm.sample_prior_predictive(samples=500, **kwargs)
Purpose:
Example:
with model:
prior_pred = pm.sample_prior_predictive(samples=1000)
# Visualize prior predictions
az.plot_ppc(prior_pred, group='prior')
Sample from posterior predictive distribution (after fitting).
pm.sample_posterior_predictive(trace, **kwargs)
Purpose:
Example:
with model:
# After sampling
idata = pm.sample()
# Add posterior predictive samples
pm.sample_posterior_predictive(idata, extend_inferencedata=True)
# Posterior predictive check
az.plot_ppc(idata)
Update data and sample predictive distribution:
with model:
# Original model fit
idata = pm.sample()
# Update with new predictor values
pm.set_data({'X': X_new})
# Sample predictions
post_pred_new = pm.sample_posterior_predictive(
idata.posterior,
var_names=['y_pred']
)
Find posterior mode (point estimate).
pm.find_MAP(start=None, method='L-BFGS-B', **kwargs)
When to use:
Example:
with model:
map_estimate = pm.find_MAP()
print(map_estimate)
Limitations:
Start with ADVI for quick exploration:
approx = pm.fit(n=20000)
Run MCMC for full inference:
idata = pm.sample(draws=2000, tune=1000)
Check diagnostics:
az.summary(idata, var_names=['~mu_log__']) # Exclude transformed vars
Sample posterior predictive:
pm.sample_posterior_predictive(idata, extend_inferencedata=True)
| Scenario | Recommended Method |
|---|---|
| Small-medium models, need full uncertainty | MCMC with NUTS |
| Large models, initial exploration | ADVI |
| Discrete parameters | Metropolis or marginalize |
| Hierarchical models with divergences | Non-centered parameterization + NUTS |
| Very large data | Minibatch ADVI |
| Quick point estimates | MAP or ADVI |
Non-centered parameterization for hierarchical models:
# Centered (can cause divergences):
mu = pm.Normal('mu', 0, 10)
sigma = pm.HalfNormal('sigma', 1)
theta = pm.Normal('theta', mu, sigma, shape=n_groups)
# Non-centered (better sampling):
mu = pm.Normal('mu', 0, 10)
sigma = pm.HalfNormal('sigma', 1)
theta_offset = pm.Normal('theta_offset', 0, 1, shape=n_groups)
theta = pm.Deterministic('theta', mu + sigma * theta_offset)
QR decomposition for correlated predictors:
import numpy as np
# QR decomposition
Q, R = np.linalg.qr(X)
with pm.Model():
# Uncorrelated coefficients
beta_tilde = pm.Normal('beta_tilde', 0, 1, shape=p)
# Transform back to original scale
beta = pm.Deterministic('beta', pm.math.solve(R, beta_tilde))
mu = pm.math.dot(Q, beta_tilde)
sigma = pm.HalfNormal('sigma', 1)
y = pm.Normal('y', mu, sigma, observed=y_obs)
For complex posteriors or model evidence estimation:
with model:
idata = pm.sample_smc(draws=2000, chains=4)
Good for multimodal posteriors or when NUTS struggles.
Provide starting values:
start = {'mu': 0, 'sigma': 1}
with model:
idata = pm.sample(start=start)
Or use MAP estimate:
with model:
start = pm.find_MAP()
idata = pm.sample(start=start)