devinterp.slt package
Submodules
devinterp.slt.callback module
- class devinterp.slt.callback.SamplerCallback(device: device | str = 'cpu')[source]
Bases:
object
Base class for creating callbacks used in
devinterp.slt.sampler.sample()
. Each instantiated callback gets its__call__
called every sample, andfinalize
called at the end of sample (if it exists). Each callback method can access parameters inlocals()
, so there’s no need to pass variables along explicitly.- Parameters:
device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’.
- Raises:
NotImplementedError – if
__call__
andsample
are not overwritten.
Note
mps
devices might not work for all callbacks.
devinterp.slt.cov module
- class devinterp.slt.cov.BetweenLayerCovarianceAccumulator(model, pairs: Dict[str, Tuple[str, str]], device: device | str = 'cpu', num_evals: int = 3, **accessors: Callable[[Module], Tensor])[source]
Bases:
object
A CovarianceAccumulator to compute covariance between arbitrary layers. For use with
devinterp.slt.sampler.sample()
.- Parameters:
model (torch.nn.Module) – The model to compute covariances on.
pairs (Dict[str, Tuple[str, str]]) – Named pairs of layers to compute covariances on.
device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’
num_evals (int) – number of eigenvectors / eigenvalues to compute. Default is 3
accessors (Callable[[nn.Module], torch.Tensor]) – Functions to access attention head weights.
- class devinterp.slt.cov.CovarianceAccumulator(num_weights: int, accessors: List[Callable[[Module], Tensor]], device: device | str = 'cpu', num_evals: int = 3)[source]
Bases:
SamplerCallback
A callback to iteratively compute and store the covariance matrix of model weights. For passing along to
devinterp.slt.sampler.sample()
.- Parameters:
num_weights (int) – Total number of weights.
accessors (List[Callable[[nn.Module], torch.Tensor]]) – Functions to access model weights.
device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’
num_evals (int, optional) – Number of eigenvalues to compute. Default is 3
- class devinterp.slt.cov.WithinHeadCovarianceAccumulator(num_heads: int, num_weights_per_head: int, accessors: List[Callable[[Module], Tuple[Tensor, ...]]], device: device | str = 'cpu', num_evals: int = 3)[source]
Bases:
object
A CovarianceAccumulator to compute covariance within attention heads. For use with
devinterp.slt.sampler.sample()
.- Parameters:
num_heads (int) – The number of attention heads.
num_weights_per_head (int) – The number of weights per attention head.
accessors (List[Callable[[nn.Module], Tuple[torch.Tensor, ...]]]) – Functions to access attention head weights.
device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’
num_evals (int, optional) – number of eigenvectors / eigenvalues to compute. Default is 3
- sample()[source]
- Returns:
A dict
{"evals": array_of_eigenvalues_of_cov_matrix_layer_idx_head_idx, "evecs": array_of_eigenvectors_of_cov_matrix_layer_idx_head_idx, "matrix": array_of_cov_matrices_layer_idx_head_idx}
. (Only after runningdevinterp.slt.sampler.sample(..., [covariance_accumulator_instance], ...)
).
devinterp.slt.gradient module
- class devinterp.slt.gradient.GradientDistribution(num_chains: int, num_draws: int, min_bins: int = 20, param_names: List[str] | None = None, device: device | str = 'cpu')[source]
Bases:
SamplerCallback
Callback for plotting the distribution of gradients as a function of draws. Does some magic to automatically adjust bins as more draws are taken. For use with
devinterp.slt.sampler.sample()
.- Parameters:
num_draws (int) – Number of samples to draw (should be identical to
num_draws
passed todevinterp.slt.sampler.sample
)num_chains (int) – Number of chains to run (should be identical to
num_chains
passed todevinterp.slt.sampler.sample
)min_bins (int, optional) – Minimum number of bins for histogram approximation. Default is 20
param_names (List[str], optional) – List of parameter names to track. If None, all parameters are tracked. Default is None
device (: str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’.
- Raises:
ValueError – If gradients are not computed before calling this callback.
- plot(param_name: str, color='blue', plot_zero=True, chain: int | None = None)[source]
Plots the gradient distribution for a specific parameter.
- Parameters:
param_name (str) – the name of the parameter plot gradients for.
color (str, optional) – The color to plot gradient bins in. Default is blue
plot_zero (bool, optional) – Whether to plot the line through y=0. Default is True
chain (int, optional) – The model to compute covariances on.
- Returns:
None, but shows the denisty gradient bins over sampling steps.
devinterp.slt.llc module
- class devinterp.slt.llc.LLCEstimator(num_chains: int, num_draws: int, temperature: float, device: device | str = 'cpu')[source]
Bases:
SamplerCallback
Callback for estimating the Local Learning Coefficient (LLC) in a rolling fashion during a sampling process. It calculates the LLC based on the average loss across draws for each chain:
\[LLC = \textrm{T} * (\textrm{avg_loss} - \textrm{init_loss})\]For use with
devinterp.slt.sampler.sample()
.Note
init_loss gets set inside
devinterp.slt.sample()
. It can be passed as an argument to that function, and if not passed will be the average loss of the supplied model over num_chains batches.- Parameters:
num_draws (int) – Number of samples to draw (should be identical to
num_draws
passed todevinterp.slt.sampler.sample
)num_chains (int) – Number of chains to run (should be identical to
num_chains
passed todevinterp.slt.sampler.sample
)temperature (int) – Temperature, float (default: 1., set by sample() to utils.optimal_temperature(dataloader)=len(batch_size)/np.log(len(batch_size)))
device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’.
- class devinterp.slt.llc.OnlineLLCEstimator(num_chains: int, num_draws: int, temperature: float, device='cpu')[source]
Bases:
SamplerCallback
Callback for estimating the Local Learning Coefficient (LLC) in an online fashion during a sampling process. It calculates LLCs using the same formula as
devinterp.slt.llc.LLCEstimator()
, but continuously and including means and std across draws (as opposed to just across chains). For use withdevinterp.slt.sampler.sample()
.Note
init_loss gets set inside
devinterp.slt.sample()
. It can be passed as an argument to that function, and if not passed will be the average loss of the supplied model over num_chains batches.- Parameters:
num_draws (int) – Number of samples to draw (should be identical to
num_draws
passed todevinterp.slt.sampler.sample
)num_chains (int) – Number of chains to run (should be identical to
num_chains
passed todevinterp.slt.sampler.sample
)temperature (int) – Temperature, float (default: 1., set by sample() to utils.optimal_temperature(dataloader)=len(batch_size)/np.log(len(batch_size)))
device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’
devinterp.slt.loss module
- class devinterp.slt.loss.OnlineLossStatistics(base_callback: OnlineLLCEstimator)[source]
Bases:
SamplerCallback
Derivative callback that computes various loss statistics for
OnlineLLCEstimator()
. Must be called after the baseOnlineLLCEstimator()
has been called at each draw.See the diagnostics notebook for examples on how to use this to diagnose your sample health.
- Parameters:
base_callback (
OnlineLLCEstimator()
) – Base callback that computes original loss metric.
Note
Requires losses to be computed first, so call using f.e.
devinterp.slt.sampler.sample(..., [llc_estimator_instance, ..., online_loss_stats_instance], ...)
- loss_hist_by_draw(draw: int = 0, bins: int = 10)[source]
Plots a histogram of chain losses for a given draw index.
- Parameters:
draw (int, optional) – Draw index to plot histogram for. Default is 0
bins (int, optional) – number of histogram bins. Default is 10
- sample()[source]
- Returns:
A dict
{"loss/percent_neg_steps": percent_neg_steps, "loss/percent_mean_neg_steps": percent_mean_neg_steps, "loss/percent_thresholded_neg_steps": percent_thresholded_neg_steps, "loss/z_scores": z_scores}
. (Only after runningdevinterp.slt.sampler.sample(..., [llc_estimator_instance, online_loss_stats_instance], ...)
)
devinterp.slt.mala module
- class devinterp.slt.mala.MalaAcceptanceRate(num_chains: int, num_draws: int, temperature: float, learning_rate: float, device: device | str = 'cpu')[source]
Bases:
SamplerCallback
Callback for computing MALA acceptance rate.
- num_draws
Number of samples to draw. (should be identical to param passed to sample())
- Type:
int
- num_chains
Number of chains to run. (should be identical to param passed to sample())
- Type:
int
- temperature
Temperature used to calculate the LLC.
- Type:
float
- learning_rate
Learning rate of the model.
- Type:
int
- device
Device to perform computations on, e.g., ‘cpu’ or ‘cuda’.
- Type:
Union[torch.device, str]
- devinterp.slt.mala.mala_acceptance_probability(prev_params: Tensor | List[Tensor], prev_grads: Tensor | List[Tensor], prev_loss: Tensor, current_params: Tensor | List[Tensor], current_grads: Tensor | List[Tensor], current_loss: Tensor, learning_rate: float)[source]
Calculate the acceptance probability for a MALA transition. Parameters and gradients can either all be given as a tensor (all of the same shape) or all as lists of tensors (eg the parameters of a Module).
Args: prev_params: The previous point in parameter space. prev_grads: Gradient of the prev point in parameter space. prev_loss: Loss of the previous point in parameter space. current_params: The current point in parameter space. current_grads: Gradient of the current point in parameter space. current_loss: Loss of the current point in parameter space. learning_rate (float): Learning rate of the model.
Returns: float: Acceptance probability for the proposed transition.
devinterp.slt.norms module
- class devinterp.slt.norms.GradientNorm(num_chains: int, num_draws: int, p_norm: int = 2, device: device | str = 'cpu')[source]
Bases:
SamplerCallback
Callback for computing the norm of the gradients of the optimizer / sampler.
- Parameters:
num_draws (int) – Number of samples to draw (should be identical to
num_draws
passed todevinterp.slt.sampler.sample
)num_chains (int) – Number of chains to run (should be identical to
num_chains
passed todevinterp.slt.sampler.sample
)p_norm (int, optional) – Order of the norm to be computed (e.g., 2 for Euclidean norm). Default is 2
device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’
- class devinterp.slt.norms.NoiseNorm(num_chains: int, num_draws: int, p_norm: int = 2, device: device | str = 'cpu')[source]
Bases:
SamplerCallback
Callback for computing the norm of the noise added in the optimizer / sampler.
- Parameters:
num_draws (int) – Number of samples to draw (should be identical to
num_draws
passed todevinterp.slt.sampler.sample
)num_chains (int) – Number of chains to run (should be identical to
num_chains
passed todevinterp.slt.sampler.sample
)p_norm (int, optional) – Order of the norm to be computed (e.g., 2 for Euclidean norm). Default is 2
device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’
- class devinterp.slt.norms.WeightNorm(num_chains: int, num_draws: int, p_norm: int = 2, device: device | str = 'cpu')[source]
Bases:
SamplerCallback
Callback for computing the norm of the weights over the sampling process.
- Parameters:
num_draws (int) – Number of samples to draw (should be identical to
num_draws
passed todevinterp.slt.sampler.sample
)num_chains (int) – Number of chains to run (should be identical to
num_chains
passed todevinterp.slt.sampler.sample
)p_norm (int, optional) – Order of the norm to be computed (e.g., 2 for Euclidean norm). Default is 2
device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’
devinterp.slt.sampler module
- devinterp.slt.sampler.sample(model: ~torch.nn.modules.module.Module, loader: ~torch.utils.data.dataloader.DataLoader, callbacks: ~typing.List[~devinterp.slt.callback.SamplerCallback], evaluate: ~typing.Callable[[~torch.nn.modules.module.Module, ~torch.Tensor], ~devinterp.utils.Outputs | ~typing.Dict[str, ~torch.Tensor] | ~typing.Tuple[~torch.Tensor, ...] | ~torch.Tensor] | None = None, sampling_method: ~typing.Type[~torch.optim.optimizer.Optimizer] = <class 'devinterp.optim.sgld.SGLD'>, optimizer_kwargs: ~typing.Dict[str, float | ~typing.Literal['adaptive']] | None = None, num_draws: int = 100, num_chains: int = 10, num_burnin_steps: int = 0, num_steps_bw_draws: int = 1, init_loss: float | None = None, grad_accum_steps: int = 1, cores: int = 1, seed: ~typing.List[int] | int | None = None, device: ~torch.device | str = device(type='cpu'), verbose: bool = True, optimize_over_per_model_param: ~typing.Dict[str, ~typing.List[bool]] | None = None)[source]
Sample model weights using a given sampling_method, supporting multiple chains/cores, and calculate the observables (loss, llc, etc.) for each callback passed along. The
update
,finalize
andsample
methods of eachSamplerCallback()
are called during sampling, after sampling, and atsampler_callback_object.sample()
respectively.After calling this function, the stats of interest live in the callback object.
- Parameters:
model (torch.nn.Module) – The neural network model.
loader (DataLoader) – DataLoader for input data.
evaluate (EvaluateFn) – Maps a model and batch of data to an object with a loss attribute.
callbacks (list[SamplerCallback]) – list of callbacks, each of type SamplerCallback
sampling_method (torch.optim.Optimizer, optional) – Sampling method to use (a PyTorch optimizer under the hood). Default is SGLD
optimizer_kwargs (dict, optional) – Keyword arguments for the PyTorch optimizer (used as sampler here). Default is None (using standard SGLD parameters as defined in the SGLD class)
num_draws (int, optional) – Number of samples to draw. Default is 100
num_chains (int, optional) – Number of chains to run. Default is 10
num_burnin_steps (int, optional) – Number of burn-in steps before sampling. Default is 0
num_steps_bw_draws (int, optional) – Number of steps between each draw. Default is 1
init_loss (float, optional) – Initial loss for use in LLCEstimator and OnlineLLCEstimator
cores (int, optional) – Number of cores for parallel execution. Default is 1
seed (int, optional) – Random seed(s) for sampling. Each chain gets a different (deterministic) seed if this is passed. Default is None
device (str or torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’
verbose (bool, optional) – whether to print sample chain progress. Default is True
- Raises:
ValueError – if derivative callbacks (f.e.
OnlineLossStatistics()
) are passed before base callbacks (f.e.OnlineLLCEstimator()
)Warning – if num_burnin_steps < num_draws
Warning – if num_draws > len(loader)
Warning – if using seeded runs
- Returns:
None (access LLCs or other observables through callback_object.sample())
devinterp.slt.trace module
- class devinterp.slt.trace.OnlineTraceStatistics(base_callback: SamplerCallback, attribute: str, device: device | str = 'cpu')[source]
Bases:
SamplerCallback
Derivative callback that computes mean/std statistics of a specified trace online. Must be called after the base callback has been called at each draw.
See the diagnostics notebook for examples on how to use this to diagnose your sample health.
- Parameters:
base_callback (
SamplerCallback()
) – Base callback that computes some metric.attribute (str) – Name of attribute to compute which mean/std statistics of.
device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’
- Raises:
ValueError if underlying trace does not have the requested
attribute
,num_chains
ornum_draws
.
Note
Requires base trace stats to be computed first, so call using f.e.
devinterp.slt.sampler.sample(..., [weight_norm_instance, online_trace_stats_instance], ...)
- sample()[source]
- Returns:
A dict
"{self.attribute}/chain/mean": mean_attribute_by_chain, "{self.attribute}/chain/std": std_attribute_by_chain, "{self.attribute}/draw/mean": mean_attribute_by_draw, "{self.attribute}/draw/std": std_attribute_by_draw}
. (Only after runningdevinterp.slt.sampler.sample(..., [some_thing_to_calc_stats_of, ..., trace_stats_instance], ...)
).
- sample_at_draw(draw=-1)[source]
- Parameters:
draw (int, optional) – draw index to return stats at. Default is -1
- Returns:
A dict
"{self.attribute}/chain/mean": mean_attribute_of_draw_by_chain, "{self.attribute}/chain/std": std_attribute_of_draw_by_chain, "{self.attribute}/draw/mean": mean_attribute_of_draw, "{self.attribute}/draw/std": std_attribute_of_draw}
. (Only after runningdevinterp.slt.sampler.sample(..., [some_thing_to_calc_stats_of, ..., trace_stats_instance], ...)
).
devinterp.slt.wbic module
- class devinterp.slt.wbic.OnlineWBICEstimator(num_chains: int, num_draws: int, n: int, device: device | str = 'cpu')[source]
Bases:
SamplerCallback
Callback for estimating the Widely Applicable Bayesian Information Criterion (WBIC) in an online fashion. The WBIC used here is just \(n * (\textrm{average sampled loss})\). (Watanabe, 2013)
- Parameters:
num_draws (int) – Number of samples to draw (should be identical to
num_draws
passed todevinterp.slt.sampler.sample
)num_chains (int) – Number of chains to run (should be identical to
num_chains
passed todevinterp.slt.sampler.sample
)n – Number of samples used to calculate the wbic.
n – int
device – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’
device – str | torch.device, optional