NEML2 2.0.0
All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends Modules Pages
TensorBase.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include <ATen/core/Tensor.h>
28#include "neml2/jit/TraceableTensorShape.h"
29#include "neml2/tensors/shape_utils.h"
30#include "neml2/tensors/functions/operators.h"
31
32namespace neml2
33{
34// Forward declarations
35template <class Derived>
36class TensorBase;
37
38class Tensor;
39
47template <class Derived>
48class TensorBase : public ATensor
49{
50public:
52 TensorBase() = default;
53
55 TensorBase(const ATensor & tensor, Size batch_dim);
56
58 TensorBase(const ATensor & tensor, const TraceableTensorShape & batch_shape);
59
61 TensorBase(const neml2::Tensor & tensor);
62
63 TensorBase(Real) = delete;
64
68 [[nodiscard]] static Derived empty_like(const Derived & other);
70 [[nodiscard]] static Derived zeros_like(const Derived & other);
72 [[nodiscard]] static Derived ones_like(const Derived & other);
75 [[nodiscard]] static Derived full_like(const Derived & other, Real init);
96 [[nodiscard]] static Derived
97 linspace(const Derived & start, const Derived & end, Size nstep, Size dim = 0);
99 [[nodiscard]] static Derived
100 logspace(const Derived & start, const Derived & end, Size nstep, Size dim = 0, Real base = 10);
102
106 Derived clone() const;
108 Derived detach() const;
110 using ATensor::detach_;
112 Derived to(const TensorOptions & options) const;
114 using ATensor::copy_;
116 using ATensor::zero_;
118 using ATensor::requires_grad;
120 using ATensor::requires_grad_;
122 Derived operator-() const;
124
128 using ATensor::options;
130 using ATensor::scalar_type;
132 using ATensor::device;
134 using ATensor::dim;
136 using ATensor::sizes;
138 using ATensor::size;
140 bool batched() const;
142 Size batch_dim() const;
144 Size base_dim() const;
146 const TraceableTensorShape & batch_sizes() const;
148 TraceableSize batch_size(Size index) const;
152 Size base_size(Size index) const;
154 Size base_storage() const;
156
160 using ATensor::index;
161 using ATensor::index_put_;
163 Derived batch_index(indexing::TensorIndicesRef indices) const;
167 Derived batch_slice(Size dim, const indexing::Slice & index) const;
169 neml2::Tensor base_slice(Size dim, const indexing::Slice & index) const;
172 void batch_index_put_(indexing::TensorIndicesRef indices, const ATensor & other);
177 void base_index_put_(indexing::TensorIndicesRef indices, const ATensor & other);
181 Derived variable_data() const;
183
187 Derived batch_expand(const TraceableTensorShape & batch_shape) const;
189 Derived batch_expand(const TraceableSize & batch_size, Size dim) const;
191 neml2::Tensor base_expand(TensorShapeRef base_shape) const;
195 Derived batch_expand_as(const neml2::Tensor & other) const;
197 neml2::Tensor base_expand_as(const neml2::Tensor & other) const;
199 Derived batch_expand_copy(const TraceableTensorShape & batch_shape) const;
203 Derived batch_reshape(const TraceableTensorShape & batch_shape) const;
205 neml2::Tensor base_reshape(TensorShapeRef base_shape) const;
207 Derived batch_unsqueeze(Size d) const;
211 Derived batch_transpose(Size d1, Size d2) const;
217
218private:
220 Size _batch_dim = {};
221
223 TraceableTensorShape _batch_sizes;
224};
225} // namespace neml2
NEML2's enhanced tensor type.
Definition TensorBase.h:49
neml2::Tensor base_flatten() const
Flatten base dimensions.
Definition TensorBaseImpl.h:447
Derived batch_slice(Size dim, const indexing::Slice &index) const
Get a tensor by slicing along a batch dimension.
Definition TensorBaseImpl.h:238
Derived batch_reshape(const TraceableTensorShape &batch_shape) const
Reshape batch dimensions.
Definition TensorBaseImpl.h:385
Derived batch_index(indexing::TensorIndicesRef indices) const
Get a tensor by slicing on the batch dimensions.
Definition TensorBaseImpl.h:219
neml2::Tensor base_transpose(Size d1, Size d2) const
Transpose two base dimensions.
Definition TensorBaseImpl.h:438
neml2::Tensor base_slice(Size dim, const indexing::Slice &index) const
Get a tensor by slicing along a base dimension.
Definition TensorBaseImpl.h:248
TraceableSize batch_size(Size index) const
Return the size of a batch axis.
Definition TensorBaseImpl.h:185
Size base_storage() const
Return the flattened storage needed just for the base indices.
Definition TensorBaseImpl.h:212
TensorBase()=default
Special member functions.
bool batched() const
Whether the tensor is batched.
Definition TensorBaseImpl.h:157
Size batch_dim() const
Return the number of batch dimensions.
Definition TensorBaseImpl.h:164
Derived detach() const
Discard function graph.
Definition TensorBaseImpl.h:143
Derived operator-() const
Negation.
Definition TensorBaseImpl.h:457
TensorBase(Real)=delete
Derived batch_transpose(Size d1, Size d2) const
Transpose two batch dimensions.
Definition TensorBaseImpl.h:430
Derived batch_expand_as(const neml2::Tensor &other) const
Expand the batch to have the same shape as another tensor.
Definition TensorBaseImpl.h:357
neml2::Tensor base_expand_as(const neml2::Tensor &other) const
Expand the base to have the same shape as another tensor.
Definition TensorBaseImpl.h:364
neml2::Tensor base_expand_copy(TensorShapeRef base_shape) const
Return a new tensor with values broadcast along the base dimensions.
Definition TensorBaseImpl.h:378
neml2::Tensor base_unsqueeze(Size d) const
Unsqueeze a base dimension.
Definition TensorBaseImpl.h:422
Derived clone() const
Definition TensorBaseImpl.h:136
Size base_size(Size index) const
Return the size of a base axis.
Definition TensorBaseImpl.h:205
Derived to(const TensorOptions &options) const
Change tensor options.
Definition TensorBaseImpl.h:150
const TraceableTensorShape & batch_sizes() const
Return the batch size.
Definition TensorBaseImpl.h:178
void base_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:276
void batch_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:258
Derived batch_unsqueeze(Size d) const
Unsqueeze a batch dimension.
Definition TensorBaseImpl.h:414
Derived batch_expand_copy(const TraceableTensorShape &batch_shape) const
Return a new tensor with values broadcast along the batch dimensions.
Definition TensorBaseImpl.h:371
Size base_dim() const
Return the number of base dimensions.
Definition TensorBaseImpl.h:171
TensorShapeRef base_sizes() const
Return the base size.
Definition TensorBaseImpl.h:198
neml2::Tensor base_reshape(TensorShapeRef base_shape) const
Reshape base dimensions.
Definition TensorBaseImpl.h:398
Derived batch_expand(const TraceableTensorShape &batch_shape) const
Definition TensorBaseImpl.h:301
neml2::Tensor base_expand(TensorShapeRef base_shape) const
Return a new view of the tensor with values broadcast along the base dimensions.
Definition TensorBaseImpl.h:330
neml2::Tensor base_index(indexing::TensorIndicesRef indices) const
Get a tensor by slicing on the base dimensions.
Definition TensorBaseImpl.h:229
Derived variable_data() const
Definition TensorBaseImpl.h:294
Definition Tensor.h:46
static Derived logspace(const Derived &start, const Derived &end, Size nstep, Size dim=0, Real base=10)
log-space equivalent of the linspace named constructor
Definition TensorBaseImpl.h:127
static Derived zeros_like(const Derived &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:81
static Derived empty_like(const Derived &other)
Definition TensorBaseImpl.h:74
static Derived linspace(const Derived &start, const Derived &end, Size nstep, Size dim=0)
Create a new tensor by adding a new batch dimension with linear spacing between start and end.
Definition TensorBaseImpl.h:102
static Derived full_like(const Derived &other, Real init)
Definition TensorBaseImpl.h:95
static Derived ones_like(const Derived &other)
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:88
c10::ArrayRef< TensorIndex > TensorIndicesRef
Definition types.h:78
Definition DiagnosticsInterface.cxx:30
at::Tensor ATensor
Definition types.h:42
double Real
Definition types.h:68
int64_t Size
Definition types.h:69
c10::TensorOptions TensorOptions
Definition types.h:63
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:72
Traceable size.
Definition TraceableSize.h:40
Traceable tensor shape.
Definition TraceableTensorShape.h:38