NEML2 2.0.0
Loading...
Searching...
No Matches
TensorBase.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include <ATen/core/Tensor.h>
28
29#include "neml2/jit/TraceableTensorShape.h"
30#include "neml2/tensors/functions/operators.h"
31#include "neml2/tensors/indexing.h"
32
33namespace neml2
34{
35// Forward declarations
36template <class Derived>
37class TensorBase;
38
39class Tensor;
40
48template <class Derived>
49class TensorBase : public ATensor
50{
51public:
53 TensorBase() = default;
54
56 TensorBase(const ATensor & tensor, Size batch_dim);
57
59 TensorBase(const ATensor & tensor, const TraceableTensorShape & batch_shape);
60
62 TensorBase(const neml2::Tensor & tensor);
63
64 TensorBase(double) = delete;
65 TensorBase(float) = delete;
66 TensorBase(int) = delete;
67
71 [[nodiscard]] static Derived empty_like(const Derived & other);
73 [[nodiscard]] static Derived zeros_like(const Derived & other);
75 [[nodiscard]] static Derived ones_like(const Derived & other);
78 [[nodiscard]] static Derived full_like(const Derived & other, const CScalar & init);
99 [[nodiscard]] static Derived
100 linspace(const Derived & start, const Derived & end, Size nstep, Size dim = 0);
102 [[nodiscard]] static Derived logspace(const Derived & start,
103 const Derived & end,
104 Size nstep,
105 Size dim = 0,
106 const CScalar & base = 10);
108
112 Derived clone() const;
114 Derived detach() const;
116 using ATensor::detach_;
118 Derived to(const TensorOptions & options) const;
120 using ATensor::copy_;
122 using ATensor::zero_;
124 using ATensor::requires_grad;
126 using ATensor::requires_grad_;
128 Derived operator-() const;
130
134 using ATensor::options;
136 using ATensor::scalar_type;
138 using ATensor::device;
140 using ATensor::dim;
142 using ATensor::sizes;
144 using ATensor::size;
146 bool batched() const;
148 Size batch_dim() const;
150 Size base_dim() const;
152 const TraceableTensorShape & batch_sizes() const;
154 TraceableSize batch_size(Size index) const;
158 Size base_size(Size index) const;
160 Size base_storage() const;
162
166 using ATensor::index;
167 using ATensor::index_put_;
169 Derived batch_index(indexing::TensorIndicesRef indices) const;
173 Derived batch_slice(Size dim, const indexing::Slice & index) const;
175 neml2::Tensor base_slice(Size dim, const indexing::Slice & index) const;
178 void batch_index_put_(indexing::TensorIndicesRef indices, const ATensor & other);
179 void batch_index_put_(indexing::TensorIndicesRef indices, const CScalar & v);
183 void base_index_put_(indexing::TensorIndicesRef indices, const ATensor & other);
184 void base_index_put_(indexing::TensorIndicesRef indices, const CScalar & v);
187 Derived variable_data() const;
189
193 Derived batch_expand(const TraceableTensorShape & batch_shape) const;
195 Derived batch_expand(const TraceableSize & batch_size, Size dim) const;
197 neml2::Tensor base_expand(TensorShapeRef base_shape) const;
201 Derived batch_expand_as(const neml2::Tensor & other) const;
203 neml2::Tensor base_expand_as(const neml2::Tensor & other) const;
205 Derived batch_expand_copy(const TraceableTensorShape & batch_shape) const;
209 Derived batch_reshape(const TraceableTensorShape & batch_shape) const;
211 neml2::Tensor base_reshape(TensorShapeRef base_shape) const;
213 Derived batch_unsqueeze(Size d) const;
217 Derived batch_transpose(Size d1, Size d2) const;
223
224protected:
227};
228} // namespace neml2
NEML2's enhanced tensor type.
Definition TensorBase.h:50
neml2::Tensor base_flatten() const
Flatten base dimensions.
Definition TensorBaseImpl.h:451
Derived batch_slice(Size dim, const indexing::Slice &index) const
Get a tensor by slicing along a batch dimension.
Definition TensorBaseImpl.h:242
Derived batch_reshape(const TraceableTensorShape &batch_shape) const
Reshape batch dimensions.
Definition TensorBaseImpl.h:389
Derived batch_index(indexing::TensorIndicesRef indices) const
Get a tensor by slicing on the batch dimensions.
Definition TensorBaseImpl.h:223
TensorBase(double)=delete
TraceableTensorShape _batch_sizes
Traceable batch sizes.
Definition TensorBase.h:226
neml2::Tensor base_transpose(Size d1, Size d2) const
Transpose two base dimensions.
Definition TensorBaseImpl.h:442
neml2::Tensor base_slice(Size dim, const indexing::Slice &index) const
Get a tensor by slicing along a base dimension.
Definition TensorBaseImpl.h:252
TraceableSize batch_size(Size index) const
Return the size of a batch axis.
Definition TensorBaseImpl.h:189
Size base_storage() const
Return the flattened storage needed just for the base indices.
Definition TensorBaseImpl.h:216
TensorBase(float)=delete
TensorBase()=default
Special member functions.
bool batched() const
Whether the tensor is batched.
Definition TensorBaseImpl.h:161
Size batch_dim() const
Return the number of batch dimensions.
Definition TensorBaseImpl.h:168
Derived detach() const
Discard function graph.
Definition TensorBaseImpl.h:147
Derived operator-() const
Negation.
Definition TensorBaseImpl.h:461
TensorBase(int)=delete
Derived batch_transpose(Size d1, Size d2) const
Transpose two batch dimensions.
Definition TensorBaseImpl.h:434
Derived batch_expand_as(const neml2::Tensor &other) const
Expand the batch to have the same shape as another tensor.
Definition TensorBaseImpl.h:361
neml2::Tensor base_expand_as(const neml2::Tensor &other) const
Expand the base to have the same shape as another tensor.
Definition TensorBaseImpl.h:368
neml2::Tensor base_expand_copy(TensorShapeRef base_shape) const
Return a new tensor with values broadcast along the base dimensions.
Definition TensorBaseImpl.h:382
neml2::Tensor base_unsqueeze(Size d) const
Unsqueeze a base dimension.
Definition TensorBaseImpl.h:426
Derived clone() const
Definition TensorBaseImpl.h:140
Size base_size(Size index) const
Return the size of a base axis.
Definition TensorBaseImpl.h:209
Derived to(const TensorOptions &options) const
Change tensor options.
Definition TensorBaseImpl.h:154
const TraceableTensorShape & batch_sizes() const
Return the batch size.
Definition TensorBaseImpl.h:182
void base_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:280
void batch_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:262
Derived batch_unsqueeze(Size d) const
Unsqueeze a batch dimension.
Definition TensorBaseImpl.h:418
Derived batch_expand_copy(const TraceableTensorShape &batch_shape) const
Return a new tensor with values broadcast along the batch dimensions.
Definition TensorBaseImpl.h:375
Size base_dim() const
Return the number of base dimensions.
Definition TensorBaseImpl.h:175
TensorShapeRef base_sizes() const
Return the base size.
Definition TensorBaseImpl.h:202
neml2::Tensor base_reshape(TensorShapeRef base_shape) const
Reshape base dimensions.
Definition TensorBaseImpl.h:402
Derived batch_expand(const TraceableTensorShape &batch_shape) const
Definition TensorBaseImpl.h:305
neml2::Tensor base_expand(TensorShapeRef base_shape) const
Return a new view of the tensor with values broadcast along the base dimensions.
Definition TensorBaseImpl.h:334
neml2::Tensor base_index(indexing::TensorIndicesRef indices) const
Get a tensor by slicing on the base dimensions.
Definition TensorBaseImpl.h:233
Derived variable_data() const
Definition TensorBaseImpl.h:298
Definition Tensor.h:46
static Derived full_like(const Derived &other, const CScalar &init)
Definition TensorBaseImpl.h:99
static Derived zeros_like(const Derived &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:85
static Derived empty_like(const Derived &other)
Definition TensorBaseImpl.h:78
static Derived linspace(const Derived &start, const Derived &end, Size nstep, Size dim=0)
Create a new tensor by adding a new batch dimension with linear spacing between start and end.
Definition TensorBaseImpl.h:106
static Derived logspace(const Derived &start, const Derived &end, Size nstep, Size dim=0, const CScalar &base=10)
log-space equivalent of the linspace named constructor
Definition TensorBaseImpl.h:131
static Derived ones_like(const Derived &other)
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:92
c10::ArrayRef< TensorIndex > TensorIndicesRef
Definition indexing.h:39
Definition DiagnosticsInterface.cxx:30
int64_t Size
Definition types.h:65
c10::Scalar CScalar
Definition types.h:39
at::Tensor ATensor
Definition types.h:38
c10::TensorOptions TensorOptions
Definition types.h:60
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:67
Traceable size.
Definition TraceableSize.h:40
Traceable tensor shape.
Definition TraceableTensorShape.h:38