32#include <torch/csrc/jit/frontend/tracer.h>
34#include "neml2/tensors/TensorBase.h"
35#include "neml2/tensors/Scalar.h"
36#include "neml2/tensors/assertions.h"
37#include "neml2/jit/utils.h"
41using namespace torch::jit;
46template <
class Derived>
49 _batch_sizes(utils::extract_batch_sizes(tensor, batch_dim))
54 " is smaller than the requested number of batch dimensions ",
58template <
class Derived>
61 _batch_sizes(batch_shape)
66 " cannot be constructed with batch shape ",
70template <
class Derived>
76template <
class Derived>
80 return Derived(at::empty_like(other), other.batch_sizes());
83template <
class Derived>
87 return Derived(at::zeros_like(other), other.batch_sizes());
90template <
class Derived>
94 return Derived(at::ones_like(other), other.batch_sizes());
97template <
class Derived>
101 return Derived(at::full_like(other, init), other.batch_sizes());
104template <
class Derived>
111 auto res = start.batch_unsqueeze(dim);
116 auto diff = (end - start).batch_unsqueeze(dim);
119 net.push_back(indexing::Ellipsis);
120 net.insert(net.end(), Bd - dim, indexing::None);
121 Scalar steps(at::arange(nstep, diff.options()).index(net) / (nstep - 1));
123 res = res + steps * diff;
129template <
class Derived>
132 const Derived & start,
const Derived & end,
Size nstep,
Size dim,
const CScalar & base)
135 return Derived(at::pow(base, exponent), exponent.batch_sizes());
138template <
class Derived>
142 return Derived(ATensor::clone(), batch_sizes());
145template <
class Derived>
149 return Derived(ATensor::detach(), batch_sizes());
152template <
class Derived>
156 return Derived(ATensor::to(options), batch_sizes());
159template <
class Derived>
163 return batch_dim() > 0;
166template <
class Derived>
170 return static_cast<Size>(_batch_sizes.size());
173template <
class Derived>
177 return dim() - batch_dim();
180template <
class Derived>
187template <
class Derived>
191 const auto i = index >= 0 ? index : index + batch_dim();
194 if (jit::tracer::isTracing())
195 return jit::tracer::getSizeOf(*
this, i);
200template <
class Derived>
204 return sizes().slice(batch_dim());
207template <
class Derived>
211 return base_sizes()[index >= 0 ? index : index + base_dim()];
214template <
class Derived>
221template <
class Derived>
226 indices_vec.insert(indices_vec.end(), base_dim(), indexing::Slice());
227 auto res = this->index(indices_vec);
228 return Derived(res, res.dim() - base_dim());
231template <
class Derived>
236 indices2.insert(indices2.end(), indices.begin(), indices.end());
240template <
class Derived>
244 auto i = dim >= 0 ? dim : this->dim() + dim - base_dim();
245 auto res = this->slice(
246 i, index.start().expect_int(), index.stop().expect_int(), index.step().expect_int());
247 return Derived(res, res.dim() - base_dim());
250template <
class Derived>
254 auto i = dim < 0 ? this->dim() + dim : dim + batch_dim();
255 auto res = this->slice(
256 i, index.start().expect_int(), index.stop().expect_int(), index.step().expect_int());
257 return Derived(res, batch_sizes());
260template <
class Derived>
265 indices_vec.insert(indices_vec.end(), base_dim(), indexing::Slice());
266 this->index_put_(indices_vec, other);
269template <
class Derived>
274 indices_vec.insert(indices_vec.end(), base_dim(), indexing::Slice());
275 this->index_put_(indices_vec, v);
278template <
class Derived>
283 indices2.insert(indices2.end(), indices.begin(), indices.end());
284 this->index_put_(indices2, other);
287template <
class Derived>
292 indices2.insert(indices2.end(), indices.begin(), indices.end());
293 this->index_put_(indices2, v);
296template <
class Derived>
300 return Derived(ATensor::variable_data(), batch_sizes());
303template <
class Derived>
309 net.insert(net.end(), base_dim(), -1);
312 for (
Size i = 0; i < (
Size)batch_shape.size(); ++i)
313 if (
const auto *
const si = batch_shape[i].traceable())
314 jit::tracer::ArgumentStash::stashIntArrayRefElem(
"size", net.size(), i, *si);
316 return Derived(expand(net), batch_shape);
319template <
class Derived>
323 auto i = dim >= 0 ? dim : this->dim() + dim - base_dim();
324 auto batch_shape = batch_sizes();
325 if (batch_shape[i] == batch_size)
326 return Derived(*
this);
328 batch_shape[i] = batch_size;
329 return batch_expand(batch_shape);
332template <
class Derived>
336 if (base_sizes() == base_shape)
340 auto net = base_shape.vec();
341 net.insert(net.begin(), batch_dim(), -1);
345template <
class Derived>
349 if (this->base_size(dim) == base_size)
353 auto net = std::vector<Size>(this->dim(), -1);
354 auto i = dim < 0 ? this->dim() + dim : dim + batch_dim();
359template <
class Derived>
366template <
class Derived>
373template <
class Derived>
377 return Derived(batch_expand(batch_shape).contiguous(), batch_shape);
380template <
class Derived>
384 return neml2::Tensor(base_expand(base_shape).contiguous(), batch_sizes());
387template <
class Derived>
392 for (
Size i = 0; i < (
Size)batch_shape.size(); ++i)
393 if (
const auto *
const si = batch_shape[i].traceable())
394 jit::tracer::ArgumentStash::stashIntArrayRefElem(
395 "shape", batch_shape.size() + base_dim(), i, *si);
400template <
class Derived>
404 auto batch_shape = batch_sizes();
407 for (
Size i = 0; i < (
Size)batch_shape.size(); ++i)
408 if (
const auto *
const si = batch_shape[i].traceable())
409 jit::tracer::ArgumentStash::stashIntArrayRefElem(
410 "shape", batch_shape.size() + base_shape.size(), i, *si);
416template <
class Derived>
420 auto d2 = d >= 0 ? d : d - base_dim();
421 return Derived(unsqueeze(d2), batch_dim() + 1);
424template <
class Derived>
428 auto d2 = d < 0 ? d : d + batch_dim();
432template <
class Derived>
436 return Derived(ATensor::transpose(d1 < 0 ? d1 - base_dim() : d1, d2 < 0 ? d2 - base_dim() : d2),
440template <
class Derived>
445 ATensor::transpose(d1 < 0 ? d1 : batch_dim() + d1, d2 < 0 ? d2 : batch_dim() + d2),
449template <
class Derived>
456 return base_reshape({base_storage()});
459template <
class Derived>
463 return Derived(-
ATensor(*
this), batch_sizes());
Scalar.
Definition Scalar.h:38
NEML2's enhanced tensor type.
Definition TensorBase.h:50
neml2::Tensor base_flatten() const
Flatten base dimensions.
Definition TensorBaseImpl.h:451
Derived batch_slice(Size dim, const indexing::Slice &index) const
Get a tensor by slicing along a batch dimension.
Definition TensorBaseImpl.h:242
Derived batch_reshape(const TraceableTensorShape &batch_shape) const
Reshape batch dimensions.
Definition TensorBaseImpl.h:389
Derived batch_index(indexing::TensorIndicesRef indices) const
Get a tensor by slicing on the batch dimensions.
Definition TensorBaseImpl.h:223
neml2::Tensor base_transpose(Size d1, Size d2) const
Transpose two base dimensions.
Definition TensorBaseImpl.h:442
neml2::Tensor base_slice(Size dim, const indexing::Slice &index) const
Get a tensor by slicing along a base dimension.
Definition TensorBaseImpl.h:252
TraceableSize batch_size(Size index) const
Return the size of a batch axis.
Definition TensorBaseImpl.h:189
Size base_storage() const
Return the flattened storage needed just for the base indices.
Definition TensorBaseImpl.h:216
TensorBase()=default
Special member functions.
bool batched() const
Whether the tensor is batched.
Definition TensorBaseImpl.h:161
Size batch_dim() const
Return the number of batch dimensions.
Definition TensorBaseImpl.h:168
Derived detach() const
Discard function graph.
Definition TensorBaseImpl.h:147
Derived operator-() const
Negation.
Definition TensorBaseImpl.h:461
Derived batch_transpose(Size d1, Size d2) const
Transpose two batch dimensions.
Definition TensorBaseImpl.h:434
Derived batch_expand_as(const neml2::Tensor &other) const
Expand the batch to have the same shape as another tensor.
Definition TensorBaseImpl.h:361
neml2::Tensor base_expand_as(const neml2::Tensor &other) const
Expand the base to have the same shape as another tensor.
Definition TensorBaseImpl.h:368
neml2::Tensor base_expand_copy(TensorShapeRef base_shape) const
Return a new tensor with values broadcast along the base dimensions.
Definition TensorBaseImpl.h:382
neml2::Tensor base_unsqueeze(Size d) const
Unsqueeze a base dimension.
Definition TensorBaseImpl.h:426
Derived clone() const
Definition TensorBaseImpl.h:140
Size base_size(Size index) const
Return the size of a base axis.
Definition TensorBaseImpl.h:209
Derived to(const TensorOptions &options) const
Change tensor options.
Definition TensorBaseImpl.h:154
const TraceableTensorShape & batch_sizes() const
Return the batch size.
Definition TensorBaseImpl.h:182
void base_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:280
void batch_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:262
Derived batch_unsqueeze(Size d) const
Unsqueeze a batch dimension.
Definition TensorBaseImpl.h:418
Derived batch_expand_copy(const TraceableTensorShape &batch_shape) const
Return a new tensor with values broadcast along the batch dimensions.
Definition TensorBaseImpl.h:375
Size base_dim() const
Return the number of base dimensions.
Definition TensorBaseImpl.h:175
TensorShapeRef base_sizes() const
Return the base size.
Definition TensorBaseImpl.h:202
neml2::Tensor base_reshape(TensorShapeRef base_shape) const
Reshape base dimensions.
Definition TensorBaseImpl.h:402
Derived batch_expand(const TraceableTensorShape &batch_shape) const
Definition TensorBaseImpl.h:305
neml2::Tensor base_expand(TensorShapeRef base_shape) const
Return a new view of the tensor with values broadcast along the base dimensions.
Definition TensorBaseImpl.h:334
neml2::Tensor base_index(indexing::TensorIndicesRef indices) const
Get a tensor by slicing on the base dimensions.
Definition TensorBaseImpl.h:233
Derived variable_data() const
Definition TensorBaseImpl.h:298
static Derived full_like(const Derived &other, const CScalar &init)
Definition TensorBaseImpl.h:99
static Derived zeros_like(const Derived &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:85
static Derived empty_like(const Derived &other)
Definition TensorBaseImpl.h:78
static Derived linspace(const Derived &start, const Derived &end, Size nstep, Size dim=0)
Create a new tensor by adding a new batch dimension with linear spacing between start and end.
Definition TensorBaseImpl.h:106
static Derived logspace(const Derived &start, const Derived &end, Size nstep, Size dim=0, const CScalar &base=10)
log-space equivalent of the linspace named constructor
Definition TensorBaseImpl.h:131
static Derived ones_like(const Derived &other)
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:92
c10::ArrayRef< TensorIndex > TensorIndicesRef
Definition indexing.h:39
c10::SmallVector< TensorIndex, 8 > TensorIndices
Definition indexing.h:38
Size storage_size(TensorShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition shape_utils.cxx:32
Size broadcast_batch_dim(const T &...)
The batch dimension after broadcasting.
TensorShape add_shapes(const S &...)
Definition DiagnosticsInterface.cxx:30
void neml_assert_dbg(bool assertion, Args &&... args)
Definition assertions.h:60
int64_t Size
Definition types.h:65
c10::Scalar CScalar
Definition types.h:39
at::Tensor ATensor
Definition types.h:38
void neml_assert_broadcastable_dbg(const T &...)
A helper function to assert (in Debug mode) that all tensors are broadcastable.
c10::TensorOptions TensorOptions
Definition types.h:60
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:67
Traceable size.
Definition TraceableSize.h:40
Traceable tensor shape.
Definition TraceableTensorShape.h:38
TensorShape concrete() const
Definition TraceableTensorShape.cxx:78