devinterp.slt package

Submodules

devinterp.slt.callback module

class devinterp.slt.callback.SamplerCallback(device: device | str = 'cpu')[source]

Bases: object

Base class for creating callbacks used in devinterp.slt.sampler.sample(). Each instantiated callback gets its __call__ called every sample, and finalize called at the end of sample (if it exists). Each callback method can access parameters in locals(), so there’s no need to pass variables along explicitly.

Parameters:

device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’.

Raises:

NotImplementedError – if __call__ and sample are not overwritten.

Note

  • mps devices might not work for all callbacks.

finalize(*args, **kwargs)[source]

Gets called at the end of sampling. Can access any variable in locals() when called. Should be used for calucalting stats over chains, for example average chain loss.

sample(*args, **kwargs)[source]

Does not get called automatically, but functions as an interface to easily access stats calculated by the callback.

devinterp.slt.cov module

class devinterp.slt.cov.BetweenLayerCovarianceAccumulator(model, pairs: Dict[str, Tuple[str, str]], device: device | str = 'cpu', num_evals: int = 3, **accessors: Callable[[Module], Tensor])[source]

Bases: object

A CovarianceAccumulator to compute covariance between arbitrary layers. For use with devinterp.slt.sampler.sample().

Parameters:
  • model (torch.nn.Module) – The model to compute covariances on.

  • pairs (Dict[str, Tuple[str, str]]) – Named pairs of layers to compute covariances on.

  • device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’

  • num_evals (int) – number of eigenvectors / eigenvalues to compute. Default is 3

  • accessors (Callable[[nn.Module], torch.Tensor]) – Functions to access attention head weights.

sample()[source]
Returns:

A dict with named_pairs keys, with corresponding values {"evals": eigenvalues_of_cov_matrix, "evecs": eigenvectors_of_cov_matrix, "matrix": cov_matrix}. (Only after running devinterp.slt.sampler.sample(..., [covariance_accumulator_instance], ...)).

class devinterp.slt.cov.CovarianceAccumulator(num_weights: int, accessors: List[Callable[[Module], Tensor]], device: device | str = 'cpu', num_evals: int = 3)[source]

Bases: SamplerCallback

A callback to iteratively compute and store the covariance matrix of model weights. For passing along to devinterp.slt.sampler.sample().

Parameters:
  • num_weights (int) – Total number of weights.

  • accessors (List[Callable[[nn.Module], torch.Tensor]]) – Functions to access model weights.

  • device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’

  • num_evals (int, optional) – Number of eigenvalues to compute. Default is 3

sample()[source]
Returns:

A dict {"evals": eigenvalues_of_cov_matrix, "evecs": eigenvectors_of_cov_matrix, "matrix": cov_matrix}. (Only after running devinterp.slt.sampler.sample(..., [covariance_accumulator_instance], ...))

class devinterp.slt.cov.WithinHeadCovarianceAccumulator(num_heads: int, num_weights_per_head: int, accessors: List[Callable[[Module], Tuple[Tensor, ...]]], device: device | str = 'cpu', num_evals: int = 3)[source]

Bases: object

A CovarianceAccumulator to compute covariance within attention heads. For use with devinterp.slt.sampler.sample().

Parameters:
  • num_heads (int) – The number of attention heads.

  • num_weights_per_head (int) – The number of weights per attention head.

  • accessors (List[Callable[[nn.Module], Tuple[torch.Tensor, ...]]]) – Functions to access attention head weights.

  • device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’

  • num_evals (int, optional) – number of eigenvectors / eigenvalues to compute. Default is 3

sample()[source]
Returns:

A dict {"evals": array_of_eigenvalues_of_cov_matrix_layer_idx_head_idx, "evecs": array_of_eigenvectors_of_cov_matrix_layer_idx_head_idx, "matrix": array_of_cov_matrices_layer_idx_head_idx}. (Only after running devinterp.slt.sampler.sample(..., [covariance_accumulator_instance], ...)).

