27#include "neml2/tensors/Tensor.h"
35template <std::
size_t N>
64 std::array<const VariableBase *, N>
args()
const;
67 std::array<TensorShapeRef, N + 1> get_intmd_sizes()
const;
68 std::array<TensorShapeRef, N + 1> get_base_sizes()
const;
77 const std::array<const VariableBase *, N + 1> _var_and_args = {};
86 mutable Tensor _deriv_assembly;
89 const std::string _debug_name =
"<anonymous>";
Derivative wrapper.
Definition VariableBase.h:38
const Tensor & tensor() const
Get the derivative.
Definition Derivative.h:52
std::array< const VariableBase *, N > args() const
Get the args.
Definition Derivative.cxx:76
bool operator==(const Derivative< N > &other) const
Equality operator (for searching in a container)
Definition Derivative.cxx:324
void set(const Tensor &val)
Set the derivative (given in assembly format)
Definition Derivative.cxx:348
Derivative & operator=(const Tensor &val)
Definition Derivative.cxx:122
const Tensor & get() const
Get the derivative in assembly format.
Definition Derivative.cxx:335
const VariableBase * var() const
Get the variable.
Definition Derivative.h:61
Base class of variable.
Definition VariableBase.h:53
Definition DiagnosticsInterface.cxx:30
c10::SmallVector< Size, 8 > TensorShape
Definition types.h:66
c10::ArrayRef< T > ArrayRef
Definition types.h:59