from typing import Union, List
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from devinterp.slt.callback import SamplerCallback
[docs]
def mala_acceptance_probability(
prev_params: Union[Tensor, List[Tensor]],
prev_grads: Union[Tensor, List[Tensor]],
prev_loss: Tensor,
current_params: Union[Tensor, List[Tensor]],
current_grads: Union[Tensor, List[Tensor]],
current_loss: Tensor,
learning_rate: float,
):
"""
Calculate the acceptance probability for a MALA transition. Parameters and
gradients can either all be given as a tensor (all of the same shape) or
all as lists of tensors (eg the parameters of a Module).
Args:
prev_params: The previous point in parameter space.
prev_grads: Gradient of the prev point in parameter space.
prev_loss: Loss of the previous point in parameter space.
current_params: The current point in parameter space.
current_grads: Gradient of the current point in parameter space.
current_loss: Loss of the current point in parameter space.
learning_rate (float): Learning rate of the model.
Returns:
float: Acceptance probability for the proposed transition.
"""
if np.isnan(current_loss):
return np.nan
# convert tensors to lists with one element
if not isinstance(prev_params, list):
prev_params = [prev_params]
if not isinstance(prev_grads, list):
prev_grads = [prev_grads]
if not isinstance(current_params, list):
current_params = [current_params]
if not isinstance(current_grads, list):
current_grads = [current_grads]
log_q_current_to_prev = 0
log_q_prev_to_current = 0
for current_point, current_grad, prev_point, prev_grad in zip(
current_params,
current_grads,
prev_params,
prev_grads,
):
# Compute the log of the proposal probabilities (using the Gaussian proposal distribution)
log_q_current_to_prev += -torch.sum(
(prev_point - current_point - (learning_rate * 0.5 * -current_grad)) ** 2
) / (2 * learning_rate)
log_q_prev_to_current += -torch.sum(
(current_point - prev_point - (learning_rate * 0.5 * -prev_grad)) ** 2
) / (2 * learning_rate)
acceptance_log_prob = (
log_q_current_to_prev - log_q_prev_to_current + prev_loss - current_loss
)
return min(1.0, torch.exp(acceptance_log_prob))
[docs]
class MalaAcceptanceRate(SamplerCallback):
"""
Callback for computing MALA acceptance rate.
Attributes:
num_draws (int): Number of samples to draw. (should be identical to param passed to sample())
num_chains (int): Number of chains to run. (should be identical to param passed to sample())
temperature (float): Temperature used to calculate the LLC.
learning_rate (int): Learning rate of the model.
device (Union[torch.device, str]): Device to perform computations on, e.g., 'cpu' or 'cuda'.
"""
def __init__(
self,
num_chains: int,
num_draws: int,
temperature: float,
learning_rate: float,
device: Union[torch.device, str] = "cpu",
):
self.num_chains = num_chains
self.num_draws = num_draws
self.learning_rate = learning_rate
self.temperature = temperature
self.mala_acceptance_rate = torch.zeros(
(num_chains, num_draws - 1), dtype=torch.float32
).to(device)
self.device = device
self.current_params = []
self.current_grads = []
self.prev_params = []
self.prev_grads = []
self.prev_mala_loss = 0.0
def __call__(self, chain: int, draw: int, model: nn.Module, loss: float, optimizer):
self.update(chain, draw, model, loss, optimizer)
def update(self, chain: int, draw: int, model: nn.Module, loss: float, optimizer):
# we need the grads & loss from the pass, but the current params are from after the step
# (so we update those only after the calculation)
self.current_grads = optimizer.dws
# mala acceptance loss is different from pytorch supplied loss
mala_loss = (loss * self.temperature).item() + optimizer.localization_loss
if draw > 1:
self.mala_acceptance_rate[chain, draw - 1] = (
mala_acceptance_probability(
self.prev_params,
self.prev_grads,
self.prev_mala_loss,
self.current_params,
self.current_grads,
mala_loss,
self.learning_rate,
)
)
# move new -> old, then update new after
self.prev_params = self.current_params
self.prev_grads = self.current_grads
self.prev_mala_loss = mala_loss
# params update only at the end, as decribed
self.current_params = [
param.clone().detach() for param in model.parameters() if param.requires_grad
]
def sample(self):
return {
"mala_accept/trace": self.mala_acceptance_rate.cpu().numpy(),
"mala_accept/mean": np.mean(self.mala_acceptance_rate.cpu().numpy()),
"mala_accept/std": np.std(self.mala_acceptance_rate.cpu().numpy()),
}