pyzag.reparametrization
Helper methods for reparameterizing modules, for example to scale parameter values and gradients
- class pyzag.reparametrization.RangeRescale(lb, ub, clamp=True)
Scale parameter within bounds
- forward(X)
Go from scaled to natural parameters
- Parameters:
X (torch.tensor) – scaled parameter values
- forward_std_dev(X)
Go from the standard deviation of a scaled normal to the actual standard deviation
- Parameters:
X (torch.tensor) – scaled standard deviation
- reverse(X)
Go from natural to scaled parameter values
- Parameters:
X (torch.tensor) – natural parameter values
- reverse_std_dev(X)
Go from the standard deviation of the actual normal to the standard deviation of the scaled normal
- Parameters:
X (torch.tensor) – natural standard deviation
- class pyzag.reparametrization.Reparameterizer(map_dict, error_not_provided=False)
Reparameterize a torch Module by adding the appropriate rescale function to each parameter
- Parameters:
map_dict (dict mapping str to rescaler) – dictionary mapping the parameter name to the appropriate rescaler
- Keyword Arguments:
error_not_provided (bool) – if True, error out if a rescaler is missing