Source code for devinterp.slt.mala

from typing import Union, List
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor

from devinterp.slt.callback import SamplerCallback


[docs] def mala_acceptance_probability( prev_params: Union[Tensor, List[Tensor]], prev_grads: Union[Tensor, List[Tensor]], prev_loss: Tensor, current_params: Union[Tensor, List[Tensor]], current_grads: Union[Tensor, List[Tensor]], current_loss: Tensor, learning_rate: float, ): """ Calculate the acceptance probability for a MALA transition. Parameters and gradients can either all be given as a tensor (all of the same shape) or all as lists of tensors (eg the parameters of a Module). Args: prev_params: The previous point in parameter space. prev_grads: Gradient of the prev point in parameter space. prev_loss: Loss of the previous point in parameter space. current_params: The current point in parameter space. current_grads: Gradient of the current point in parameter space. current_loss: Loss of the current point in parameter space. learning_rate (float): Learning rate of the model. Returns: float: Acceptance probability for the proposed transition. """ if np.isnan(current_loss): return np.nan # convert tensors to lists with one element if not isinstance(prev_params, list): prev_params = [prev_params] if not isinstance(prev_grads, list): prev_grads = [prev_grads] if not isinstance(current_params, list): current_params = [current_params] if not isinstance(current_grads, list): current_grads = [current_grads] log_q_current_to_prev = 0 log_q_prev_to_current = 0 for current_point, current_grad, prev_point, prev_grad in zip( current_params, current_grads, prev_params, prev_grads, ): # Compute the log of the proposal probabilities (using the Gaussian proposal distribution) log_q_current_to_prev += -torch.sum( (prev_point - current_point - (learning_rate * 0.5 * -current_grad)) ** 2 ) / (2 * learning_rate) log_q_prev_to_current += -torch.sum( (current_point - prev_point - (learning_rate * 0.5 * -prev_grad)) ** 2 ) / (2 * learning_rate) acceptance_log_prob = ( log_q_current_to_prev - log_q_prev_to_current + prev_loss - current_loss ) return min(1.0, torch.exp(acceptance_log_prob))
[docs] class MalaAcceptanceRate(SamplerCallback): """ Callback for computing MALA acceptance rate. Attributes: num_draws (int): Number of samples to draw. (should be identical to param passed to sample()) num_chains (int): Number of chains to run. (should be identical to param passed to sample()) temperature (float): Temperature used to calculate the LLC. learning_rate (int): Learning rate of the model. device (Union[torch.device, str]): Device to perform computations on, e.g., 'cpu' or 'cuda'. """ def __init__( self, num_chains: int, num_draws: int, temperature: float, learning_rate: float, device: Union[torch.device, str] = "cpu", ): self.num_chains = num_chains self.num_draws = num_draws self.learning_rate = learning_rate self.temperature = temperature self.mala_acceptance_rate = torch.zeros( (num_chains, num_draws - 1), dtype=torch.float32 ).to(device) self.device = device self.current_params = [] self.current_grads = [] self.prev_params = [] self.prev_grads = [] self.prev_mala_loss = 0.0 def __call__(self, chain: int, draw: int, model: nn.Module, loss: float, optimizer): self.update(chain, draw, model, loss, optimizer) def update(self, chain: int, draw: int, model: nn.Module, loss: float, optimizer): # we need the grads & loss from the pass, but the current params are from after the step # (so we update those only after the calculation) self.current_grads = optimizer.dws # mala acceptance loss is different from pytorch supplied loss mala_loss = (loss * self.temperature).item() + optimizer.localization_loss if draw > 1: self.mala_acceptance_rate[chain, draw - 1] = ( mala_acceptance_probability( self.prev_params, self.prev_grads, self.prev_mala_loss, self.current_params, self.current_grads, mala_loss, self.learning_rate, ) ) # move new -> old, then update new after self.prev_params = self.current_params self.prev_grads = self.current_grads self.prev_mala_loss = mala_loss # params update only at the end, as decribed self.current_params = [ param.clone().detach() for param in model.parameters() if param.requires_grad ] def sample(self): return { "mala_accept/trace": self.mala_acceptance_rate.cpu().numpy(), "mala_accept/mean": np.mean(self.mala_acceptance_rate.cpu().numpy()), "mala_accept/std": np.std(self.mala_acceptance_rate.cpu().numpy()), }