pyzag.reparametrization

Helper methods for reparameterizing modules, for example to scale parameter values and gradients

class pyzag.reparametrization.RangeRescale(lb, ub, clamp=True)

Scale parameter within bounds

forward(X)

Go from scaled to natural parameters

Parameters:

X (torch.tensor) – scaled parameter values

forward_std_dev(X)

Go from the standard deviation of a scaled normal to the actual standard deviation

Parameters:

X (torch.tensor) – scaled standard deviation

reverse(X)

Go from natural to scaled parameter values

Parameters:

X (torch.tensor) – natural parameter values

reverse_std_dev(X)

Go from the standard deviation of the actual normal to the standard deviation of the scaled normal

Parameters:

X (torch.tensor) – natural standard deviation

class pyzag.reparametrization.Reparameterizer(map_dict, error_not_provided=False)

Reparameterize a torch Module by adding the appropriate rescale function to each parameter

Parameters:

map_dict (dict mapping str to rescaler) – dictionary mapping the parameter name to the appropriate rescaler

Keyword Arguments:

error_not_provided (bool) – if True, error out if a rescaler is missing