NEML2 2.0.0
Loading...
Searching...
No Matches
TensorBaseImpl.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
30#pragma once
31
32#include <torch/csrc/jit/frontend/tracer.h>
33
34#include "neml2/tensors/TensorBase.h"
35#include "neml2/tensors/Scalar.h"
36#include "neml2/tensors/assertions.h"
37#include "neml2/jit/utils.h"
38
39namespace neml2::jit
40{
41using namespace torch::jit;
42}
43
44namespace neml2
45{
46template <class Derived>
48 : ATensor(tensor),
49 _batch_sizes(utils::extract_batch_sizes(tensor, batch_dim))
50{
51 neml_assert_dbg((Size)sizes().size() >= batch_dim,
52 "Tensor dimension ",
53 sizes().size(),
54 " is smaller than the requested number of batch dimensions ",
55 batch_dim);
57
58template <class Derived>
60 : ATensor(tensor),
61 _batch_sizes(batch_shape)
63 neml_assert_dbg(batch_sizes() == tensor.sizes().slice(0, batch_dim()),
64 "Tensor of shape ",
65 sizes(),
66 " cannot be constructed with batch shape ",
67 batch_shape);
68}
69
70template <class Derived>
72 : TensorBase(tensor, tensor.batch_sizes())
74}
76template <class Derived>
77Derived
78TensorBase<Derived>::empty_like(const Derived & other)
79{
80 return Derived(at::empty_like(other), other.batch_sizes());
81}
82
83template <class Derived>
84Derived
85TensorBase<Derived>::zeros_like(const Derived & other)
86{
87 return Derived(at::zeros_like(other), other.batch_sizes());
88}
89
90template <class Derived>
91Derived
92TensorBase<Derived>::ones_like(const Derived & other)
93{
94 return Derived(at::ones_like(other), other.batch_sizes());
95}
96
97template <class Derived>
98Derived
99TensorBase<Derived>::full_like(const Derived & other, const CScalar & init)
101 return Derived(at::full_like(other, init), other.batch_sizes());
103
104template <class Derived>
105Derived
106TensorBase<Derived>::linspace(const Derived & start, const Derived & end, Size nstep, Size dim)
107{
109 neml_assert_dbg(nstep > 0, "nstep must be positive.");
110
111 auto res = start.batch_unsqueeze(dim);
113 if (nstep > 1)
115 auto Bd = utils::broadcast_batch_dim(start, end);
116 auto diff = (end - start).batch_unsqueeze(dim);
117
118 indexing::TensorIndices net(dim, indexing::None);
119 net.push_back(indexing::Ellipsis);
120 net.insert(net.end(), Bd - dim, indexing::None);
121 Scalar steps(at::arange(nstep, diff.options()).index(net) / (nstep - 1));
122
123 res = res + steps * diff;
124 }
125
126 return res;
127}
129template <class Derived>
130Derived
132 const Derived & start, const Derived & end, Size nstep, Size dim, const CScalar & base)
133{
134 auto exponent = neml2::Tensor::linspace(start, end, nstep, dim);
135 return Derived(at::pow(base, exponent), exponent.batch_sizes());
136}
137
138template <class Derived>
139Derived
141{
142 return Derived(ATensor::clone(), batch_sizes());
143}
144
145template <class Derived>
146Derived
149 return Derived(ATensor::detach(), batch_sizes());
151
152template <class Derived>
153Derived
155{
156 return Derived(ATensor::to(options), batch_sizes());
157}
159template <class Derived>
160bool
162{
163 return batch_dim() > 0;
164}
165
166template <class Derived>
167Size
170 return static_cast<Size>(_batch_sizes.size());
172
173template <class Derived>
174Size
176{
177 return dim() - batch_dim();
180template <class Derived>
184 return _batch_sizes;
185}
186
187template <class Derived>
190{
191 const auto i = index >= 0 ? index : index + batch_dim();
192
193 // Put the batch size into the traced graph if we are tracing
194 if (jit::tracer::isTracing())
195 return jit::tracer::getSizeOf(*this, i);
196
197 return size(i);
198}
200template <class Derived>
204 return sizes().slice(batch_dim());
206
207template <class Derived>
208Size
210{
211 return base_sizes()[index >= 0 ? index : index + base_dim()];
212}
214template <class Derived>
218 return utils::storage_size(base_sizes());
220
221template <class Derived>
222Derived
224{
225 indexing::TensorIndices indices_vec(indices);
226 indices_vec.insert(indices_vec.end(), base_dim(), indexing::Slice());
227 auto res = this->index(indices_vec);
228 return Derived(res, res.dim() - base_dim());
229}
230
231template <class Derived>
234{
235 indexing::TensorIndices indices2(batch_dim(), indexing::Slice());
236 indices2.insert(indices2.end(), indices.begin(), indices.end());
237 return neml2::Tensor(this->index(indices2), batch_sizes());
238}
239
240template <class Derived>
241Derived
242TensorBase<Derived>::batch_slice(Size dim, const indexing::Slice & index) const
243{
244 auto i = dim >= 0 ? dim : this->dim() + dim - base_dim();
245 auto res = this->slice(
246 i, index.start().expect_int(), index.stop().expect_int(), index.step().expect_int());
247 return Derived(res, res.dim() - base_dim());
248}
249
250template <class Derived>
252TensorBase<Derived>::base_slice(Size dim, const indexing::Slice & index) const
253{
254 auto i = dim < 0 ? this->dim() + dim : dim + batch_dim();
255 auto res = this->slice(
256 i, index.start().expect_int(), index.stop().expect_int(), index.step().expect_int());
257 return Derived(res, batch_sizes());
258}
259
260template <class Derived>
261void
263{
264 indexing::TensorIndices indices_vec(indices);
265 indices_vec.insert(indices_vec.end(), base_dim(), indexing::Slice());
266 this->index_put_(indices_vec, other);
267}
268
269template <class Derived>
270void
272{
273 indexing::TensorIndices indices_vec(indices);
274 indices_vec.insert(indices_vec.end(), base_dim(), indexing::Slice());
275 this->index_put_(indices_vec, v);
276}
277
278template <class Derived>
279void
281{
282 indexing::TensorIndices indices2(batch_dim(), indexing::Slice());
283 indices2.insert(indices2.end(), indices.begin(), indices.end());
284 this->index_put_(indices2, other);
285}
286
287template <class Derived>
288void
290{
291 indexing::TensorIndices indices2(batch_dim(), indexing::Slice());
292 indices2.insert(indices2.end(), indices.begin(), indices.end());
293 this->index_put_(indices2, v);
294}
295
296template <class Derived>
297Derived
299{
300 return Derived(ATensor::variable_data(), batch_sizes());
301}
302
303template <class Derived>
304Derived
306{
307 // We don't want to touch the base dimensions, so put -1 for them.
308 auto net = batch_shape.concrete();
309 net.insert(net.end(), base_dim(), -1);
310
311 // Record the batch sizes in the traced graph if we are tracing
312 for (Size i = 0; i < (Size)batch_shape.size(); ++i)
313 if (const auto * const si = batch_shape[i].traceable())
314 jit::tracer::ArgumentStash::stashIntArrayRefElem("size", net.size(), i, *si);
315
316 return Derived(expand(net), batch_shape);
317}
318
319template <class Derived>
320Derived
322{
323 auto i = dim >= 0 ? dim : this->dim() + dim - base_dim();
324 auto batch_shape = batch_sizes();
325 if (batch_shape[i] == batch_size)
326 return Derived(*this);
327
328 batch_shape[i] = batch_size;
329 return batch_expand(batch_shape);
330}
331
332template <class Derived>
335{
336 if (base_sizes() == base_shape)
337 return *this;
338
339 // We don't want to touch the batch dimensions, so put -1 for them.
340 auto net = base_shape.vec();
341 net.insert(net.begin(), batch_dim(), -1);
342 return neml2::Tensor(expand(net), batch_sizes());
343}
344
345template <class Derived>
348{
349 if (this->base_size(dim) == base_size)
350 return *this;
351
352 // We don't want to touch the batch dimensions and other base dimensions, so put -1 for them.
353 auto net = std::vector<Size>(this->dim(), -1);
354 auto i = dim < 0 ? this->dim() + dim : dim + batch_dim();
355 net[i] = base_size;
356 return neml2::Tensor(expand(net), batch_sizes());
357}
358
359template <class Derived>
360Derived
362{
363 return batch_expand(other.batch_sizes());
364}
365
366template <class Derived>
369{
370 return base_expand(other.base_sizes());
371}
372
373template <class Derived>
374Derived
376{
377 return Derived(batch_expand(batch_shape).contiguous(), batch_shape);
378}
379
380template <class Derived>
383{
384 return neml2::Tensor(base_expand(base_shape).contiguous(), batch_sizes());
385}
386
387template <class Derived>
388Derived
390{
391 // Record the batch sizes in the traced graph if we are tracing
392 for (Size i = 0; i < (Size)batch_shape.size(); ++i)
393 if (const auto * const si = batch_shape[i].traceable())
394 jit::tracer::ArgumentStash::stashIntArrayRefElem(
395 "shape", batch_shape.size() + base_dim(), i, *si);
396
397 return Derived(reshape(utils::add_shapes(batch_shape.concrete(), base_sizes())), batch_shape);
398}
399
400template <class Derived>
403{
404 auto batch_shape = batch_sizes();
405
406 // Record the batch sizes in the traced graph if we are tracing
407 for (Size i = 0; i < (Size)batch_shape.size(); ++i)
408 if (const auto * const si = batch_shape[i].traceable())
409 jit::tracer::ArgumentStash::stashIntArrayRefElem(
410 "shape", batch_shape.size() + base_shape.size(), i, *si);
411
412 return neml2::Tensor(reshape(utils::add_shapes(batch_shape.concrete(), base_shape)),
413 batch_sizes());
414}
415
416template <class Derived>
417Derived
419{
420 auto d2 = d >= 0 ? d : d - base_dim();
421 return Derived(unsqueeze(d2), batch_dim() + 1);
422}
423
424template <class Derived>
427{
428 auto d2 = d < 0 ? d : d + batch_dim();
429 return neml2::Tensor(ATensor::unsqueeze(d2), batch_sizes());
430}
431
432template <class Derived>
433Derived
435{
436 return Derived(ATensor::transpose(d1 < 0 ? d1 - base_dim() : d1, d2 < 0 ? d2 - base_dim() : d2),
437 batch_dim());
438}
439
440template <class Derived>
443{
444 return neml2::Tensor(
445 ATensor::transpose(d1 < 0 ? d1 : batch_dim() + d1, d2 < 0 ? d2 : batch_dim() + d2),
446 batch_sizes());
447}
448
449template <class Derived>
452{
453 if (base_dim() == 1)
454 return *this;
455
456 return base_reshape({base_storage()});
457}
458
459template <class Derived>
460Derived
462{
463 return Derived(-ATensor(*this), batch_sizes());
464}
465
466} // end namespace neml2
Scalar.
Definition Scalar.h:38
NEML2's enhanced tensor type.
Definition TensorBase.h:50
neml2::Tensor base_flatten() const
Flatten base dimensions.
Definition TensorBaseImpl.h:451
Derived batch_slice(Size dim, const indexing::Slice &index) const
Get a tensor by slicing along a batch dimension.
Definition TensorBaseImpl.h:242
Derived batch_reshape(const TraceableTensorShape &batch_shape) const
Reshape batch dimensions.
Definition TensorBaseImpl.h:389
Derived batch_index(indexing::TensorIndicesRef indices) const
Get a tensor by slicing on the batch dimensions.
Definition TensorBaseImpl.h:223
neml2::Tensor base_transpose(Size d1, Size d2) const
Transpose two base dimensions.
Definition TensorBaseImpl.h:442
neml2::Tensor base_slice(Size dim, const indexing::Slice &index) const
Get a tensor by slicing along a base dimension.
Definition TensorBaseImpl.h:252
TraceableSize batch_size(Size index) const
Return the size of a batch axis.
Definition TensorBaseImpl.h:189
Size base_storage() const
Return the flattened storage needed just for the base indices.
Definition TensorBaseImpl.h:216
TensorBase()=default
Special member functions.
bool batched() const
Whether the tensor is batched.
Definition TensorBaseImpl.h:161
Size batch_dim() const
Return the number of batch dimensions.
Definition TensorBaseImpl.h:168
Derived detach() const
Discard function graph.
Definition TensorBaseImpl.h:147
Derived operator-() const
Negation.
Definition TensorBaseImpl.h:461
Derived batch_transpose(Size d1, Size d2) const
Transpose two batch dimensions.
Definition TensorBaseImpl.h:434
Derived batch_expand_as(const neml2::Tensor &other) const
Expand the batch to have the same shape as another tensor.
Definition TensorBaseImpl.h:361
neml2::Tensor base_expand_as(const neml2::Tensor &other) const
Expand the base to have the same shape as another tensor.
Definition TensorBaseImpl.h:368
neml2::Tensor base_expand_copy(TensorShapeRef base_shape) const
Return a new tensor with values broadcast along the base dimensions.
Definition TensorBaseImpl.h:382
neml2::Tensor base_unsqueeze(Size d) const
Unsqueeze a base dimension.
Definition TensorBaseImpl.h:426
Derived clone() const
Definition TensorBaseImpl.h:140
Size base_size(Size index) const
Return the size of a base axis.
Definition TensorBaseImpl.h:209
Derived to(const TensorOptions &options) const
Change tensor options.
Definition TensorBaseImpl.h:154
const TraceableTensorShape & batch_sizes() const
Return the batch size.
Definition TensorBaseImpl.h:182
void base_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:280
void batch_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:262
Derived batch_unsqueeze(Size d) const
Unsqueeze a batch dimension.
Definition TensorBaseImpl.h:418
Derived batch_expand_copy(const TraceableTensorShape &batch_shape) const
Return a new tensor with values broadcast along the batch dimensions.
Definition TensorBaseImpl.h:375
Size base_dim() const
Return the number of base dimensions.
Definition TensorBaseImpl.h:175
TensorShapeRef base_sizes() const
Return the base size.
Definition TensorBaseImpl.h:202
neml2::Tensor base_reshape(TensorShapeRef base_shape) const
Reshape base dimensions.
Definition TensorBaseImpl.h:402
Derived batch_expand(const TraceableTensorShape &batch_shape) const
Definition TensorBaseImpl.h:305
neml2::Tensor base_expand(TensorShapeRef base_shape) const
Return a new view of the tensor with values broadcast along the base dimensions.
Definition TensorBaseImpl.h:334
neml2::Tensor base_index(indexing::TensorIndicesRef indices) const
Get a tensor by slicing on the base dimensions.
Definition TensorBaseImpl.h:233
Derived variable_data() const
Definition TensorBaseImpl.h:298
Definition Tensor.h:46
static Derived full_like(const Derived &other, const CScalar &init)
Definition TensorBaseImpl.h:99
static Derived zeros_like(const Derived &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:85
static Derived empty_like(const Derived &other)
Definition TensorBaseImpl.h:78
static Derived linspace(const Derived &start, const Derived &end, Size nstep, Size dim=0)
Create a new tensor by adding a new batch dimension with linear spacing between start and end.
Definition TensorBaseImpl.h:106
static Derived logspace(const Derived &start, const Derived &end, Size nstep, Size dim=0, const CScalar &base=10)
log-space equivalent of the linspace named constructor
Definition TensorBaseImpl.h:131
static Derived ones_like(const Derived &other)
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:92
c10::ArrayRef< TensorIndex > TensorIndicesRef
Definition indexing.h:39
c10::SmallVector< TensorIndex, 8 > TensorIndices
Definition indexing.h:38
Definition types.h:32
Size storage_size(TensorShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition shape_utils.cxx:32
Size broadcast_batch_dim(const T &...)
The batch dimension after broadcasting.
TensorShape add_shapes(const S &...)
Definition DiagnosticsInterface.cxx:30
void neml_assert_dbg(bool assertion, Args &&... args)
Definition assertions.h:60
int64_t Size
Definition types.h:65
c10::Scalar CScalar
Definition types.h:39
at::Tensor ATensor
Definition types.h:38
void neml_assert_broadcastable_dbg(const T &...)
A helper function to assert (in Debug mode) that all tensors are broadcastable.
c10::TensorOptions TensorOptions
Definition types.h:60
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:67
Traceable size.
Definition TraceableSize.h:40
Traceable tensor shape.
Definition TraceableTensorShape.h:38
TensorShape concrete() const
Definition TraceableTensorShape.cxx:78