29#include "neml2/models/utils.h"
30#include "neml2/misc/types.h"
31#include "neml2/tensors/Tensor.h"
37template <std::
size_t N>
40template <
typename,
typename>
48 const std::string & sep);
139 virtual std::unique_ptr<VariableBase>
clone(std::optional<VariableName>
name = std::nullopt,
179 [[maybe_unused]] std::optional<TracerPrivilege> key = std::nullopt) = 0;
192 std::size_t deriv_intrsc_intmd_dim = 0,
193 std::size_t var_intrsc_intmd_dim = 0,
194 std::size_t arg_intrsc_intmd_dim = 0);
200 std::size_t deriv_intrsc_intmd_dim = 0,
201 std::size_t var_intrsc_intmd_dim = 0,
202 std::size_t arg1_intrsc_intmd_dim = 0,
203 std::size_t arg2_intrsc_intmd_dim = 0);
208 void request_AD(
const std::vector<const VariableBase *> & us);
214 void request_AD(
const std::vector<const VariableBase *> & u1s,
215 const std::vector<const VariableBase *> & u2s);
299#define FWD_VARIABLE_BINARY_OP(op) \
300 template <typename T1, \
302 typename = typename std::enable_if_t<std::is_base_of_v<VariableBase, T1> || \
303 std::is_base_of_v<VariableBase, T2>>> \
304 auto op(const T1 & a, const T2 & b) \
306 if constexpr (std::is_base_of_v<VariableBase, T1> && std::is_base_of_v<VariableBase, T2>) \
307 return op(a(), b()); \
309 if constexpr (std::is_base_of_v<VariableBase, T1> && !std::is_base_of_v<VariableBase, T2>) \
312 if constexpr (!std::is_base_of_v<VariableBase, T1> && std::is_base_of_v<VariableBase, T2>) \
317FWD_VARIABLE_BINARY_OP(
operator+);
318FWD_VARIABLE_BINARY_OP(
operator-);
319FWD_VARIABLE_BINARY_OP(
operator*);
320FWD_VARIABLE_BINARY_OP(
operator/);
322FWD_VARIABLE_BINARY_OP(
operator>);
323FWD_VARIABLE_BINARY_OP(
operator<);
324FWD_VARIABLE_BINARY_OP(
operator>=);
325FWD_VARIABLE_BINARY_OP(
operator<=);
326FWD_VARIABLE_BINARY_OP(
operator&&);
327FWD_VARIABLE_BINARY_OP(
operator||);
328FWD_VARIABLE_BINARY_OP(
operator==);
329FWD_VARIABLE_BINARY_OP(
operator!=);
The DependencyResolver identifies and resolves the dependencies among a set of objects derived from D...
Definition DependencyResolver.h:47
Derivative wrapper.
Definition Derivative.h:66
The base class for all constitutive models.
Definition Model.h:82
void request_AD(const VariableBase &u1, const VariableBase &u2)
const Derivative< 1 > & d(const VariableBase &arg) const
DerivContainer & derivatives()
Definition VariableBase.h:225
VariableBase(VariableBase &&)=delete
TraceableTensorShape batch_sizes() const
const VariableName & base_name() const
Base name without the history suffix (e.g. "stress" for "stress~1").
Definition VariableBase.h:84
Size static_size(Size i) const
Model *const _owner
The model which declared this variable.
Definition VariableBase.h:256
Size intmd_size(Size i) const
virtual void ref(VariableBase &other)=0
Reference another variable.
VariableBase(VariableName name_in, Model *owner, TensorShapeRef base_shape)
The canonical constructor.
bool requires_grad() const
Check if this variable is part of the AD function graph.
bool _mutable
When referenced by another variable, whether to allow the referencing variable to mutate my value.
Definition VariableBase.h:272
const Model & owner() const
virtual const VariableBase * ref() const =0
Get the referencing variable (returns this if this is a storing variable).
TensorShapeRef intmd_sizes() const
std::size_t _history_order
History order parsed from the variable name (0 = current, 1 = "foo~1", etc.).
Definition VariableBase.h:259
void clear_derivatives()
Clear only the derivatives.
const TraceableSize & dynamic_size(Size i) const
Tensor zeros(const TensorOptions &options) const
Make zeros tensor with the shape of this variable.
std::size_t history_order() const
History order: 0 for current variables, 1 for "foo~1", 2 for "foo~2", etc.
Definition VariableBase.h:81
virtual std::unique_ptr< VariableBase > clone(std::optional< VariableName > name=std::nullopt, Model *owner=nullptr) const =0
Clone this variable.
virtual VariableBase * ref()=0
void request_AD(const VariableBase &u)
TraceableSize batch_size(Size i) const
std::vector< DerivTuple > DerivContainer
Definition VariableBase.h:219
TensorShapeRef sizes() const
std::tuple< Derivative< 2 >, const VariableBase *, const VariableBase * > SecDerivTuple
Definition VariableBase.h:220
TensorShapeRef static_sizes() const
const VariableBase & provider(const DependencyResolver< Model, VariableName > &) const
Get the provider in the dependency graph.
virtual TensorOptions options() const =0
Tensor options.
virtual const VariableBase * direct_ref() const =0
Get the direct referencing variable (returns nullptr if this is a storing variable).
virtual Device device() const =0
Device.
VariableName _base_name
Base name without history suffix (equals _name when history_order == 0).
Definition VariableBase.h:262
Size base_size(Size i) const
virtual void zero(const TensorOptions &options)=0
Set the variable value to zero.
VariableBase(const VariableBase &)=delete
virtual VariableBase * direct_ref()=0
virtual bool defined() const =0
const DerivContainer & derivatives() const
Partial derivatives.
Definition VariableBase.h:224
const TensorShape _base_sizes
Base shape of the variable.
Definition VariableBase.h:269
VariableBase & operator=(VariableBase &&)=delete
const DerivContainer & total_derivatives(const DependencyResolver< Model, VariableName > &) const
Get total derivatives with respect to leaf variables.
const VariableName & name() const
Name of this variable.
Definition VariableBase.h:78
bool is_mutable() const
Whether this variable is mutable when it is referenced by another variable.
virtual void clear()
Clear the variable value and derivatives.
virtual void operator=(const Tensor &val)=0
Assignment operator.
TensorShape _cached_intmd_sizes
Definition VariableBase.h:266
bool has_derivative(const VariableName &v1name, const VariableName &v2name) const
Whether the variable has non-zero second derivative with respect to another variable.
VariableBase & operator=(const VariableBase &)=delete
virtual Dtype scalar_type() const =0
Scalar type.
virtual void assign(const Tensor &val, std::optional< TracerPrivilege > key=std::nullopt)=0
Assignment operator (with TracerPrivilege).
void set_mutable(bool m)
Allow/disable mutation of this variable when it is referenced by another variable.
const VariableName _name
Name of the variable.
Definition VariableBase.h:253
SecDerivContainer & second_derivatives()
Definition VariableBase.h:229
void request_AD(const std::vector< const VariableBase * > &us)
virtual bool owning() const =0
Check if this is an owning variable.
std::vector< SecDerivTuple > SecDerivContainer
Definition VariableBase.h:221
void clear_chain_rule_cache(const DependencyResolver< Model, VariableName > &) const
Clear chain rule cache.
void request_AD(const std::vector< const VariableBase * > &u1s, const std::vector< const VariableBase * > &u2s)
TensorShapeRef base_sizes() const
Derivative< 2 > & d2(const VariableBase &arg1, const VariableBase &arg2, std::size_t deriv_intrsc_intmd_dim=0, std::size_t var_intrsc_intmd_dim=0, std::size_t arg1_intrsc_intmd_dim=0, std::size_t arg2_intrsc_intmd_dim=0)
Wrapper for assigning second partial derivative.
virtual TensorType type() const =0
Variable tensor type.
const SecDerivContainer & total_second_derivatives(const DependencyResolver< Model, VariableName > &) const
Get total second derivatives with respect to leaf variables.
Derivative< 1 > & d(const VariableBase &arg, std::size_t deriv_intrsc_intmd_dim=0, std::size_t var_intrsc_intmd_dim=0, std::size_t arg_intrsc_intmd_dim=0)
Wrapper for assigning partial derivative.
bool has_derivative(const VariableName &vname) const
Whether the variable has non-zero derivative with respect to another variable.
bool is_leaf(const DependencyResolver< Model, VariableName > &) const
virtual Tensor tensor() const =0
Get the variable value cast to Tensor.
virtual const TraceableTensorShape & dynamic_sizes() const =0
const SecDerivContainer & second_derivatives() const
Partial second derivatives.
Definition VariableBase.h:228
std::tuple< Derivative< 1 >, const VariableBase * > DerivTuple
Definition VariableBase.h:218
const Derivative< 2 > & d2(const VariableBase &arg1, const VariableBase &arg2) const
virtual void requires_grad_(bool req=true)=0
Mark this variable as a leaf variable in tracing function graph for AD.
Definition DiagnosticsInterface.h:31
c10::Device Device
Definition types.h:69
std::pair< VariableName, std::size_t > parse_history(const VariableName &name, const std::string &sep)
c10::SmallVector< Size, 8 > TensorShape
Definition types.h:72
std::string name(ElasticConstant p)
int64_t Size
Definition types.h:71
std::string VariableName
Definition types.h:75
TensorType
Definition tensors.h:56
c10::TensorOptions TensorOptions
Definition types.h:66
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:73
c10::ScalarType Dtype
Definition types.h:67
Traceable size.
Definition TraceableSize.h:42
Traceable tensor shape.
Definition TraceableTensorShape.h:38