NEML2 2.0.0
Loading...
Searching...
No Matches
types.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include <torch/types.h>
28
29namespace neml2
30{
31using Real = double;
32using Integer = int;
33using Size = int64_t;
34using TensorShape = torch::SmallVector<Size>;
35using TensorShapeRef = torch::IntArrayRef;
36
37// Bring in torch::indexing
38namespace indexing
39{
40using namespace torch::indexing;
41using TensorIndices = torch::SmallVector<TensorIndex>;
42using TensorIndicesRef = torch::ArrayRef<TensorIndex>;
43}
44
51struct TraceableSize : public std::variant<Size, torch::Tensor>
52{
53 using std::variant<Size, torch::Tensor>::variant;
54
56 const torch::Tensor * traceable() const noexcept;
57
59 Size concrete() const;
60
62 torch::Tensor as_tensor() const;
63};
64
67bool operator==(const TraceableSize & lhs, const TraceableSize & rhs);
68bool operator!=(const TraceableSize & lhs, const TraceableSize & rhs);
70
72std::ostream & operator<<(std::ostream & os, const TraceableSize & s);
73
80struct TraceableTensorShape : public torch::SmallVector<TraceableSize>
81{
82 using torch::SmallVector<TraceableSize>::SmallVector;
83 using Size = int64_t;
84
88 TraceableTensorShape(const torch::Tensor & shape);
89
92
95
97 TensorShape concrete() const;
98
100 torch::Tensor as_tensor() const;
101};
102
108
118enum class FType : int8_t
119{
120 NONE = 0,
121 INPUT = 1 << 0,
122 OUTPUT = 1 << 1,
123 PARAMETER = 1 << 2,
124 BUFFER = 1 << 3
125};
126
140torch::TensorOptions & default_tensor_options();
142torch::TensorOptions & default_integer_tensor_options();
144torch::Dtype & default_dtype();
146torch::Dtype & default_integer_dtype();
148torch::Device & default_device();
150
156Real & tolerance();
160
162std::string & buffer_name_separator();
164std::string & parameter_name_separator();
165
173
174} // namespace neml2
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:54
torch::SmallVector< TensorIndex > TensorIndices
Definition types.h:41
torch::ArrayRef< TensorIndex > TensorIndicesRef
Definition types.h:42
Definition CrossRef.cxx:31
bool operator!=(const TraceableSize &lhs, const TraceableSize &rhs)
Definition types.cxx:66
torch::Device & default_device()
Default device.
Definition types.cxx:187
torch::TensorOptions & default_integer_tensor_options()
Default integral tensor options.
Definition types.cxx:165
torch::TensorOptions & default_tensor_options()
Definition types.cxx:157
double Real
Definition types.h:31
std::string & buffer_name_separator()
Default nested buffer name separator.
Definition types.cxx:215
torch::Dtype & default_dtype()
Default floating point type.
Definition types.cxx:173
torch::SmallVector< Size > TensorShape
Definition types.h:34
torch::Dtype & default_integer_dtype()
Default integral type.
Definition types.cxx:180
bool & currently_solving_nonlinear_system()
Definition types.cxx:229
Real & machine_precision()
Definition types.cxx:194
int64_t Size
Definition types.h:33
Real & tolerance()
The tolerance used in various algorithms.
Definition types.cxx:201
FType
Role in a function definition.
Definition types.h:119
torch::IntArrayRef TensorShapeRef
Definition types.h:35
std::ostream & operator<<(std::ostream &os, const EnumSelection &es)
Definition EnumSelection.cxx:31
Real & tighter_tolerance()
A tighter tolerance used in various algorithms.
Definition types.cxx:208
std::string & parameter_name_separator()
Default nested parameter name separator.
Definition types.cxx:222
bool operator==(const TraceableSize &lhs, const TraceableSize &rhs)
Definition types.cxx:60
Traceable size.
Definition types.h:52
torch::Tensor as_tensor() const
Definition types.cxx:51
Size concrete() const
Definition types.cxx:37
const torch::Tensor * traceable() const noexcept
Definition types.cxx:31
Traceable tensor shape.
Definition types.h:81
TensorShape concrete() const
Definition types.cxx:124
torch::Tensor as_tensor() const
Definition types.cxx:133
TraceableTensorShape(const TensorShape &shape)
Definition types.cxx:78
int64_t Size
Definition types.h:83
TraceableTensorShape slice(Size start, Size end) const
Slice the shape, semantically the same as ArrayRef::slice, but traceable.
Definition types.cxx:105