NEML2 2.0.0
Loading...
Searching...
No Matches
Variable.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/models/VariableBase.h"
28
29namespace neml2
30{
35template <typename T>
36class Variable : public VariableBase
37{
38public:
40 : VariableBase(std::move(name_in), owner, T::const_base_sizes, dep_intmd_dims),
41 _ref(nullptr)
42 {
43 }
44
45 TensorType type() const override;
46
48 // These methods mirror TensorBase
51 bool defined() const override;
53 TensorOptions options() const override;
55 Dtype scalar_type() const override;
57 Device device() const override;
59
60 const TraceableTensorShape & dynamic_sizes() const override;
61
62 std::unique_ptr<VariableBase> clone(const VariableName & name = {},
63 Model * owner = nullptr) const override;
64
65 void ref(const VariableBase & var, bool ref_is_mutable = false) override;
66
67 const VariableBase * ref() const override { return _ref ? _ref->ref() : this; }
68
69 bool owning() const override { return !_ref; }
70
71 void zero(const TensorOptions & options) override;
72
73 void set(const Tensor & val, std::optional<TracerPrivilege> key) override;
74
75 Tensor get() const override;
76
77 Tensor tensor() const override;
78
79 void requires_grad_(bool req = true) override;
80
81 void operator=(const Tensor & val) override;
82
84 const T & operator()() const { return owning() ? _value : (*_ref)(); }
85
87 T operator-() const { return -operator()(); }
88
89 void clear() override;
90
91protected:
94
96 bool _ref_is_mutable = false;
97
100};
101} // namespace neml2
The accessor containing all the information needed to access an item in a LabeledAxis.
Definition LabeledAxisAccessor.h:56
The base class for all constitutive models.
Definition Model.h:70
Definition Tensor.h:47
Base class of variable.
Definition VariableBase.h:53
const Model & owner() const
Definition VariableBase.cxx:51
VariableBase()=default
virtual void ref(const VariableBase &other, bool ref_is_mutable=false)=0
Reference another variable.
const VariableName & name() const
Name of this variable.
Definition VariableBase.h:69
ArrayRef< Size > dep_intmd_dims() const
Get dependent intermediate dimensions for derivative calculation.
Definition VariableBase.cxx:237
Concrete definition of a variable.
Definition VariableStore.h:41
bool defined() const override
Definition Variable.cxx:45
Dtype scalar_type() const override
Scalar type.
Definition Variable.cxx:59
Tensor get() const override
Get the variable value in assembly format.
Definition Variable.cxx:167
const Variable< T > * _ref
The variable referenced by this (nullptr if this is a storing variable)
Definition Variable.h:93
Tensor tensor() const override
Get the variable value cast to Tensor.
Definition Variable.cxx:178
void requires_grad_(bool req=true) override
Mark this variable as a leaf variable in tracing function graph for AD.
Definition Variable.cxx:188
Variable(VariableName name_in, Model *owner, TensorShapeRef dep_intmd_dims={})
Definition Variable.h:39
bool _ref_is_mutable
Whether mutating the referenced variable is allowed.
Definition Variable.h:96
void operator=(const Tensor &val) override
Assignment operator.
Definition Variable.cxx:199
Device device() const override
Device.
Definition Variable.cxx:66
const VariableBase * ref() const override
Get the referencing variable (returns this if this is a storing variable)
Definition Variable.h:67
void zero(const TensorOptions &options) override
Set the variable value to zero.
Definition Variable.cxx:125
T operator-() const
Negation.
Definition Variable.h:87
TensorOptions options() const override
Tensor options.
Definition Variable.cxx:52
TensorType type() const override
Variable tensor type.
Definition Variable.cxx:38
const TraceableTensorShape & dynamic_sizes() const override
Definition Variable.cxx:73
bool owning() const override
Check if this is an owning variable.
Definition Variable.h:69
void set(const Tensor &val, std::optional< TracerPrivilege > key) override
Set the variable value from a Tensor in assembly format.
Definition Variable.cxx:146
std::unique_ptr< VariableBase > clone(const VariableName &name={}, Model *owner=nullptr) const override
Clone this variable.
Definition Variable.cxx:80
T _value
Variable value (undefined if this is a referencing variable)
Definition Variable.h:99
void clear() override
Clear the variable value and derivatives.
Definition Variable.cxx:231
const T & operator()() const
Variable value.
Definition Variable.h:84
Definition DiagnosticsInterface.cxx:30
c10::Device Device
Definition types.h:63
LabeledAxisAccessor VariableName
Definition LabeledAxisAccessor.h:185
TensorType
Definition tensors.h:56
c10::TensorOptions TensorOptions
Definition types.h:60
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:67
c10::ScalarType Dtype
Definition types.h:61