Skip to content

enzax.mcmc

Code for MCMC-based Bayesian inference on kinetic models.

run_nuts(logdensity_fn, rng_key, init_parameters, num_warmup, num_samples, **adapt_kwargs)

Run the default NUTS algorithm with blackjax.

Source code in enzax/mcmc.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def run_nuts(
    logdensity_fn: Callable,
    rng_key: Array,
    init_parameters: PyTree,
    num_warmup: int,
    num_samples: int,
    **adapt_kwargs: Unpack[AdaptationKwargs],
):
    """Run the default NUTS algorithm with blackjax."""
    warmup = blackjax.window_adaptation(
        blackjax.nuts,
        logdensity_fn,
        progress_bar=True,
        **adapt_kwargs,
    )
    rng_key, warmup_key = jax.random.split(rng_key)
    (initial_state, tuned_parameters), (_, info, _) = warmup.run(
        warmup_key,
        init_parameters,
        num_steps=num_warmup,  #  type: ignore
    )
    rng_key, sample_key = jax.random.split(rng_key)
    nuts_kernel = blackjax.nuts(logdensity_fn, **tuned_parameters).step
    states, info = _inference_loop(
        sample_key,
        kernel=nuts_kernel,
        initial_state=initial_state,
        num_samples=num_samples,
    )
    return states, info

ind_prior_from_truth(truth, sd)

Get a set of independent priors centered at the true parameter values.

Note that the standard deviation currently has to be the same for all parameters.

Source code in enzax/mcmc.py
66
67
68
69
70
71
72
73
def ind_prior_from_truth(truth: Float[Array, " _"], sd: ScalarLike):
    """Get a set of independent priors centered at the true parameter values.

    Note that the standard deviation currently has to be the same for
    all parameters.

    """
    return jnp.vstack((truth, jnp.full(truth.shape, sd)))