32#include "neml2/tensors/TensorBase.h"
33#include "neml2/tensors/Scalar.h"
34#include "neml2/misc/math.h"
35#include "neml2/jit/utils.h"
39template <
class Derived>
42 _batch_dim(batch_dim),
43 _batch_sizes(utils::extract_batch_sizes(tensor, batch_dim))
48 " is smaller than the requested number of batch dimensions ",
52template <
class Derived>
62 " cannot be constructed with batch shape ",
66template <
class Derived>
70 return Derived(torch::empty_like(
other),
other.batch_sizes());
73template <
class Derived>
77 return Derived(torch::zeros_like(
other),
other.batch_sizes());
80template <
class Derived>
84 return Derived(torch::ones_like(
other),
other.batch_sizes());
87template <
class Derived>
91 return Derived(torch::full_like(
other,
init),
other.batch_sizes());
94template <
class Derived>
101 auto res =
start.batch_unsqueeze(dim);
106 auto diff = (end -
start).batch_unsqueeze(dim);
109 net.push_back(indexing::Ellipsis);
110 net.insert(
net.end(),
Bd - dim, indexing::None);
119template <
class Derived>
128template <
class Derived>
132 return Derived(torch::Tensor::clone(
memory_format), batch_sizes());
135template <
class Derived>
139 return Derived(torch::Tensor::detach(), batch_sizes());
142template <
class Derived>
146 return Derived(torch::Tensor::to(options), batch_sizes());
149template <
class Derived>
156template <
class Derived>
163template <
class Derived>
167 return dim() - batch_dim();
170template <
class Derived>
177template <
class Derived>
181 const auto i = index >= 0 ? index : index + batch_dim();
184 if (torch::jit::tracer::isTracing())
185 return torch::jit::tracer::getSizeOf(*
this,
i);
190template <
class Derived>
194 return sizes().slice(_batch_dim);
197template <
class Derived>
201 return base_sizes()[index >= 0 ? index : index + base_dim()];
204template <
class Derived>
211template <
class Derived>
218 return Derived(
res,
res.dim() - base_dim());
221template <
class Derived>
230template <
class Derived>
233 const torch::Tensor &
other)
240template <
class Derived>
249template <
class Derived>
252 const torch::Tensor &
other)
259template <
class Derived>
268template <
class Derived>
272 return Derived(torch::Tensor::variable_data(), batch_sizes());
275template <
class Derived>
281 net.insert(
net.end(), base_dim(), -1);
286 torch::jit::tracer::ArgumentStash::stashIntArrayRefElem(
"size",
net.size(),
i, *
si);
291template <
class Derived>
296 auto net = std::vector<Size>(this->dim(), -1);
297 auto i = dim >= 0 ? dim : this->dim() + dim - base_dim();
301 if (
const auto *
const s = batch_size.
traceable())
302 torch::jit::tracer::ArgumentStash::stashIntArrayRefElem(
"size", this->dim(),
i, *
s);
304 return Derived(
expand(
net), batch_dim());
307template <
class Derived>
316 net.insert(
net.begin(), batch_dim(), -1);
320template <
class Derived>
324 if (this->base_size(dim) == base_size)
328 auto net = std::vector<Size>(this->dim(), -1);
329 auto i = dim < 0 ? this->dim() + dim : dim + batch_dim();
334template <
class Derived>
341template <
class Derived>
348template <
class Derived>
355 torch::jit::tracer::ArgumentStash::stashIntArrayRefElem(
361template <
class Derived>
370 torch::jit::tracer::ArgumentStash::stashIntArrayRefElem(
377template <
class Derived>
381 auto d2 = d >= 0 ? d : d - base_dim();
385template <
class Derived>
389 auto d2 = d < 0 ? d : d + batch_dim();
393template <
class Derived>
398 torch::Tensor::transpose(
d1 < 0 ?
d1 - base_dim() :
d1,
d2 < 0 ?
d2 - base_dim() :
d2),
402template <
class Derived>
407 torch::Tensor::transpose(
d1 < 0 ?
d1 : _batch_dim +
d1,
d2 < 0 ?
d2 : _batch_dim +
d2),
411template <
class Derived>
418 return base_reshape({base_storage()});
421template <
class Derived>
425 return Derived(-torch::Tensor(*
this), batch_sizes());
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:54
Scalar.
Definition Scalar.h:38
Derived clone(torch::MemoryFormat memory_format=torch::MemoryFormat::Contiguous) const
Definition TensorBaseImpl.h:130
static Derived logspace(const Derived &start, const Derived &end, Size nstep, Size dim=0, Real base=10)
log-space equivalent of the linspace named constructor
Definition TensorBaseImpl.h:121
neml2::Tensor base_flatten() const
Flatten base dimensions.
Definition TensorBaseImpl.h:413
Derived batch_reshape(const TraceableTensorShape &batch_shape) const
Reshape batch dimensions.
Definition TensorBaseImpl.h:350
Derived batch_index(indexing::TensorIndicesRef indices) const
Get a tensor by slicing on the batch dimensions.
Definition TensorBaseImpl.h:213
neml2::Tensor base_transpose(Size d1, Size d2) const
Transpose two base dimensions.
Definition TensorBaseImpl.h:404
TraceableSize batch_size(Size index) const
Return the size of a batch axis.
Definition TensorBaseImpl.h:179
Size base_storage() const
Return the flattened storage needed just for the base indices.
Definition TensorBaseImpl.h:206
TensorBase()=default
Default constructor.
bool batched() const
Whether the tensor is batched.
Definition TensorBaseImpl.h:151
Size batch_dim() const
Return the number of batch dimensions.
Definition TensorBaseImpl.h:158
Derived detach() const
Discard function graph.
Definition TensorBaseImpl.h:137
Derived operator-() const
Negation.
Definition TensorBaseImpl.h:423
Derived batch_transpose(Size d1, Size d2) const
Transpose two batch dimensions.
Definition TensorBaseImpl.h:395
neml2::Tensor base_expand_copy(TensorShapeRef base_shape) const
Return a new tensor with values broadcast along the base dimensions.
Definition TensorBaseImpl.h:343
neml2::Tensor base_unsqueeze(Size d) const
Unsqueeze a base dimension.
Definition TensorBaseImpl.h:387
Derived to(const torch::TensorOptions &options) const
Change tensor options.
Definition TensorBaseImpl.h:144
Size base_size(Size index) const
Return the size of a base axis.
Definition TensorBaseImpl.h:199
static Derived zeros_like(const Derived &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:75
const TraceableTensorShape & batch_sizes() const
Return the batch size.
Definition TensorBaseImpl.h:172
static Derived empty_like(const Derived &other)
Empty tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:68
Derived batch_unsqueeze(Size d) const
Unsqueeze a batch dimension.
Definition TensorBaseImpl.h:379
static Derived linspace(const Derived &start, const Derived &end, Size nstep, Size dim=0)
Create a new tensor by adding a new batch dimension with linear spacing between start and end.
Definition TensorBaseImpl.h:96
static Derived full_like(const Derived &other, Real init)
Definition TensorBaseImpl.h:89
void batch_index_put_(indexing::TensorIndicesRef indices, const torch::Tensor &other)
Definition TensorBaseImpl.h:232
Derived batch_expand_copy(const TraceableTensorShape &batch_shape) const
Return a new tensor with values broadcast along the batch dimensions.
Definition TensorBaseImpl.h:336
Size base_dim() const
Return the number of base dimensions.
Definition TensorBaseImpl.h:165
TensorShapeRef base_sizes() const
Return the base size.
Definition TensorBaseImpl.h:192
neml2::Tensor base_reshape(TensorShapeRef base_shape) const
Reshape base dimensions.
Definition TensorBaseImpl.h:363
void base_index_put_(indexing::TensorIndicesRef indices, const torch::Tensor &other)
Definition TensorBaseImpl.h:251
Derived batch_expand(const TraceableTensorShape &batch_shape) const
Definition TensorBaseImpl.h:277
neml2::Tensor base_expand(TensorShapeRef base_shape) const
Return a new view of the tensor with values broadcast along the base dimensions.
Definition TensorBaseImpl.h:309
static Derived ones_like(const Derived &other)
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:82
neml2::Tensor base_index(indexing::TensorIndicesRef indices) const
Get a tensor by slicing on the base dimensions.
Definition TensorBaseImpl.h:223
Derived variable_data() const
Definition TensorBaseImpl.h:270
torch::SmallVector< TensorIndex > TensorIndices
Definition types.h:41
torch::ArrayRef< TensorIndex > TensorIndicesRef
Definition types.h:42
Tensor pow(const Real &a, const Tensor &n)
Definition math.cxx:336
Size storage_size(TensorShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:55
TensorShape add_shapes(S &&... shape)
Definition utils.h:301
Definition CrossRef.cxx:31
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:76
Size broadcast_batch_dim(const T &...)
The batch dimension after broadcasting.
double Real
Definition types.h:31
int64_t Size
Definition types.h:33
void neml_assert_broadcastable_dbg(const T &...)
A helper function to assert (in Debug mode) that all tensors are broadcastable.
torch::IntArrayRef TensorShapeRef
Definition types.h:35
Traceable size.
Definition types.h:52
Size concrete() const
Definition types.cxx:37
const torch::Tensor * traceable() const noexcept
Definition types.cxx:31
Traceable tensor shape.
Definition types.h:81