NEML2 2.1.0
Loading...
Searching...
No Matches

Detailed Description

ODF represented from a Kernel Density Estimate

Args:
    X (neml2.tensors.Rot): rotations, must have a single batch dimension
    kernel (Kernel): kernel function
Inheritance diagram for KDEODF:

Public Member Functions

 __init__ (self, X, kernel)
 optimize_kernel (self, miter=50, verbose=False, lr=1.0e-2)
 leave_out (self, i)
 forward (self, Y)
 __init__ (self, X, kernel)
 forward (self, Y)
 leave_out (self, i)
 optimize_kernel (self, miter=50, verbose=False, lr=0.01)
Public Member Functions inherited from ODF
 __init__ (self, X)
 n (self)
 texture_index (self, deg=5)
 __init__ (self, X)
 texture_index (self, deg=5)
 n (self)

Public Attributes

 kernel = kernel
Public Attributes inherited from ODF
 X = X

Constructor & Destructor Documentation

◆ __init__() [1/2]

__init__ ( self,
X,
kernel )

◆ __init__() [2/2]

__init__ ( self,
X,
kernel )

Member Function Documentation

◆ forward() [1/2]

forward ( self,
Y )
Calculate the probability density at each point in Y

Args:
    Y (neml2.tensors.Rot): rotations with arbitrary batch shape

Returns:
    torch.tensor with the probabilities

◆ forward() [2/2]

forward ( self,
Y )
Calculate the probability density at each point in Y

        Args:
            Y (neml2.tensors.Rot): rotations with arbitrary batch shape

        Returns:
            torch.tensor with the probabilities

◆ leave_out() [1/2]

leave_out ( self,
i )
Calculate the second term of the cross-validation loss, leaving out the ith point

Args:
    i (int): index of the point to leave out

◆ leave_out() [2/2]

leave_out ( self,
i )
Calculate the second term of the cross-validation loss, leaving out the ith point

        Args:
            i (int): index of the point to leave out

◆ optimize_kernel() [1/2]

optimize_kernel ( self,
miter = 50,
verbose = False,
lr = 0.01 )
Optimize the kernel half width by cross-validation

        Keyword Args:
            miter (int): optimization iterations
            verbose (bool): if true print convergence progress
            lr (float): learning rate
            sf (float): fraction of data split out for validation

◆ optimize_kernel() [2/2]

optimize_kernel ( self,
miter = 50,
verbose = False,
lr = 1.0e-2 )
Optimize the kernel half width by cross-validation

Keyword Args:
    miter (int): optimization iterations
    verbose (bool): if true print convergence progress
    lr (float): learning rate
    sf (float): fraction of data split out for validation

Member Data Documentation

◆ kernel

kernel = kernel