NEML2 2.0.0
All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends Modules Pages
TensorBaseImpl.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
29
30#pragma once
31
32#include "neml2/tensors/TensorBase.h"
33#include "neml2/tensors/Scalar.h"
34#include "neml2/tensors/assertions.h"
35#include "neml2/jit/types.h"
36#include "neml2/jit/utils.h"
37
38namespace neml2
39{
40template <class Derived>
42 : ATensor(tensor),
43 _batch_dim(batch_dim),
44 _batch_sizes(utils::extract_batch_sizes(tensor, batch_dim))
45{
46 neml_assert_dbg((Size)sizes().size() >= _batch_dim,
47 "Tensor dimension ",
48 sizes().size(),
49 " is smaller than the requested number of batch dimensions ",
50 _batch_dim);
51}
52
53template <class Derived>
55 : ATensor(tensor),
56 _batch_dim(Size(batch_shape.size())),
57 _batch_sizes(batch_shape)
59 neml_assert_dbg(batch_sizes() == batch_shape,
60 "Tensor of shape ",
61 sizes(),
62 " cannot be constructed with batch shape ",
63 batch_shape);
64}
65
66template <class Derived>
68 : TensorBase(tensor, tensor.batch_sizes())
69{
71
72template <class Derived>
73Derived
74TensorBase<Derived>::empty_like(const Derived & other)
76 return Derived(at::empty_like(other), other.batch_sizes());
77}
78
79template <class Derived>
80Derived
81TensorBase<Derived>::zeros_like(const Derived & other)
82{
83 return Derived(at::zeros_like(other), other.batch_sizes());
84}
85
86template <class Derived>
87Derived
88TensorBase<Derived>::ones_like(const Derived & other)
89{
90 return Derived(at::ones_like(other), other.batch_sizes());
91}
92
93template <class Derived>
94Derived
95TensorBase<Derived>::full_like(const Derived & other, Real init)
96{
97 return Derived(at::full_like(other, init), other.batch_sizes());
98}
99
100template <class Derived>
101Derived
102TensorBase<Derived>::linspace(const Derived & start, const Derived & end, Size nstep, Size dim)
103{
105 neml_assert_dbg(nstep > 0, "nstep must be positive.");
107 auto res = start.batch_unsqueeze(dim);
109 if (nstep > 1)
110 {
111 auto Bd = utils::broadcast_batch_dim(start, end);
112 auto diff = (end - start).batch_unsqueeze(dim);
113
114 indexing::TensorIndices net(dim, indexing::None);
115 net.push_back(indexing::Ellipsis);
116 net.insert(net.end(), Bd - dim, indexing::None);
117 Scalar steps(at::arange(nstep, diff.options()).index(net) / (nstep - 1));
118
119 res = res + steps * diff;
120 }
121
122 return res;
123}
124
125template <class Derived>
126Derived
128 const Derived & start, const Derived & end, Size nstep, Size dim, Real base)
129{
130 auto exponent = neml2::Tensor::linspace(start, end, nstep, dim);
131 return Derived(at::pow(base, exponent), exponent.batch_sizes());
132}
133
134template <class Derived>
135Derived
137{
138 return Derived(ATensor::clone(), batch_sizes());
139}
141template <class Derived>
142Derived
145 return Derived(ATensor::detach(), batch_sizes());
147
148template <class Derived>
149Derived
151{
152 return Derived(ATensor::to(options), batch_sizes());
153}
155template <class Derived>
156bool
158{
159 return _batch_dim;
160}
161
162template <class Derived>
166 return _batch_dim;
168
169template <class Derived>
170Size
173 return dim() - batch_dim();
174}
175
176template <class Derived>
179{
180 return _batch_sizes;
182
183template <class Derived>
186{
187 const auto i = index >= 0 ? index : index + batch_dim();
188
189 // Put the batch size into the traced graph if we are tracing
190 if (jit::tracer::isTracing())
191 return jit::tracer::getSizeOf(*this, i);
192
193 return size(i);
194}
196template <class Derived>
200 return sizes().slice(_batch_dim);
202
203template <class Derived>
204Size
206{
207 return base_sizes()[index >= 0 ? index : index + base_dim()];
208}
210template <class Derived>
216
217template <class Derived>
218Derived
220{
221 indexing::TensorIndices indices_vec(indices);
222 indices_vec.insert(indices_vec.end(), base_dim(), indexing::Slice());
223 auto res = this->index(indices_vec);
224 return Derived(res, res.dim() - base_dim());
225}
226
227template <class Derived>
230{
231 indexing::TensorIndices indices2(batch_dim(), indexing::Slice());
232 indices2.insert(indices2.end(), indices.begin(), indices.end());
233 return neml2::Tensor(this->index(indices2), batch_sizes());
234}
235
236template <class Derived>
237Derived
238TensorBase<Derived>::batch_slice(Size dim, const indexing::Slice & index) const
239{
240 auto i = dim >= 0 ? dim : this->dim() + dim - base_dim();
241 auto res = this->slice(
242 i, index.start().expect_int(), index.stop().expect_int(), index.step().expect_int());
243 return Derived(res, res.dim() - base_dim());
244}
245
246template <class Derived>
248TensorBase<Derived>::base_slice(Size dim, const indexing::Slice & index) const
249{
250 auto i = dim < 0 ? this->dim() + dim : dim + batch_dim();
251 auto res = this->slice(
252 i, index.start().expect_int(), index.stop().expect_int(), index.step().expect_int());
253 return Derived(res, batch_sizes());
254}
255
256template <class Derived>
257void
259{
260 indexing::TensorIndices indices_vec(indices);
261 indices_vec.insert(indices_vec.end(), base_dim(), indexing::Slice());
262 this->index_put_(indices_vec, other);
263}
264
265template <class Derived>
266void
268{
269 indexing::TensorIndices indices_vec(indices);
270 indices_vec.insert(indices_vec.end(), base_dim(), indexing::Slice());
271 this->index_put_(indices_vec, v);
272}
273
274template <class Derived>
275void
277{
278 indexing::TensorIndices indices2(batch_dim(), indexing::Slice());
279 indices2.insert(indices2.end(), indices.begin(), indices.end());
280 this->index_put_(indices2, other);
281}
282
283template <class Derived>
284void
286{
287 indexing::TensorIndices indices2(batch_dim(), indexing::Slice());
288 indices2.insert(indices2.end(), indices.begin(), indices.end());
289 this->index_put_(indices2, v);
290}
291
292template <class Derived>
293Derived
295{
296 return Derived(ATensor::variable_data(), batch_sizes());
297}
298
299template <class Derived>
300Derived
302{
303 // We don't want to touch the base dimensions, so put -1 for them.
304 auto net = batch_shape.concrete();
305 net.insert(net.end(), base_dim(), -1);
306
307 // Record the batch sizes in the traced graph if we are tracing
308 for (Size i = 0; i < (Size)batch_shape.size(); ++i)
309 if (const auto * const si = batch_shape[i].traceable())
310 jit::tracer::ArgumentStash::stashIntArrayRefElem("size", net.size(), i, *si);
311
312 return Derived(expand(net), batch_shape);
313}
314
315template <class Derived>
316Derived
318{
319 auto i = dim >= 0 ? dim : this->dim() + dim - base_dim();
320 auto batch_shape = batch_sizes();
321 if (batch_shape[i] == batch_size)
322 return Derived(*this);
323
324 batch_shape[i] = batch_size;
325 return batch_expand(batch_shape);
326}
327
328template <class Derived>
331{
332 if (base_sizes() == base_shape)
333 return *this;
334
335 // We don't want to touch the batch dimensions, so put -1 for them.
336 auto net = base_shape.vec();
337 net.insert(net.begin(), batch_dim(), -1);
338 return neml2::Tensor(expand(net), batch_sizes());
339}
340
341template <class Derived>
344{
345 if (this->base_size(dim) == base_size)
346 return *this;
347
348 // We don't want to touch the batch dimensions and other base dimensions, so put -1 for them.
349 auto net = std::vector<Size>(this->dim(), -1);
350 auto i = dim < 0 ? this->dim() + dim : dim + batch_dim();
351 net[i] = base_size;
352 return neml2::Tensor(expand(net), batch_sizes());
353}
354
355template <class Derived>
356Derived
358{
359 return batch_expand(other.batch_sizes());
360}
361
362template <class Derived>
365{
366 return base_expand(other.base_sizes());
367}
368
369template <class Derived>
370Derived
372{
373 return Derived(batch_expand(batch_shape).contiguous(), batch_shape);
374}
375
376template <class Derived>
379{
380 return neml2::Tensor(base_expand(base_shape).contiguous(), batch_sizes());
381}
382
383template <class Derived>
384Derived
386{
387 // Record the batch sizes in the traced graph if we are tracing
388 for (Size i = 0; i < (Size)batch_shape.size(); ++i)
389 if (const auto * const si = batch_shape[i].traceable())
390 jit::tracer::ArgumentStash::stashIntArrayRefElem(
391 "shape", batch_shape.size() + base_dim(), i, *si);
392
393 return Derived(reshape(utils::add_shapes(batch_shape.concrete(), base_sizes())), batch_shape);
394}
395
396template <class Derived>
399{
400 auto batch_shape = batch_sizes();
401
402 // Record the batch sizes in the traced graph if we are tracing
403 for (Size i = 0; i < (Size)batch_shape.size(); ++i)
404 if (const auto * const si = batch_shape[i].traceable())
405 jit::tracer::ArgumentStash::stashIntArrayRefElem(
406 "shape", batch_shape.size() + base_shape.size(), i, *si);
407
408 return neml2::Tensor(reshape(utils::add_shapes(batch_shape.concrete(), base_shape)),
409 batch_sizes());
410}
411
412template <class Derived>
413Derived
415{
416 auto d2 = d >= 0 ? d : d - base_dim();
417 return Derived(unsqueeze(d2), _batch_dim + 1);
418}
419
420template <class Derived>
423{
424 auto d2 = d < 0 ? d : d + batch_dim();
425 return neml2::Tensor(ATensor::unsqueeze(d2), batch_sizes());
426}
427
428template <class Derived>
429Derived
431{
432 return Derived(ATensor::transpose(d1 < 0 ? d1 - base_dim() : d1, d2 < 0 ? d2 - base_dim() : d2),
433 _batch_dim);
434}
435
436template <class Derived>
439{
440 return neml2::Tensor(
441 ATensor::transpose(d1 < 0 ? d1 : _batch_dim + d1, d2 < 0 ? d2 : _batch_dim + d2),
442 batch_sizes());
443}
444
445template <class Derived>
448{
449 if (base_dim() == 1)
450 return *this;
451
452 return base_reshape({base_storage()});
453}
454
455template <class Derived>
456Derived
458{
459 return Derived(-ATensor(*this), batch_sizes());
460}
461
462} // end namespace neml2
Scalar.
Definition Scalar.h:38
neml2::Tensor base_flatten() const
Flatten base dimensions.
Definition TensorBaseImpl.h:447
Derived batch_slice(Size dim, const indexing::Slice &index) const
Get a tensor by slicing along a batch dimension.
Definition TensorBaseImpl.h:238
Derived batch_reshape(const TraceableTensorShape &batch_shape) const
Reshape batch dimensions.
Definition TensorBaseImpl.h:385
Derived batch_index(indexing::TensorIndicesRef indices) const
Get a tensor by slicing on the batch dimensions.
Definition TensorBaseImpl.h:219
neml2::Tensor base_transpose(Size d1, Size d2) const
Transpose two base dimensions.
Definition TensorBaseImpl.h:438
neml2::Tensor base_slice(Size dim, const indexing::Slice &index) const
Get a tensor by slicing along a base dimension.
Definition TensorBaseImpl.h:248
TraceableSize batch_size(Size index) const
Return the size of a batch axis.
Definition TensorBaseImpl.h:185
Size base_storage() const
Return the flattened storage needed just for the base indices.
Definition TensorBaseImpl.h:212
TensorBase()=default
Special member functions.
bool batched() const
Whether the tensor is batched.
Definition TensorBaseImpl.h:157
Size batch_dim() const
Return the number of batch dimensions.
Definition TensorBaseImpl.h:164
Derived detach() const
Discard function graph.
Definition TensorBaseImpl.h:143
Derived operator-() const
Negation.
Definition TensorBaseImpl.h:457
Derived batch_transpose(Size d1, Size d2) const
Transpose two batch dimensions.
Definition TensorBaseImpl.h:430
Derived batch_expand_as(const neml2::Tensor &other) const
Expand the batch to have the same shape as another tensor.
Definition TensorBaseImpl.h:357
neml2::Tensor base_expand_as(const neml2::Tensor &other) const
Expand the base to have the same shape as another tensor.
Definition TensorBaseImpl.h:364
neml2::Tensor base_expand_copy(TensorShapeRef base_shape) const
Return a new tensor with values broadcast along the base dimensions.
Definition TensorBaseImpl.h:378
neml2::Tensor base_unsqueeze(Size d) const
Unsqueeze a base dimension.
Definition TensorBaseImpl.h:422
Derived clone() const
Definition TensorBaseImpl.h:136
Size base_size(Size index) const
Return the size of a base axis.
Definition TensorBaseImpl.h:205
Derived to(const TensorOptions &options) const
Change tensor options.
Definition TensorBaseImpl.h:150
const TraceableTensorShape & batch_sizes() const
Definition TensorBaseImpl.h:178
void base_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:276
void batch_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:258
MillerIndex batch_unsqueeze(Size d) const
Definition TensorBaseImpl.h:414
Derived batch_expand_copy(const TraceableTensorShape &batch_shape) const
Return a new tensor with values broadcast along the batch dimensions.
Definition TensorBaseImpl.h:371
Size base_dim() const
Return the number of base dimensions.
Definition TensorBaseImpl.h:171
TensorShapeRef base_sizes() const
Return the base size.
Definition TensorBaseImpl.h:198
neml2::Tensor base_reshape(TensorShapeRef base_shape) const
Reshape base dimensions.
Definition TensorBaseImpl.h:398
Derived batch_expand(const TraceableTensorShape &batch_shape) const
Definition TensorBaseImpl.h:301
neml2::Tensor base_expand(TensorShapeRef base_shape) const
Return a new view of the tensor with values broadcast along the base dimensions.
Definition TensorBaseImpl.h:330
neml2::Tensor base_index(indexing::TensorIndicesRef indices) const
Get a tensor by slicing on the base dimensions.
Definition TensorBaseImpl.h:229
Derived variable_data() const
Definition TensorBaseImpl.h:294
Definition Tensor.h:46
static Derived logspace(const Derived &start, const Derived &end, Size nstep, Size dim=0, Real base=10)
log-space equivalent of the linspace named constructor
Definition TensorBaseImpl.h:127
static Derived zeros_like(const Derived &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:81
static Derived empty_like(const Derived &other)
Definition TensorBaseImpl.h:74
static Derived linspace(const Derived &start, const Derived &end, Size nstep, Size dim=0)
Create a new tensor by adding a new batch dimension with linear spacing between start and end.
Definition TensorBaseImpl.h:102
static Derived full_like(const Derived &other, Real init)
Definition TensorBaseImpl.h:95
static Derived ones_like(const Derived &other)
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:88
c10::ArrayRef< TensorIndex > TensorIndicesRef
Definition types.h:78
c10::SmallVector< TensorIndex, 8 > TensorIndices
Definition types.h:77
Definition Parser.cxx:35
Size storage_size(TensorShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition shape_utils.cxx:30
Size broadcast_batch_dim(const T &...)
The batch dimension after broadcasting.
TensorShape add_shapes(const S &...)
Definition DiagnosticsInterface.cxx:30
void neml_assert_dbg(bool assertion, Args &&... args)
Definition assertions.h:60
at::Tensor ATensor
Definition types.h:42
double Real
Definition types.h:68
int64_t Size
Definition types.h:69
void neml_assert_broadcastable_dbg(const T &...)
A helper function to assert (in Debug mode) that all tensors are broadcastable.
c10::TensorOptions TensorOptions
Definition types.h:63
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:72
Traceable size.
Definition TraceableSize.h:40
Traceable tensor shape.
Definition TraceableTensorShape.h:38
TensorShape concrete() const
Definition TraceableTensorShape.cxx:78