NEML2 2.0.0
All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends Modules Pages
PrimitiveTensor.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include <torch/csrc/autograd/variable.h>
28#include <torch/csrc/api/include/torch/detail/TensorDataContainer.h>
29
30#include "neml2/tensors/Tensor.h"
31
32namespace neml2
33{
34using TensorDataContainer = torch::detail::TensorDataContainer;
35
41template <class Derived, Size... S>
42class PrimitiveTensor : public TensorBase<Derived>
43{
44public:
46 static inline const TensorShape const_base_sizes = {S...};
47
49 static constexpr Size const_base_dim = sizeof...(S);
50
52 static inline const Size const_base_storage = utils::storage_size({S...});
53
55 PrimitiveTensor() = default;
56
58 explicit PrimitiveTensor(const ATensor & tensor, Size batch_dim);
59
61 explicit PrimitiveTensor(const ATensor & tensor, const TraceableTensorShape & batch_shape);
62
64 PrimitiveTensor(const Tensor & tensor);
65
67 explicit PrimitiveTensor(const ATensor & tensor);
68
70 operator Tensor() const;
71
73 [[nodiscard]] static Derived create(const TensorDataContainer & data,
74 const TensorOptions & options = default_tensor_options());
76 [[nodiscard]] static Derived empty(const TensorOptions & options = default_tensor_options());
78 [[nodiscard]] static Derived empty(const TraceableTensorShape & batch_shape,
79 const TensorOptions & options = default_tensor_options());
81 [[nodiscard]] static Derived zeros(const TensorOptions & options = default_tensor_options());
83 [[nodiscard]] static Derived zeros(const TraceableTensorShape & batch_shape,
84 const TensorOptions & options = default_tensor_options());
86 [[nodiscard]] static Derived ones(const TensorOptions & options = default_tensor_options());
88 [[nodiscard]] static Derived ones(const TraceableTensorShape & batch_shape,
89 const TensorOptions & options = default_tensor_options());
91 [[nodiscard]] static Derived full(Real init,
92 const TensorOptions & options = default_tensor_options());
94 [[nodiscard]] static Derived full(const TraceableTensorShape & batch_shape,
95 Real init,
96 const TensorOptions & options = default_tensor_options());
97
99 [[nodiscard]] static Tensor identity_map(const TensorOptions &)
100 {
101 throw NEMLException("Not implemented");
102 }
103};
104
106// Implementations
108
109template <class Derived, Size... S>
111 : TensorBase<Derived>(tensor, batch_dim)
112{
113#ifndef NDEBUG
114 if (this->base_sizes() != const_base_sizes)
115 throw NEMLException("Base shape mismatch");
116#endif
117}
118
119template <class Derived, Size... S>
121 const TraceableTensorShape & batch_shape)
122 : TensorBase<Derived>(tensor, batch_shape)
123{
124#ifndef NDEBUG
125 if (this->base_sizes() != const_base_sizes)
126 throw NEMLException("Base shape mismatch");
127#endif
128}
129
130template <class Derived, Size... S>
132 : TensorBase<Derived>(tensor)
133{
134#ifndef NDEBUG
135 if (this->base_sizes() != const_base_sizes)
136 throw NEMLException("Base shape mismatch");
137#endif
138}
139
140template <class Derived, Size... S>
142 : TensorBase<Derived>(tensor, tensor.dim() - const_base_dim)
143{
144#ifndef NDEBUG
145 if (this->base_sizes() != const_base_sizes)
146 throw NEMLException("Base shape mismatch");
147#endif
148}
149
150template <class Derived, Size... S>
151PrimitiveTensor<Derived, S...>::operator Tensor() const
152{
153 return Tensor(*this, this->batch_sizes());
154}
155
156template <class Derived, Size... S>
157Derived
159 const TensorOptions & options)
160{
161 return Derived(torch::autograd::make_variable(
162 data.convert_to_tensor(options.requires_grad(false)), options.requires_grad()));
163}
164
165template <class Derived, Size... S>
166Derived
171
172template <class Derived, Size... S>
173Derived
175 const TensorOptions & options)
176{
177 return Tensor::empty(batch_shape, const_base_sizes, options);
178}
179
180template <class Derived, Size... S>
181Derived
186
187template <class Derived, Size... S>
188Derived
190 const TensorOptions & options)
191{
192 return Tensor::zeros(batch_shape, const_base_sizes, options);
193}
194
195template <class Derived, Size... S>
196Derived
201
202template <class Derived, Size... S>
203Derived
205 const TensorOptions & options)
206{
207 return Tensor::ones(batch_shape, const_base_sizes, options);
208}
209
210template <class Derived, Size... S>
211Derived
213{
214 return Tensor::full(const_base_sizes, init, options);
215}
216
217template <class Derived, Size... S>
218Derived
220 Real init,
221 const TensorOptions & options)
222{
223 return Tensor::full(batch_shape, const_base_sizes, init, options);
224}
225} // namespace neml2
Definition errors.h:34
static Derived full(const TraceableTensorShape &batch_shape, Real init, const TensorOptions &options=default_tensor_options())
Full tensor given batch shape.
Definition PrimitiveTensor.h:219
static Derived ones(const TraceableTensorShape &batch_shape, const TensorOptions &options=default_tensor_options())
Unit tensor given batch shape.
Definition PrimitiveTensor.h:204
PrimitiveTensor(const ATensor &tensor)
Construct from another ATensor and infer batch dimension.
Definition PrimitiveTensor.h:141
static Tensor identity_map(const TensorOptions &)
Derived tensor classes should define identity_map where appropriate.
Definition PrimitiveTensor.h:99
static Derived empty(const TensorOptions &options=default_tensor_options())
Unbatched empty tensor.
Definition PrimitiveTensor.h:167
PrimitiveTensor(const ATensor &tensor, const TraceableTensorShape &batch_shape)
Construct from another ATensor given batch shape.
Definition PrimitiveTensor.h:120
static Derived create(const TensorDataContainer &data, const TensorOptions &options=default_tensor_options())
Arbitrary tensor from a nested container with inferred batch dimension.
Definition PrimitiveTensor.h:158
PrimitiveTensor(const Tensor &tensor)
Copy constructor.
Definition PrimitiveTensor.h:131
static const TensorShape const_base_sizes
The base shape.
Definition PrimitiveTensor.h:46
static Derived empty(const TraceableTensorShape &batch_shape, const TensorOptions &options=default_tensor_options())
Empty tensor given batch shape.
Definition PrimitiveTensor.h:174
static Derived ones(const TensorOptions &options=default_tensor_options())
Unbatched unit tensor.
Definition PrimitiveTensor.h:197
static constexpr Size const_base_dim
The base dim.
Definition PrimitiveTensor.h:49
static Derived full(Real init, const TensorOptions &options=default_tensor_options())
Unbatched tensor filled with a given value given base shape.
Definition PrimitiveTensor.h:212
PrimitiveTensor(const ATensor &tensor, Size batch_dim)
Construct from another ATensor given batch dimension.
Definition PrimitiveTensor.h:110
static const Size const_base_storage
The base storage.
Definition PrimitiveTensor.h:52
static Derived zeros(const TensorOptions &options=default_tensor_options())
Unbatched zero tensor.
Definition PrimitiveTensor.h:182
static Derived zeros(const TraceableTensorShape &batch_shape, const TensorOptions &options=default_tensor_options())
Zero tensor given batch shape.
Definition PrimitiveTensor.h:189
PrimitiveTensor()=default
Special member functions.
TensorBase()=default
Special member functions.
Size batch_dim() const
Definition TensorBaseImpl.h:164
const TraceableTensorShape & batch_sizes() const
Return the batch size.
Definition TensorBaseImpl.h:178
TensorShapeRef base_sizes() const
Return the base size.
Definition TensorBaseImpl.h:198
Definition Tensor.h:46
static Tensor full(TensorShapeRef base_shape, Real init, const TensorOptions &options=default_tensor_options())
Unbatched tensor filled with a given value given base shape.
Definition Tensor.cxx:154
static Tensor empty(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Unbatched empty tensor given base shape.
Definition Tensor.cxx:91
static Tensor zeros(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Unbatched tensor filled with zeros given base shape.
Definition Tensor.cxx:112
static Tensor ones(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Unbatched tensor filled with ones given base shape.
Definition Tensor.cxx:133
Size storage_size(TensorShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition shape_utils.cxx:30
Definition DiagnosticsInterface.cxx:30
c10::SmallVector< Size, 8 > TensorShape
Definition types.h:71
TensorOptions default_tensor_options()
Default floating point tensor options.
Definition defaults.cxx:44
at::Tensor ATensor
Definition types.h:42
double Real
Definition types.h:68
int64_t Size
Definition types.h:69
torch::detail::TensorDataContainer TensorDataContainer
Definition PrimitiveTensor.h:34
c10::TensorOptions TensorOptions
Definition types.h:63
Traceable tensor shape.
Definition TraceableTensorShape.h:38