Neural Parameter Estimation
Our goal is to illustrate how to use NeuralEstimators.jl to perform neural parameter estimation of the Leaky Competing Accumulator (LCA; Usher & McClelland, 2001). In our example below, we estimate the mean of the posterior distributions. Neural parameter estimation uses neural networks to perform parameter estimation by learning the mapping between simulated data and model parameters (for a detailed review, see Zammit-Mangion et al., 2024). Neural parameter estimation is particularly useful for models with computationally intractable likelihoods, such as the LCA. Many neural estimation teachniques are amortized, meaning one incurs a large initial computational cost to train the neural estimator, but estimating the parameters with the trained network is fast and computationally efficient. One benefit of amortized approaches is that the trained neural estimator can be saved and reused across multiple datasets, or used for computationally intensive parameter recovery simulations to understand the quality of parameter estimates under ideal conditions.
Full Code
For those who are interested only in the code, you can click on the ▶ icon below to reveal a full copy-and-pastable version of the example.
Full Code
using Distributions
using Flux
using NeuralEstimators
using Plots
using SequentialSamplingModels
Random.seed!(123)
# Function to sample parameters from priors
function sample(K::Integer)
ν1 = rand(Gamma(2, 1/1.2f0), K) # Drift rate 1
ν2 = rand(Gamma(2, 1/1.2f0), K) # Drift rate 2
α = rand(Gamma(10, 1/6f0), K) # Threshold
β = rand(Beta(1, 5f0), K) # Lateral inhibition
λ = rand(Beta(1, 5f0), K) # Leak rate
τ = rand(Gamma(1.5, 1/5.0f0), K) # Non-decision time
# Stack parameters into a matrix (d×K)
θ = vcat(ν1', ν2', α', β', λ', τ')
return θ
end
# Function to simulate data from the LCA model
function simulate(θ, n_trials_per_param)
# Simulate data for each parameter vector
simulated_data = map(eachcol(θ)) do param
# Extract parameters for this model
ν1, ν2, α, β, λ, τ = param
ν = [ν1, ν2] # Two-choice LCA
# Create LCA model with SSM
model = LCA(; ν, α, β, λ, τ)
# Generate choices and reaction times
choices, rts = rand(model, n_trials_per_param)
# Return as a transpose matrix where each column is a trial
return Float32.([choices rts]')
end
return simulated_data
end
# Create neural network architecture for parameter recovery
function create_neural_estimator(;
ν_bounds = (0.1, 6.0),
α_bounds = (0.3, 4.5),
β_bounds = (0.0, 0.8),
λ_bounds = (0.0, 0.8),
τ_bounds = (0.100, 2.0)
)
# Unpack defined parameter Bounds
ν_min, ν_max = ν_bounds # Drift rates
α_min, α_max = α_bounds # Threshold
β_min, β_max = β_bounds # Lateral inhibition
λ_min, λ_max = λ_bounds # Leak rate
τ_min, τ_max = τ_bounds # Non-decision time
# Input dimension: 2 (choice and RT for each trial)
n = 2
# Output dimension: 6 parameters
d = 6 # ν[1], ν[2], α, β, λ, τ
# Width of hidden layers
w = 128
# Inner network - processes each trial independently
ψ = Chain(
Dense(n, w, relu),
Dense(w, w, relu),
Dense(w, w, relu)
)
# Final layer with parameter constraints
final_layer = Parallel(
vcat,
Dense(w, 1, x -> ν_min + (ν_max - ν_min) * σ(x)), # ν1
Dense(w, 1, x -> ν_min + (ν_max - ν_min) * σ(x)), # ν2
Dense(w, 1, x -> α_min + (α_max - α_min) * σ(x)), # α
Dense(w, 1, x -> β_min + (β_max - β_min) * σ(x)), # β
Dense(w, 1, x -> λ_min + (λ_max - λ_min) * σ(x)), # λ
Dense(w, 1, x -> τ_min + (τ_max - τ_min) * σ(x)) # τ
)
# Outer network - maps aggregated features to parameters
ϕ = Chain(
Dense(w, w, relu),
Dense(w, w, relu),
final_layer
)
# Combine into a DeepSet
network = DeepSet(ψ, ϕ)
# Initialize neural Bayes estimator
estimator = PointEstimator(network)
return estimator
end
# Create the neural estimator
estimator = create_neural_estimator()
# Train network
trained_estimator = train(
estimator,
sample, # Parameter sampler function
simulate, # Data simulator function
m = 100, # Number of trials per parameter vector
K = 10000, # Number of training parameter vectors
K_val = 2000, # Number of validation parameter vectors
loss = Flux.mae, # Mean absolute error loss
epochs = 60, # Number of training epochs
epochs_per_Z_refresh = 1, # Refresh data every epoch
epochs_per_θ_refresh = 5, # Refresh parameters every 5 epochs
batchsize = 16, # Batch size for training
verbose = true
)
# Generate test data
n_test = 500
θ_test = sample(n_test)
Z_test = simulate(θ_test, 500)
# Assess the estimator
parameter_names = ["ν1", "ν2", "α", "β", "λ", "τ"]
assessment = assess(
trained_estimator,
θ_test,
Z_test;
parameter_names = parameter_names
)
# Calculate performance metrics
bias_results = bias(assessment)
rmse_results = rmse(assessment)
println("Bias: ", bias_results)
println("RMSE: ", rmse_results)
# Extract data from assessment
df = assessment.df
# Create recovery plots for each parameter
params = unique(df.parameter)
p_plots = []
for param in params
param_data = filter(row -> row.parameter == param, df)
# Calculate correlation coefficient
truth = param_data.truth
estimate = param_data.estimate
correlation = cor(truth, estimate)
# Create plot
p = scatter(
truth,
estimate,
xlabel="Ground Truth",
ylabel="Estimated",
title=param,
legend=false
)
# Add diagonal reference line
plot!(p, [minimum(truth), maximum(truth)],
[minimum(truth), maximum(truth)],
line=:dash, color=:black)
# Get current axis limits after plot is created
x_min, x_max = xlims(p)
y_min, y_max = ylims(p)
# Position text at the top-left corner of the plot
annotate!(p, x_min + 0.1, y_max, text("R = $(round(correlation, digits=3))", :left, 10))
push!(p_plots, p)
end
# Combine plots
p_combined = plot(p_plots..., layout=(3,2), size=(800, 600))
display(p_combined)
# Generate "observed" data
ν = [2.5, 2.0]
α = 1.5
β = 0.2
λ = 0.1
τ = 0.3
# Create model and generate data
true_model = LCA(; ν, α, β, λ, τ)
observed_choices, observed_rts = rand(true_model, 100)
# Format the data
observed_data = Float32.([observed_choices observed_rts]')
# Recover parameters
recovered_params = NeuralEstimators.estimate(trained_estimator, [observed_data])
# Compare true and recovered parameters
println("True parameters: ", [ν[1], ν[2], α, β, λ, τ])
println("Recovered parameters: ", recovered_params)
Example
We'll estimate parameters of the LCA model, which is particularly challenging due to its complex dynamics, where parameters like leak rate (λ) and lateral inhibition (β) can be difficult to recover (Miletić et al., 2017). This example draws from a more in-depth case that highlights many of the steps one ought to consider when utilizing amortized inference for cognitive modeling; see Principled Amortized Bayesian Workflow for Cognitive Modeling.
Load Packages
using Distributions
using Flux
using NeuralEstimators
using Plots
using SequentialSamplingModels
Random.seed!(123)
Define Parameter Sampling
Unlike traditional Bayesian inference methods, neural parameter estimation requires us to define two functions so that the neural network can learn the mapping between simulated data and parameters. One function samples parameters from a prior distribution, and the other generates simulated data based on a sampled parameter vector. While traditional methods like MCMC also sample from the prior, those samples are used directly during inference rather than to create a separate training dataset.
Schematic of neural parameter estimation. Once trained, the neural network provides a direct mapping from observed data (Z) to parameter estimates (θ̂), enabling rapid inference without the computational burden of traditional methods.
In neural parameter estimation, we use the prior to sample a wide range of parameters and simulate corresponding data, which we then use to train a model (e.g., a neural network) to approximate a point estimate or the posterior. We use the following function to sample a range of parameters for training:
# Function to sample parameters from priors
function sample(K::Integer)
ν1 = rand(Gamma(2, 1/1.2f0), K) # Drift rate 1
ν2 = rand(Gamma(2, 1/1.2f0), K) # Drift rate 2
α = rand(Gamma(10, 1/6f0), K) # Threshold
β = rand(Beta(1, 5f0), K) # Lateral inhibition
λ = rand(Beta(1, 5f0), K) # Leak rate
τ = rand(Gamma(1.5, 1/5.0f0), K) # Non-decision time
# Stack parameters into a matrix (d×K)
θ = vcat(ν1', ν2', α', β', λ', τ')
return θ
end
Define Data Simulator
Neural estimators learn the mapping from data to parameters through simulation. Here we define a function to simulate LCA model data. To do so we will use the LCA.
# Function to simulate data from the LCA model
function simulate(θ, n_trials_per_param)
# Simulate data for each parameter vector
simulated_data = map(eachcol(θ)) do param
# Extract parameters for this model
ν1, ν2, α, β, λ, τ = param
ν = [ν1, ν2] # Two-choice LCA
# Create LCA model with SSM
model = LCA(; ν, α, β, λ, τ)
# Generate choices and reaction times
choices, rts = rand(model, n_trials_per_param)
# Return as a transpose matrix where each column is a trial
return Float32.([choices rts]')
end
return simulated_data
end
Define Neural Network Architecture
For LCA parameter recovery, we use a DeepSet architecture which respects the permutation invariance of trial data. For more details on the method see NeuralEstimators.jl documentation. To construct the network architecture we will use the Flux.jl package.
# Create neural network architecture for parameter recovery
function create_neural_estimator(;
ν_bounds = (0.1, 6.0),
α_bounds = (0.3, 4.5),
β_bounds = (0.0, 0.8),
λ_bounds = (0.0, 0.8),
τ_bounds = (0.100, 2.0)
)
# Unpack defined parameter Bounds
ν_min, ν_max = ν_bounds # Drift rates
α_min, α_max = α_bounds # Threshold
β_min, β_max = β_bounds # Lateral inhibition
λ_min, λ_max = λ_bounds # Leak rate
τ_min, τ_max = τ_bounds # Non-decision time
# Input dimension: 2 (choice and RT for each trial)
n = 2
# Output dimension: 6 parameters
d = 6 # ν[1], ν[2], α, β, λ, τ
# Width of hidden layers
w = 128
# Inner network - processes each trial independently
ψ = Chain(
Dense(n, w, relu),
Dense(w, w, relu),
Dense(w, w, relu)
)
# Final layer with parameter constraints
final_layer = Parallel(
vcat,
Dense(w, 1, x -> ν_min + (ν_max - ν_min) * σ(x)), # ν1
Dense(w, 1, x -> ν_min + (ν_max - ν_min) * σ(x)), # ν2
Dense(w, 1, x -> α_min + (α_max - α_min) * σ(x)), # α
Dense(w, 1, x -> β_min + (β_max - β_min) * σ(x)), # β
Dense(w, 1, x -> λ_min + (λ_max - λ_min) * σ(x)), # λ
Dense(w, 1, x -> τ_min + (τ_max - τ_min) * σ(x)) # τ
)
# Outer network - maps aggregated features to parameters
ϕ = Chain(
Dense(w, w, relu),
Dense(w, w, relu),
final_layer
)
# Combine into a DeepSet
network = DeepSet(ψ, ϕ)
# Initialize neural Bayes estimator
estimator = PointEstimator(network)
return estimator
end
The result of our constructed neural network is a point estimator that corresponds to a Bayes estimator, which is a functional of the posterior distribution. Under the specified loss function, this point estimate corresponds to the posterior mean. For details on the theoretical foundations of neural Bayes estimators, see Sainsbury-Dale et al. (2024).
Training the Neural Estimator
Neural estimators, like all deep learning methods, require a training phase during which they learn the mapping from data to parameters. Here, we train the estimator by simulating data on the fly: the sampler provides new parameter vectors from the prior, and the simulator generates corresponding data conditional on those parameters. Since we use online training and the network never sees the same simulated dataset twice, overfitting is less likely. For more details on training, see the API for arguments here.
# Create the neural estimator
estimator = create_neural_estimator()
# Train network
trained_estimator = train(
estimator,
sample, # Parameter sampler function
simulate, # Data simulator function
m = 100, # Number of trials per parameter vector
K = 10000, # Number of training parameter vectors
K_val = 2000, # Number of validation parameter vectors
loss = Flux.mae, # Mean absolute error loss
epochs = 60, # Number of training epochs
epochs_per_Z_refresh = 1, # Refresh data every epoch
epochs_per_θ_refresh = 5, # Refresh parameters every 5 epochs
batchsize = 16, # Batch size for training
verbose = true
)
Assessing Estimator Performance
We can assess the performance of our trained estimator on held-out test data:
# Generate test data
n_test = 500
θ_test = sample(n_test)
Z_test = simulate(θ_test, 500)
# Assess the estimator
parameter_names = ["ν1", "ν2", "α", "β", "λ", "τ"]
assessment = assess(
trained_estimator,
θ_test,
Z_test;
parameter_names = parameter_names
)
# Calculate performance metrics
bias_results = bias(assessment)
rmse_results = rmse(assessment)
println("Bias: ", bias_results)
println("RMSE: ", rmse_results)
Visualizing Parameter Recovery
A key advantage of neural estimation is the ability to quickly conduct inference after training. For example, we can visualize the recovery of parameters. While NeuralEstimators provides built-in visualization capabilities through the AlgebraOfGraphics.jl, we will demonstrate custom plotting below:
# Extract data from assessment
df = assessment.df
# Create recovery plots for each parameter
params = unique(df.parameter)
p_plots = []
for param in params
param_data = filter(row -> row.parameter == param, df)
# Calculate correlation coefficient
truth = param_data.truth
estimate = param_data.estimate
correlation = cor(truth, estimate)
# Create plot
p = scatter(
truth,
estimate,
xlabel="Ground Truth",
ylabel="Estimated",
title=param,
legend=false
)
# Add diagonal reference line
plot!(p, [minimum(truth), maximum(truth)],
[minimum(truth), maximum(truth)],
line=:dash, color=:black)
# Get current axis limits after plot is created
x_min, x_max = xlims(p)
y_min, y_max = ylims(p)
# Position text at the top-left corner of the plot
annotate!(p, x_min + 0.1, y_max, text("R = $(round(correlation, digits=3))", :left, 10))
push!(p_plots, p)
end
# Combine plots
p_combined = plot(p_plots..., layout=(3,2), size=(800, 600))
display(p_combined)
Using the Trained Estimator
Once trained, the estimator can instantly recover parameters from new data via a forward pass:
# Generate "observed" data
ν = [2.5, 2.0]
α = 1.5
β = 0.2
λ = 0.1
τ = 0.3
# Create model and generate data
true_model = LCA(; ν, α, β, λ, τ)
observed_choices, observed_rts = rand(true_model, 100)
# Format the data
observed_data = Float32.([observed_choices observed_rts]')
# Recover parameters
recovered_params = NeuralEstimators.estimate(trained_estimator, [observed_data])
# Compare true and recovered parameters
println("True parameters: ", [ν[1], ν[2], α, β, λ, τ])
println("Recovered parameters: ", recovered_params)
Notes on Performance
Neural estimators are particularly effective for models with computationally intractable likelihoods like the LCA model. However, certain parameters (particularly β and λ) can be difficult to recover, even with advanced neural network architectures (Miletić et al., 2017). This is a property of the LCA model rather than a limitation of the estimation technique.
Additional details can be found in the NeuralEstimators.jl documentation.
References
Miletić, S., Turner, B. M., Forstmann, B. U., & van Maanen, L. (2017). Parameter recovery for the leaky competing accumulator model. Journal of Mathematical Psychology, 76, 25-50.
Sainsbury-Dale, Matthew, Andrew Zammit-Mangion, and Raphaël Huser. "Likelihood-free parameter estimation with neural Bayes estimators." The American Statistician 78.1 (2024): 1-14.
Radev, S. T., Schmitt, M., Schumacher, L., Elsemüller, L., Pratz, V., Schälte, Y., ... & Bürkner, P. C. (2023). BayesFlow: Amortized Bayesian workflows with neural networks. arXiv preprint arXiv:2306.16015.
Usher, M., & McClelland, J. L. (2001). The time course of perceptual choice: The leaky, competing accumulator model. Psychological Review, 108 3, 550–592. https://doi.org/10.1037/0033-295X.108.3.550
Zammit-Mangion, Andrew, Matthew Sainsbury-Dale, and Raphaël Huser. "Neural methods for amortized inference." Annual Review of Statistics and Its Application 12 (2024).