devinterp.slt.gradient module

class devinterp.slt.gradient.GradientDistribution(num_chains: int, num_draws: int, min_bins: int = 20, param_names: List[str] | None = None, device: device | str = 'cpu')[source]

Bases: SamplerCallback

Callback for plotting the distribution of gradients as a function of draws. Does some magic to automatically adjust bins as more draws are taken. For use with devinterp.slt.sampler.sample().

Parameters:
  • num_draws (int) – Number of samples to draw (should be identical to num_draws passed to devinterp.slt.sampler.sample)

  • num_chains (int) – Number of chains to run (should be identical to num_chains passed to devinterp.slt.sampler.sample)

  • min_bins (int, optional) – Minimum number of bins for histogram approximation. Default is 20

  • param_names (List[str], optional) – List of parameter names to track. If None, all parameters are tracked. Default is None

  • device (: str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’.

Raises:

ValueError – If gradients are not computed before calling this callback.

plot(param_name: str, color='blue', plot_zero=True, chain: int | None = None)[source]

Plots the gradient distribution for a specific parameter.

Parameters:
  • param_name (str) – the name of the parameter plot gradients for.

  • color (str, optional) – The color to plot gradient bins in. Default is blue

  • plot_zero (bool, optional) – Whether to plot the line through y=0. Default is True

  • chain (int, optional) – The model to compute covariances on.

Returns:

None, but shows the denisty gradient bins over sampling steps.

sample()[source]
Returns:

A dict {"gradient/distributions": grad_dists}. (Only after running devinterp.slt.sampler.sample(..., [gradient_dist_instance], ...))

devinterp.slt.llc module

class devinterp.slt.llc.LLCEstimator(num_chains: int, num_draws: int, temperature: float, device: device | str = 'cpu')[source]

Bases: SamplerCallback

Callback for estimating the Local Learning Coefficient (LLC) in a rolling fashion during a sampling process. It calculates the LLC based on the average loss across draws for each chain:

\[LLC = \textrm{T} * (\textrm{avg_loss} - \textrm{init_loss})\]

For use with devinterp.slt.sampler.sample().

Note

init_loss gets set inside devinterp.slt.sample(). It can be passed as an argument to that function, and if not passed will be the average loss of the supplied model over num_chains batches.

Parameters:
  • num_draws (int) – Number of samples to draw (should be identical to num_draws passed to devinterp.slt.sampler.sample)

  • num_chains (int) – Number of chains to run (should be identical to num_chains passed to devinterp.slt.sampler.sample)

  • temperature (int) – Temperature, float (default: 1., set by sample() to utils.optimal_temperature(dataloader)=len(batch_size)/np.log(len(batch_size)))

  • device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’.

sample()[source]
Returns:

A dict {"llc/mean": llc_mean, "llc/std": llc_std, "llc-chain/{i}": llc_trace_per_chain, "loss/trace": loss_trace_per_chain}. (Only after running devinterp.slt.sampler.sample(..., [llc_estimator_instance], ...)).

class devinterp.slt.llc.OnlineLLCEstimator(num_chains: int, num_draws: int, temperature: float, device='cpu')[source]

Bases: SamplerCallback

Callback for estimating the Local Learning Coefficient (LLC) in an online fashion during a sampling process. It calculates LLCs using the same formula as devinterp.slt.llc.LLCEstimator(), but continuously and including means and std across draws (as opposed to just across chains). For use with devinterp.slt.sampler.sample().

Note

init_loss gets set inside devinterp.slt.sample(). It can be passed as an argument to that function, and if not passed will be the average loss of the supplied model over num_chains batches.

Parameters:
  • num_draws (int) – Number of samples to draw (should be identical to num_draws passed to devinterp.slt.sampler.sample)

  • num_chains (int) – Number of chains to run (should be identical to num_chains passed to devinterp.slt.sampler.sample)

  • temperature (int) – Temperature, float (default: 1., set by sample() to utils.optimal_temperature(dataloader)=len(batch_size)/np.log(len(batch_size)))

  • device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’

sample()[source]
Returns:

A dict {"llc/means": llc_means, "llc/stds": llc_stds, "llc/trace": llc_trace_per_chain, "loss/trace": loss_trace_per_chain}. (Only after running devinterp.slt.sampler.sample(..., [llc_estimator_instance], ...)).

devinterp.slt.loss module

class devinterp.slt.loss.OnlineLossStatistics(base_callback: OnlineLLCEstimator)[source]

Bases: SamplerCallback

Derivative callback that computes various loss statistics for OnlineLLCEstimator(). Must be called after the base OnlineLLCEstimator() has been called at each draw.

See the diagnostics notebook colab5 for examples on how to use this to diagnose your sample health.

Parameters:

base_callback (OnlineLLCEstimator()) – Base callback that computes original loss metric.

Note

  • Requires losses to be computed first, so call using f.e. devinterp.slt.sampler.sample(..., [llc_estimator_instance, ..., online_loss_stats_instance], ...)

loss_hist_by_draw(draw: int = 0, bins: int = 10)[source]

Plots a histogram of chain losses for a given draw index.

Parameters:
  • draw (int, optional) – Draw index to plot histogram for. Default is 0

  • bins (int, optional) – number of histogram bins. Default is 10

sample()[source]
Returns:

A dict {"loss/percent_neg_steps": percent_neg_steps, "loss/percent_mean_neg_steps": percent_mean_neg_steps, "loss/percent_thresholded_neg_steps": percent_thresholded_neg_steps, "loss/z_scores": z_scores}. (Only after running devinterp.slt.sampler.sample(..., [llc_estimator_instance, online_loss_stats_instance], ...))

devinterp.slt.mala module

class devinterp.slt.mala.MalaAcceptanceRate(num_chains: int, num_draws: int, temperature: float, learning_rate: float, device: device | str = 'cpu')[source]

Bases: SamplerCallback

Callback for computing MALA acceptance rate.

num_draws

Number of samples to draw. (should be identical to param passed to sample())

Type:

int

num_chains

Number of chains to run. (should be identical to param passed to sample())

Type:

int

temperature

Temperature used to calculate the LLC.

Type:

float

learning_rate

Learning rate of the model.

Type:

int

device

Device to perform computations on, e.g., ‘cpu’ or ‘cuda’.

Type:

Union[torch.device, str]

devinterp.slt.mala.mala_acceptance_probability(prev_params: Tensor | List[Tensor], prev_grads: Tensor | List[Tensor], prev_loss: Tensor, current_params: Tensor | List[Tensor], current_grads: Tensor | List[Tensor], current_loss: Tensor, learning_rate: float)[source]

Calculate the acceptance probability for a MALA transition. Parameters and gradients can either all be given as a tensor (all of the same shape) or all as lists of tensors (eg the parameters of a Module).

Args: prev_params: The previous point in parameter space. prev_grads: Gradient of the prev point in parameter space. prev_loss: Loss of the previous point in parameter space. current_params: The current point in parameter space. current_grads: Gradient of the current point in parameter space. current_loss: Loss of the current point in parameter space. learning_rate (float): Learning rate of the model.

Returns: float: Acceptance probability for the proposed transition.

devinterp.slt.norms module

class devinterp.slt.norms.GradientNorm(num_chains: int, num_draws: int, p_norm: int = 2, device: device | str = 'cpu')[source]

Bases: SamplerCallback

Callback for computing the norm of the gradients of the optimizer / sampler.

Parameters:
  • num_draws (int) – Number of samples to draw (should be identical to num_draws passed to devinterp.slt.sampler.sample)

  • num_chains (int) – Number of chains to run (should be identical to num_chains passed to devinterp.slt.sampler.sample)

  • p_norm (int, optional) – Order of the norm to be computed (e.g., 2 for Euclidean norm). Default is 2

  • device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’

sample()[source]
Returns:

A dict {"gradient_norm/trace": gradient_norms}. (Only after running devinterp.slt.sampler.sample(..., [grad_norm_instance], ...))

class devinterp.slt.norms.NoiseNorm(num_chains: int, num_draws: int, p_norm: int = 2, device: device | str = 'cpu')[source]

Bases: SamplerCallback

Callback for computing the norm of the noise added in the optimizer / sampler.

Parameters:
  • num_draws (int) – Number of samples to draw (should be identical to num_draws passed to devinterp.slt.sampler.sample)

  • num_chains (int) – Number of chains to run (should be identical to num_chains passed to devinterp.slt.sampler.sample)

  • p_norm (int, optional) – Order of the norm to be computed (e.g., 2 for Euclidean norm). Default is 2

  • device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’

sample()[source]
Returns:

A dict {"noise_norm/trace": noise_norms}. (Only after running devinterp.slt.sampler.sample(..., [noise_norm_instance], ...))

class devinterp.slt.norms.WeightNorm(num_chains: int, num_draws: int, p_norm: int = 2, device: device | str = 'cpu')[source]

Bases: SamplerCallback

Callback for computing the norm of the weights over the sampling process.

Parameters:
  • num_draws (int) – Number of samples to draw (should be identical to num_draws passed to devinterp.slt.sampler.sample)

  • num_chains (int) – Number of chains to run (should be identical to num_chains passed to devinterp.slt.sampler.sample)

  • p_norm (int, optional) – Order of the norm to be computed (e.g., 2 for Euclidean norm). Default is 2

  • device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’

sample()[source]
Returns:

A dict {"weight_norm/trace": weight_norms}. (Only after running devinterp.slt.sampler.sample(..., [weight_norm_instance], ...))

devinterp.slt.sampler module

devinterp.slt.sampler.sample(model: ~torch.nn.modules.module.Module, loader: ~torch.utils.data.dataloader.DataLoader, callbacks: ~typing.List[~devinterp.slt.callback.SamplerCallback], evaluate: ~typing.Callable[[~torch.nn.modules.module.Module, ~torch.Tensor], ~devinterp.utils.Outputs | ~typing.Dict[str, ~torch.Tensor] | ~typing.Tuple[~torch.Tensor, ...] | ~torch.Tensor] | None = None, sampling_method: ~typing.Type[~torch.optim.optimizer.Optimizer] = <class 'devinterp.optim.sgld.SGLD'>, optimizer_kwargs: ~typing.Dict[str, float | ~typing.Literal['adaptive']] | None = None, num_draws: int = 100, num_chains: int = 10, num_burnin_steps: int = 0, num_steps_bw_draws: int = 1, init_loss: float | None = None, grad_accum_steps: int = 1, cores: int = 1, seed: ~typing.List[int] | int | None = None, device: ~torch.device | str = device(type='cpu'), verbose: bool = True, optimize_over_per_model_param: ~typing.Dict[str, ~typing.List[bool]] | None = None)[source]

Sample model weights using a given sampling_method, supporting multiple chains/cores, and calculate the observables (loss, llc, etc.) for each callback passed along. The update, finalize and sample methods of each SamplerCallback() are called during sampling, after sampling, and at sampler_callback_object.sample() respectively.

After calling this function, the stats of interest live in the callback object.

Parameters:
  • model (torch.nn.Module) – The neural network model.

  • loader (DataLoader) – DataLoader for input data.

  • evaluate (EvaluateFn) – Maps a model and batch of data to an object with a loss attribute.

  • callbacks (list[SamplerCallback]) – list of callbacks, each of type SamplerCallback

  • sampling_method (torch.optim.Optimizer, optional) – Sampling method to use (a PyTorch optimizer under the hood). Default is SGLD

  • optimizer_kwargs (dict, optional) – Keyword arguments for the PyTorch optimizer (used as sampler here). Default is None (using standard SGLD parameters as defined in the SGLD class)

  • num_draws (int, optional) – Number of samples to draw. Default is 100

  • num_chains (int, optional) – Number of chains to run. Default is 10

  • num_burnin_steps (int, optional) – Number of burn-in steps before sampling. Default is 0

  • num_steps_bw_draws (int, optional) – Number of steps between each draw. Default is 1

  • init_loss (float, optional) – Initial loss for use in LLCEstimator and OnlineLLCEstimator

  • cores (int, optional) – Number of cores for parallel execution. Default is 1

  • seed (int, optional) – Random seed(s) for sampling. Each chain gets a different (deterministic) seed if this is passed. Default is None

  • device (str or torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’

  • verbose (bool, optional) – whether to print sample chain progress. Default is True

Raises:
  • ValueError – if derivative callbacks (f.e. OnlineLossStatistics()) are passed before base callbacks (f.e. OnlineLLCEstimator())

  • Warning – if num_burnin_steps < num_draws

  • Warning – if num_draws > len(loader)

  • Warning – if using seeded runs

Returns:

None (access LLCs or other observables through callback_object.sample())

devinterp.slt.trace module

class devinterp.slt.trace.OnlineTraceStatistics(base_callback: SamplerCallback, attribute: str, device: device | str = 'cpu')[source]

Bases: SamplerCallback

Derivative callback that computes mean/std statistics of a specified trace online. Must be called after the base callback has been called at each draw.

See the diagnostics notebook colab6 for examples on how to use this to diagnose your sample health.

Parameters:
  • base_callback (SamplerCallback()) – Base callback that computes some metric.

  • attribute (str) – Name of attribute to compute which mean/std statistics of.

  • device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’

Raises:

ValueError if underlying trace does not have the requested attribute, num_chains or num_draws.

Note

  • Requires base trace stats to be computed first, so call using f.e. devinterp.slt.sampler.sample(..., [weight_norm_instance, online_trace_stats_instance], ...)

sample()[source]
Returns:

A dict "{self.attribute}/chain/mean": mean_attribute_by_chain, "{self.attribute}/chain/std": std_attribute_by_chain, "{self.attribute}/draw/mean": mean_attribute_by_draw, "{self.attribute}/draw/std": std_attribute_by_draw}. (Only after running devinterp.slt.sampler.sample(..., [some_thing_to_calc_stats_of, ..., trace_stats_instance], ...)).

sample_at_draw(draw=-1)[source]
Parameters:

draw (int, optional) – draw index to return stats at. Default is -1

Returns:

A dict "{self.attribute}/chain/mean": mean_attribute_of_draw_by_chain, "{self.attribute}/chain/std": std_attribute_of_draw_by_chain, "{self.attribute}/draw/mean": mean_attribute_of_draw, "{self.attribute}/draw/std": std_attribute_of_draw}. (Only after running devinterp.slt.sampler.sample(..., [some_thing_to_calc_stats_of, ..., trace_stats_instance], ...)).

devinterp.slt.wbic module

class devinterp.slt.wbic.OnlineWBICEstimator(num_chains: int, num_draws: int, n: int, device: device | str = 'cpu')[source]

Bases: SamplerCallback

Callback for estimating the Widely Applicable Bayesian Information Criterion (WBIC) in an online fashion. The WBIC used here is just \(n * (\textrm{average sampled loss})\). (Watanabe, 2013)

Parameters:
  • num_draws (int) – Number of samples to draw (should be identical to num_draws passed to devinterp.slt.sampler.sample)

  • num_chains (int) – Number of chains to run (should be identical to num_chains passed to devinterp.slt.sampler.sample)

  • n – Number of samples used to calculate the wbic.

  • n – int

  • device – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’

  • device – str | torch.device, optional

sample()[source]
Returns:

A dict {"wbic/means": wbic_means, "wbic/stds": wbic_stds, "wbic/trace": wbic_trace_per_chain, "loss/trace": loss_trace_per_chain}. (Only after running devinterp.slt.sampler.sample(..., [wbic_estimator_instance], ...)).

Module contents