NEML2 2.1.0
Loading...
Searching...
No Matches
Variable.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/models/VariableBase.h"
28
29namespace neml2
30{
35template <typename T>
36class Variable : public VariableBase
37{
38public:
40 : VariableBase(std::move(name_in), owner, T::const_base_sizes),
41 _ref(nullptr)
42 {
43 }
44
45 TensorType type() const override;
46
48 // These methods mirror TensorBase
51 bool defined() const override;
53 TensorOptions options() const override;
55 Dtype scalar_type() const override;
57 Device device() const override;
59
60 const TraceableTensorShape & dynamic_sizes() const override;
61
62 std::unique_ptr<VariableBase> clone(const VariableName & name = {},
63 Model * owner = nullptr) const override;
64
65 void ref(VariableBase & var) override;
66 const VariableBase * ref() const override { return _ref ? _ref->ref() : this; }
67 VariableBase * ref() override { return _ref ? _ref->ref() : this; }
68 const VariableBase * direct_ref() const override { return _ref; }
69 VariableBase * direct_ref() override { return _ref; }
70
71 bool owning() const override { return !_ref; }
72
73 void zero(const TensorOptions & options) override;
74
75 Tensor tensor() const override;
76
77 void requires_grad_(bool req = true) override;
78
79 void assign(const Tensor & val,
80 [[maybe_unused]] std::optional<TracerPrivilege> key = std::nullopt) override;
81
82 void operator=(const Tensor & val) override;
83
85 const T & operator()() const { return owning() ? _value : (*_ref)(); }
86
88 T operator-() const { return -operator()(); }
89
90 void clear() override;
91
92protected:
95
98};
99} // namespace neml2
The base class for all constitutive models.
Definition Model.h:83
Definition Tensor.h:49
const Model & owner() const
VariableBase()=default
const VariableName & name() const
Name of this variable.
Definition VariableBase.h:73
bool defined() const override
Dtype scalar_type() const override
Scalar type.
Tensor tensor() const override
Get the variable value cast to Tensor.
void requires_grad_(bool req=true) override
Mark this variable as a leaf variable in tracing function graph for AD.
Variable< T > * _ref
The variable referenced by this (nullptr if this is a storing variable).
Definition Variable.h:94
void operator=(const Tensor &val) override
Assignment operator.
Device device() const override
Device.
Variable(VariableName name_in, Model *owner)
Definition Variable.h:39
VariableBase * direct_ref() override
Definition Variable.h:69
const VariableBase * ref() const override
Get the referencing variable (returns this if this is a storing variable).
Definition Variable.h:66
VariableBase * ref() override
Definition Variable.h:67
void assign(const Tensor &val, std::optional< TracerPrivilege > key=std::nullopt) override
Assignment operator (with TracerPrivilege).
void zero(const TensorOptions &options) override
Set the variable value to zero.
T operator-() const
Negation.
Definition Variable.h:88
TensorOptions options() const override
Tensor options.
TensorType type() const override
Variable tensor type.
const TraceableTensorShape & dynamic_sizes() const override
bool owning() const override
Check if this is an owning variable.
Definition Variable.h:71
std::unique_ptr< VariableBase > clone(const VariableName &name={}, Model *owner=nullptr) const override
Clone this variable.
T _value
Variable value (undefined if this is a referencing variable).
Definition Variable.h:97
const VariableBase * direct_ref() const override
Get the direct referencing variable (returns nullptr if this is a storing variable).
Definition Variable.h:68
void clear() override
Clear the variable value and derivatives.
void ref(VariableBase &var) override
Reference another variable.
const T & operator()() const
Variable value.
Definition Variable.h:85
Definition DiagnosticsInterface.h:31
c10::Device Device
Definition types.h:69
TensorType
Definition tensors.h:56
LabeledAxisAccessor VariableName
Definition LabeledAxisAccessor.h:185
c10::TensorOptions TensorOptions
Definition types.h:66
c10::ScalarType Dtype
Definition types.h:67
Traceable tensor shape.
Definition TraceableTensorShape.h:38