NEML2 2.0.0
Loading...
Searching...
No Matches
TensorBaseImpl.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
30#pragma once
31
32#include "neml2/tensors/TensorBase.h"
33#include "neml2/tensors/Scalar.h"
34#include "neml2/misc/math.h"
35#include "neml2/jit/utils.h"
36
37namespace neml2
38{
39template <class Derived>
40TensorBase<Derived>::TensorBase(const torch::Tensor & tensor, Size batch_dim)
41 : torch::Tensor(tensor),
42 _batch_dim(batch_dim),
43 _batch_sizes(utils::extract_batch_sizes(tensor, batch_dim))
44{
45 neml_assert_dbg((Size)sizes().size() >= _batch_dim,
46 "Tensor dimension ",
47 sizes().size(),
48 " is smaller than the requested number of batch dimensions ",
49 _batch_dim);
50}
51
52template <class Derived>
53TensorBase<Derived>::TensorBase(const torch::Tensor & tensor,
55 : torch::Tensor(tensor),
56 _batch_dim(Size(batch_shape.size())),
57 _batch_sizes(batch_shape)
58{
59 neml_assert_dbg(batch_sizes() == batch_shape,
60 "Tensor of shape ",
61 sizes(),
62 " cannot be constructed with batch shape ",
65
66template <class Derived>
67Derived
69{
70 return Derived(torch::empty_like(other), other.batch_sizes());
72
73template <class Derived>
74Derived
76{
77 return Derived(torch::zeros_like(other), other.batch_sizes());
78}
79
80template <class Derived>
81Derived
83{
84 return Derived(torch::ones_like(other), other.batch_sizes());
85}
86
87template <class Derived>
88Derived
90{
91 return Derived(torch::full_like(other, init), other.batch_sizes());
92}
93
94template <class Derived>
95Derived
96TensorBase<Derived>::linspace(const Derived & start, const Derived & end, Size nstep, Size dim)
99 neml_assert_dbg(nstep > 0, "nstep must be positive.");
100
101 auto res = start.batch_unsqueeze(dim);
103 if (nstep > 1)
105 auto Bd = broadcast_batch_dim(start, end);
106 auto diff = (end - start).batch_unsqueeze(dim);
107
108 indexing::TensorIndices net(dim, indexing::None);
109 net.push_back(indexing::Ellipsis);
110 net.insert(net.end(), Bd - dim, indexing::None);
111 Scalar steps = torch::arange(nstep, diff.options()).index(net) / (nstep - 1);
112
113 res = res + steps * diff;
114 }
115
116 return res;
117}
119template <class Derived>
120Derived
122 const Derived & start, const Derived & end, Size nstep, Size dim, Real base)
123{
124 auto exponent = neml2::Tensor::linspace(start, end, nstep, dim);
125 return math::pow(base, exponent);
126}
127
128template <class Derived>
129Derived
131{
132 return Derived(torch::Tensor::clone(memory_format), batch_sizes());
133}
134
135template <class Derived>
136Derived
139 return Derived(torch::Tensor::detach(), batch_sizes());
141
142template <class Derived>
143Derived
144TensorBase<Derived>::to(const torch::TensorOptions & options) const
145{
146 return Derived(torch::Tensor::to(options), batch_sizes());
147}
149template <class Derived>
150bool
152{
153 return _batch_dim;
154}
155
156template <class Derived>
157Size
160 return _batch_dim;
162
163template <class Derived>
166{
167 return dim() - batch_dim();
168}
170template <class Derived>
174 return _batch_sizes;
175}
176
177template <class Derived>
180{
181 const auto i = index >= 0 ? index : index + batch_dim();
182
183 // Put the batch size into the traced graph if we are tracing
184 if (torch::jit::tracer::isTracing())
185 return torch::jit::tracer::getSizeOf(*this, i);
186
187 return size(i);
188}
189
190template <class Derived>
194 return sizes().slice(_batch_dim);
196
197template <class Derived>
198Size
200{
201 return base_sizes()[index >= 0 ? index : index + base_dim()];
202}
204template <class Derived>
208 return utils::storage_size(base_sizes());
210
211template <class Derived>
212Derived
214{
216 indices_vec.insert(indices_vec.end(), base_dim(), torch::indexing::Slice());
217 auto res = this->index(indices_vec);
218 return Derived(res, res.dim() - base_dim());
219}
220
221template <class Derived>
224{
225 indexing::TensorIndices indices2(batch_dim(), torch::indexing::Slice());
226 indices2.insert(indices2.end(), indices.begin(), indices.end());
227 return neml2::Tensor(this->index(indices2), batch_sizes());
228}
229
230template <class Derived>
231void
233 const torch::Tensor & other)
234{
236 indices_vec.insert(indices_vec.end(), base_dim(), torch::indexing::Slice());
237 this->index_put_(indices_vec, other);
238}
239
240template <class Derived>
241void
243{
245 indices_vec.insert(indices_vec.end(), base_dim(), torch::indexing::Slice());
246 this->index_put_(indices_vec, v);
247}
248
249template <class Derived>
250void
252 const torch::Tensor & other)
253{
254 indexing::TensorIndices indices2(batch_dim(), torch::indexing::Slice());
255 indices2.insert(indices2.end(), indices.begin(), indices.end());
256 this->index_put_(indices2, other);
257}
258
259template <class Derived>
260void
262{
263 indexing::TensorIndices indices2(batch_dim(), torch::indexing::Slice());
264 indices2.insert(indices2.end(), indices.begin(), indices.end());
265 this->index_put_(indices2, v);
266}
267
268template <class Derived>
269Derived
271{
272 return Derived(torch::Tensor::variable_data(), batch_sizes());
273}
274
275template <class Derived>
276Derived
278{
279 // We don't want to touch the base dimensions, so put -1 for them.
280 auto net = batch_shape.concrete();
281 net.insert(net.end(), base_dim(), -1);
282
283 // Record the batch sizes in the traced graph if we are tracing
284 for (Size i = 0; i < (Size)batch_shape.size(); ++i)
285 if (const auto * const si = batch_shape[i].traceable())
286 torch::jit::tracer::ArgumentStash::stashIntArrayRefElem("size", net.size(), i, *si);
287
288 return Derived(expand(net), batch_shape);
289}
290
291template <class Derived>
292Derived
294{
295 // We don't want to touch other batch dimensions and the base dimensions, so put -1 for them.
296 auto net = std::vector<Size>(this->dim(), -1);
297 auto i = dim >= 0 ? dim : this->dim() + dim - base_dim();
298 net[i] = batch_size.concrete();
299
300 // Record the batch sizes in the traced graph if we are tracing
301 if (const auto * const s = batch_size.traceable())
302 torch::jit::tracer::ArgumentStash::stashIntArrayRefElem("size", this->dim(), i, *s);
303
304 return Derived(expand(net), batch_dim());
305}
306
307template <class Derived>
310{
311 if (base_sizes() == base_shape)
312 return *this;
313
314 // We don't want to touch the batch dimensions, so put -1 for them.
315 auto net = base_shape.vec();
316 net.insert(net.begin(), batch_dim(), -1);
317 return neml2::Tensor(expand(net), batch_sizes());
318}
319
320template <class Derived>
323{
324 if (this->base_size(dim) == base_size)
325 return *this;
326
327 // We don't want to touch the batch dimensions and other base dimensions, so put -1 for them.
328 auto net = std::vector<Size>(this->dim(), -1);
329 auto i = dim < 0 ? this->dim() + dim : dim + batch_dim();
330 net[i] = base_size;
331 return neml2::Tensor(expand(net), batch_sizes());
332}
333
334template <class Derived>
335Derived
340
341template <class Derived>
344{
345 return neml2::Tensor(base_expand(base_shape).contiguous(), batch_sizes());
346}
347
348template <class Derived>
349Derived
351{
352 // Record the batch sizes in the traced graph if we are tracing
353 for (Size i = 0; i < (Size)batch_shape.size(); ++i)
354 if (const auto * const si = batch_shape[i].traceable())
355 torch::jit::tracer::ArgumentStash::stashIntArrayRefElem(
356 "shape", batch_shape.size() + base_dim(), i, *si);
357
358 return Derived(reshape(utils::add_shapes(batch_shape.concrete(), base_sizes())), batch_shape);
359}
360
361template <class Derived>
364{
365 auto batch_shape = batch_sizes();
366
367 // Record the batch sizes in the traced graph if we are tracing
368 for (Size i = 0; i < (Size)batch_shape.size(); ++i)
369 if (const auto * const si = batch_shape[i].traceable())
370 torch::jit::tracer::ArgumentStash::stashIntArrayRefElem(
371 "shape", batch_shape.size() + base_shape.size(), i, *si);
372
374 batch_sizes());
375}
376
377template <class Derived>
378Derived
380{
381 auto d2 = d >= 0 ? d : d - base_dim();
382 return Derived(unsqueeze(d2), _batch_dim + 1);
383}
384
385template <class Derived>
388{
389 auto d2 = d < 0 ? d : d + batch_dim();
390 return neml2::Tensor(torch::Tensor::unsqueeze(d2), batch_sizes());
391}
392
393template <class Derived>
394Derived
396{
397 return Derived(
398 torch::Tensor::transpose(d1 < 0 ? d1 - base_dim() : d1, d2 < 0 ? d2 - base_dim() : d2),
399 _batch_dim);
400}
401
402template <class Derived>
405{
406 return neml2::Tensor(
407 torch::Tensor::transpose(d1 < 0 ? d1 : _batch_dim + d1, d2 < 0 ? d2 : _batch_dim + d2),
408 batch_sizes());
409}
410
411template <class Derived>
414{
415 if (base_dim() == 1)
416 return *this;
417
418 return base_reshape({base_storage()});
419}
420
421template <class Derived>
422Derived
424{
425 return Derived(-torch::Tensor(*this), batch_sizes());
426}
427
428} // end namespace neml2
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:54
CrossRef()=default
Scalar.
Definition Scalar.h:38
Derived clone(torch::MemoryFormat memory_format=torch::MemoryFormat::Contiguous) const
Definition TensorBaseImpl.h:130
static Derived logspace(const Derived &start, const Derived &end, Size nstep, Size dim=0, Real base=10)
log-space equivalent of the linspace named constructor
Definition TensorBaseImpl.h:121
neml2::Tensor base_flatten() const
Flatten base dimensions.
Definition TensorBaseImpl.h:413
Derived batch_reshape(const TraceableTensorShape &batch_shape) const
Reshape batch dimensions.
Definition TensorBaseImpl.h:350
Derived batch_index(indexing::TensorIndicesRef indices) const
Get a tensor by slicing on the batch dimensions.
Definition TensorBaseImpl.h:213
neml2::Tensor base_transpose(Size d1, Size d2) const
Transpose two base dimensions.
Definition TensorBaseImpl.h:404
TraceableSize batch_size(Size index) const
Return the size of a batch axis.
Definition TensorBaseImpl.h:179
Size base_storage() const
Return the flattened storage needed just for the base indices.
Definition TensorBaseImpl.h:206
TensorBase()=default
Default constructor.
bool batched() const
Whether the tensor is batched.
Definition TensorBaseImpl.h:151
Size batch_dim() const
Return the number of batch dimensions.
Definition TensorBaseImpl.h:158
Derived detach() const
Discard function graph.
Definition TensorBaseImpl.h:137
Derived operator-() const
Negation.
Definition TensorBaseImpl.h:423
Derived batch_transpose(Size d1, Size d2) const
Transpose two batch dimensions.
Definition TensorBaseImpl.h:395
neml2::Tensor base_expand_copy(TensorShapeRef base_shape) const
Return a new tensor with values broadcast along the base dimensions.
Definition TensorBaseImpl.h:343
neml2::Tensor base_unsqueeze(Size d) const
Unsqueeze a base dimension.
Definition TensorBaseImpl.h:387
Derived to(const torch::TensorOptions &options) const
Change tensor options.
Definition TensorBaseImpl.h:144
Size base_size(Size index) const
Return the size of a base axis.
Definition TensorBaseImpl.h:199
static Derived zeros_like(const Derived &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:75
const TraceableTensorShape & batch_sizes() const
Return the batch size.
Definition TensorBaseImpl.h:172
static Derived empty_like(const Derived &other)
Empty tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:68
Derived batch_unsqueeze(Size d) const
Unsqueeze a batch dimension.
Definition TensorBaseImpl.h:379
static Derived linspace(const Derived &start, const Derived &end, Size nstep, Size dim=0)
Create a new tensor by adding a new batch dimension with linear spacing between start and end.
Definition TensorBaseImpl.h:96
static Derived full_like(const Derived &other, Real init)
Definition TensorBaseImpl.h:89
void batch_index_put_(indexing::TensorIndicesRef indices, const torch::Tensor &other)
Definition TensorBaseImpl.h:232
Derived batch_expand_copy(const TraceableTensorShape &batch_shape) const
Return a new tensor with values broadcast along the batch dimensions.
Definition TensorBaseImpl.h:336
Size base_dim() const
Return the number of base dimensions.
Definition TensorBaseImpl.h:165
TensorShapeRef base_sizes() const
Return the base size.
Definition TensorBaseImpl.h:192
neml2::Tensor base_reshape(TensorShapeRef base_shape) const
Reshape base dimensions.
Definition TensorBaseImpl.h:363
void base_index_put_(indexing::TensorIndicesRef indices, const torch::Tensor &other)
Definition TensorBaseImpl.h:251
Derived batch_expand(const TraceableTensorShape &batch_shape) const
Definition TensorBaseImpl.h:277
neml2::Tensor base_expand(TensorShapeRef base_shape) const
Return a new view of the tensor with values broadcast along the base dimensions.
Definition TensorBaseImpl.h:309
static Derived ones_like(const Derived &other)
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:82
neml2::Tensor base_index(indexing::TensorIndicesRef indices) const
Get a tensor by slicing on the base dimensions.
Definition TensorBaseImpl.h:223
Derived variable_data() const
Definition TensorBaseImpl.h:270
Definition Tensor.h:47
torch::SmallVector< TensorIndex > TensorIndices
Definition types.h:41
torch::ArrayRef< TensorIndex > TensorIndicesRef
Definition types.h:42
Tensor pow(const Real &a, const Tensor &n)
Definition math.cxx:336
Size storage_size(TensorShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:55
TensorShape add_shapes(S &&... shape)
Definition utils.h:301
Definition CrossRef.cxx:31
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:76
Size broadcast_batch_dim(const T &...)
The batch dimension after broadcasting.
double Real
Definition types.h:31
int64_t Size
Definition types.h:33
void neml_assert_broadcastable_dbg(const T &...)
A helper function to assert (in Debug mode) that all tensors are broadcastable.
torch::IntArrayRef TensorShapeRef
Definition types.h:35
Traceable size.
Definition types.h:52
Size concrete() const
Definition types.cxx:37
const torch::Tensor * traceable() const noexcept
Definition types.cxx:31
Traceable tensor shape.
Definition types.h:81