27#include "neml2/models/map_types_fwd.h"
28#include "neml2/models/utils.h"
29#include "neml2/base/LabeledAxisAccessor.h"
30#include "neml2/misc/types.h"
31#include "neml2/tensors/Tensor.h"
37template <std::
size_t N>
40template <
typename,
typename>
182 [[maybe_unused]] std::optional<TracerPrivilege> key = std::nullopt) = 0;
195 std::size_t deriv_intrsc_intmd_dim = 0,
196 std::size_t var_intrsc_intmd_dim = 0,
197 std::size_t arg_intrsc_intmd_dim = 0);
203 std::size_t deriv_intrsc_intmd_dim = 0,
204 std::size_t var_intrsc_intmd_dim = 0,
205 std::size_t arg1_intrsc_intmd_dim = 0,
206 std::size_t arg2_intrsc_intmd_dim = 0);
211 void request_AD(
const std::vector<const VariableBase *> & us);
217 void request_AD(
const std::vector<const VariableBase *> & u1s,
218 const std::vector<const VariableBase *> & u2s);
296#define FWD_VARIABLE_BINARY_OP(op) \
297 template <typename T1, \
299 typename = typename std::enable_if_t<std::is_base_of_v<VariableBase, T1> || \
300 std::is_base_of_v<VariableBase, T2>>> \
301 auto op(const T1 & a, const T2 & b) \
303 if constexpr (std::is_base_of_v<VariableBase, T1> && std::is_base_of_v<VariableBase, T2>) \
304 return op(a(), b()); \
306 if constexpr (std::is_base_of_v<VariableBase, T1> && !std::is_base_of_v<VariableBase, T2>) \
309 if constexpr (!std::is_base_of_v<VariableBase, T1> && std::is_base_of_v<VariableBase, T2>) \
314FWD_VARIABLE_BINARY_OP(
operator+);
315FWD_VARIABLE_BINARY_OP(
operator-);
316FWD_VARIABLE_BINARY_OP(
operator*);
317FWD_VARIABLE_BINARY_OP(
operator/);
319FWD_VARIABLE_BINARY_OP(
operator>);
320FWD_VARIABLE_BINARY_OP(
operator<);
321FWD_VARIABLE_BINARY_OP(
operator>=);
322FWD_VARIABLE_BINARY_OP(
operator<=);
323FWD_VARIABLE_BINARY_OP(
operator&&);
324FWD_VARIABLE_BINARY_OP(
operator||);
325FWD_VARIABLE_BINARY_OP(
operator==);
326FWD_VARIABLE_BINARY_OP(
operator!=);
The DependencyResolver identifies and resolves the dependencies among a set of objects derived from D...
Definition DependencyResolver.h:47
Derivative wrapper.
Definition Derivative.h:66
The base class for all constitutive models.
Definition Model.h:83
void request_AD(const VariableBase &u1, const VariableBase &u2)
const Derivative< 1 > & d(const VariableBase &arg) const
DerivContainer & derivatives()
Definition VariableBase.h:228
VariableBase(VariableBase &&)=delete
TraceableTensorShape batch_sizes() const
bool is_old_force() const
Size static_size(Size i) const
Model *const _owner
The model which declared this variable.
Definition VariableBase.h:259
Size intmd_size(Size i) const
virtual void ref(VariableBase &other)=0
Reference another variable.
VariableBase(VariableName name_in, Model *owner, TensorShapeRef base_shape)
The canonical constructor.
bool requires_grad() const
Check if this variable is part of the AD function graph.
bool is_parameter() const
bool _mutable
When referenced by another variable, whether to allow the referencing variable to mutate my value.
Definition VariableBase.h:269
const Model & owner() const
virtual const VariableBase * ref() const =0
Get the referencing variable (returns this if this is a storing variable).
TensorShapeRef intmd_sizes() const
void clear_derivatives()
Clear only the derivatives.
const TraceableSize & dynamic_size(Size i) const
Tensor zeros(const TensorOptions &options) const
Make zeros tensor with the shape of this variable.
virtual VariableBase * ref()=0
void request_AD(const VariableBase &u)
TraceableSize batch_size(Size i) const
std::vector< DerivTuple > DerivContainer
Definition VariableBase.h:222
TensorShapeRef sizes() const
std::tuple< Derivative< 2 >, const VariableBase *, const VariableBase * > SecDerivTuple
Definition VariableBase.h:223
TensorShapeRef static_sizes() const
const VariableBase & provider(const DependencyResolver< Model, VariableName > &) const
Get the provider in the dependency graph.
virtual TensorOptions options() const =0
Tensor options.
virtual const VariableBase * direct_ref() const =0
Get the direct referencing variable (returns nullptr if this is a storing variable).
virtual Device device() const =0
Device.
virtual std::unique_ptr< VariableBase > clone(const VariableName &name={}, Model *owner=nullptr) const =0
Clone this variable.
Size base_size(Size i) const
bool is_solve_dependent() const
virtual void zero(const TensorOptions &options)=0
Set the variable value to zero.
VariableBase(const VariableBase &)=delete
virtual VariableBase * direct_ref()=0
virtual bool defined() const =0
const DerivContainer & derivatives() const
Partial derivatives.
Definition VariableBase.h:227
const TensorShape _base_sizes
Base shape of the variable.
Definition VariableBase.h:266
VariableBase & operator=(VariableBase &&)=delete
const DerivContainer & total_derivatives(const DependencyResolver< Model, VariableName > &) const
Get total derivatives with respect to leaf variables.
const VariableName & name() const
Name of this variable.
Definition VariableBase.h:73
bool is_mutable() const
Whether this variable is mutable when it is referenced by another variable.
virtual void clear()
Clear the variable value and derivatives.
virtual void operator=(const Tensor &val)=0
Assignment operator.
TensorShape _cached_intmd_sizes
Definition VariableBase.h:263
bool has_derivative(const VariableName &v1name, const VariableName &v2name) const
Whether the variable has non-zero second derivative with respect to another variable.
bool is_dependent() const
Check if the derivative with respect to this variable should be evaluated.
VariableBase & operator=(const VariableBase &)=delete
virtual Dtype scalar_type() const =0
Scalar type.
virtual void assign(const Tensor &val, std::optional< TracerPrivilege > key=std::nullopt)=0
Assignment operator (with TracerPrivilege).
void set_mutable(bool m)
Allow/disable mutation of this variable when it is referenced by another variable.
const VariableName _name
Name of the variable.
Definition VariableBase.h:256
SecDerivContainer & second_derivatives()
Definition VariableBase.h:232
void request_AD(const std::vector< const VariableBase * > &us)
virtual bool owning() const =0
Check if this is an owning variable.
std::vector< SecDerivTuple > SecDerivContainer
Definition VariableBase.h:224
void clear_chain_rule_cache(const DependencyResolver< Model, VariableName > &) const
Clear chain rule cache.
void request_AD(const std::vector< const VariableBase * > &u1s, const std::vector< const VariableBase * > &u2s)
TensorShapeRef base_sizes() const
Derivative< 2 > & d2(const VariableBase &arg1, const VariableBase &arg2, std::size_t deriv_intrsc_intmd_dim=0, std::size_t var_intrsc_intmd_dim=0, std::size_t arg1_intrsc_intmd_dim=0, std::size_t arg2_intrsc_intmd_dim=0)
Wrapper for assigning second partial derivative.
virtual TensorType type() const =0
Variable tensor type.
const SecDerivContainer & total_second_derivatives(const DependencyResolver< Model, VariableName > &) const
Get total second derivatives with respect to leaf variables.
Derivative< 1 > & d(const VariableBase &arg, std::size_t deriv_intrsc_intmd_dim=0, std::size_t var_intrsc_intmd_dim=0, std::size_t arg_intrsc_intmd_dim=0)
Wrapper for assigning partial derivative.
bool has_derivative(const VariableName &vname) const
Whether the variable has non-zero derivative with respect to another variable.
bool is_leaf(const DependencyResolver< Model, VariableName > &) const
virtual Tensor tensor() const =0
Get the variable value cast to Tensor.
virtual const TraceableTensorShape & dynamic_sizes() const =0
const SecDerivContainer & second_derivatives() const
Partial second derivatives.
Definition VariableBase.h:231
std::tuple< Derivative< 1 >, const VariableBase * > DerivTuple
Definition VariableBase.h:221
bool is_old_state() const
const Derivative< 2 > & d2(const VariableBase &arg1, const VariableBase &arg2) const
virtual void requires_grad_(bool req=true)=0
Mark this variable as a leaf variable in tracing function graph for AD.
Definition DiagnosticsInterface.h:31
c10::Device Device
Definition types.h:69
c10::SmallVector< Size, 8 > TensorShape
Definition types.h:72
int64_t Size
Definition types.h:71
TensorType
Definition tensors.h:56
LabeledAxisAccessor VariableName
Definition LabeledAxisAccessor.h:185
c10::TensorOptions TensorOptions
Definition types.h:66
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:73
c10::ScalarType Dtype
Definition types.h:67
Traceable size.
Definition TraceableSize.h:42
Traceable tensor shape.
Definition TraceableTensorShape.h:38