A Simple Turing Model

It is possible to use Turing.jl to perform Bayesian parameter estimation on models defined in SequentialSamplingModels.jl. Below, we show you how to estimate the parameters for the Linear Ballistic Accumulator (LBA) and to use it to estimate effects.

Note that you can easily swap the LBA model from this example for other SSM models simply by changing the names of the parameters.

Load Packages

The first step is to load the required packages. You will need to install each package in your local environment in order to run the code locally. We will also set a random number generator so that the results are reproducible.

using Turing
using SequentialSamplingModels
using Random
using LinearAlgebra
using StatsPlots
using Random

Random.seed!(45461)
Random.TaskLocalRNG()

Generate Data

We will use the LBA distribution to simulate data (100 trials) with fixed parameters (those we want to recover only from the data using Bayesian modeling).

# Generate some data with known parameters
dist = LBA(ν=[3.0, 2.0], A = .8, k = .2, τ = .3)
data = rand(dist, 100)
(choice = [2, 1, 1, 1, 1, 1, 1, 1, 1, 2  …  2, 1, 1, 1, 1, 1, 2, 2, 1, 1], rt = [0.6290010134562214, 0.507511686077583, 0.4077035786951601, 0.5845853642627061, 0.5518606313501111, 0.357077395905047, 0.4430796909532061, 0.35525992231554643, 0.4858235685699115, 0.37197001383636874  …  0.8595469187781852, 0.5705118298266227, 0.5348699092588236, 0.4397605741138838, 0.48381613889237696, 0.4576847624125857, 0.4835375184039913, 0.5544902383424269, 0.583537266351565, 0.36048725612183097])

The rand() function will sample random draws from the distribution, and store that into a named tuple of 2 vectors (one for choice and one for rt). The individual vectors can be accessed by their names using data.choice and data.rt.

Specify Turing Model

The code snippet below defines a model in Turing. The model function accepts a tuple containing a vector of choices and a vector of reaction times. The sampling statements define the prior distributions for each parameter. The non-decision time parameter $\tau$ must be founded by the minimum reaction time, min_rt. The last sampling statement defines the likelihood of the data given the sampled parameter values.

# Specify LBA model
@model function model_lba(data; min_rt = minimum(data.rt))
    # Priors
    ν ~ MvNormal(zeros(2), I * 2)
    A ~ truncated(Normal(.8, .4), 0.0, Inf)
    k ~ truncated(Normal(.2, .2), 0.0, Inf)
    τ  ~ Uniform(0.0, min_rt)

    # Likelihood
    data ~ LBA(;ν, A, k, τ )
end
model_lba (generic function with 2 methods)

Estimate the Parameters

Finally, we perform parameter estimation with sample(), which takes the model, and details about the sampling algorithm:

  1. model(data): the Turing model with data passed
  2. NUTS(1000, .65): a sampler object for the No U-Turn Sampler for 1000 warmup samples.
  3. MCMCThreads(): instructs Turing to run each chain on a separate thread
  4. n_iterations: the number of iterations performed after warmup
  5. n_chains: the number of chains
# Estimate parameters
chain = sample(model_lba(data), NUTS(1000, .85), MCMCThreads(), 1000, 4)
Chains MCMC chain (1000×17×4 Array{Float64, 3}):

Iterations        = 1001:1:2000
Number of chains  = 4
Samples per chain = 1000
Wall duration     = 35.71 seconds
Compute duration  = 32.75 seconds
parameters        = ν[1], ν[2], A, k, τ
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse    ess_bulk    ess_tail      rhat   ⋯
      Symbol   Float64   Float64   Float64     Float64     Float64   Float64   ⋯

        ν[1]    2.8325    0.4291    0.0130   1138.3596   1090.6557    1.0034   ⋯
        ν[2]    1.6565    0.3851    0.0122   1021.2542   1187.2001    1.0037   ⋯
           A    0.7322    0.1733    0.0052   1106.5762   1202.1972    1.0043   ⋯
           k    0.2433    0.1223    0.0034   1127.2799    804.8034    1.0020   ⋯
           τ    0.2801    0.0293    0.0008   1239.0902    803.8147    1.0014   ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

        ν[1]    2.0705    2.5294    2.8066    3.1022    3.7558
        ν[2]    0.9511    1.3933    1.6463    1.9074    2.4498
           A    0.4135    0.6150    0.7174    0.8420    1.1242
           k    0.0441    0.1538    0.2284    0.3191    0.5179
           τ    0.2170    0.2606    0.2831    0.3016    0.3294

Posterior Summary

We can compute a description of the posterior distributions.

# Summarize posteriors
summarystats(chain)
Summary Statistics
  parameters      mean       std      mcse    ess_bulk    ess_tail      rhat   ⋯
      Symbol   Float64   Float64   Float64     Float64     Float64   Float64   ⋯

        ν[1]    2.8325    0.4291    0.0130   1138.3596   1090.6557    1.0034   ⋯
        ν[2]    1.6565    0.3851    0.0122   1021.2542   1187.2001    1.0037   ⋯
           A    0.7322    0.1733    0.0052   1106.5762   1202.1972    1.0043   ⋯
           k    0.2433    0.1223    0.0034   1127.2799    804.8034    1.0020   ⋯
           τ    0.2801    0.0293    0.0008   1239.0902    803.8147    1.0014   ⋯
                                                                1 column omitted

As you can see, based on the mean values of the posterior distributions, the original parameters (ν=[3.0, 2.0], A = .8, k = .2, τ = .3) are successfully recovered from the data (the accuracy would increase with more data).

Evaluation

It is important to verify that the chains converged. We see that the chains converged according to $\hat{r} \leq 1.05$, and the trace plots below show that the chains look like "hairy caterpillars", which indicates the chains did not get stuck. As expected, the posterior distributions are close to the data generating parameter values.

plot(chain)
Example block output