27#include "neml2/misc/utils.h"
32template <
class Derived>
44template <
class Derived>
58 template <
class Derived2>
102 Derived
clone(torch::MemoryFormat
memory_format = torch::MemoryFormat::Contiguous)
const;
106 using torch::Tensor::detach_;
108 Derived
to(
const torch::TensorOptions & options)
const;
110 using torch::Tensor::copy_;
112 using torch::Tensor::zero_;
114 using torch::Tensor::requires_grad;
116 using torch::Tensor::requires_grad_;
124 using torch::Tensor::options;
126 using torch::Tensor::scalar_type;
128 using torch::Tensor::device;
130 using torch::Tensor::dim;
132 using torch::Tensor::sizes;
134 using torch::Tensor::size;
156 using torch::Tensor::index;
157 using torch::Tensor::index_put_;
187 template <
class Derived2>
190 template <
class Derived2>
214 Size _batch_dim = {};
220template <
class Derived>
221template <
class Derived2>
223 : _batch_dim(tensor.batch_dim()),
224 _batch_sizes(tensor.batch_sizes())
226 torch::Tensor::operator=(tensor);
229template <
class Derived>
230template <
class Derived2>
234 return batch_expand(
other.batch_sizes());
237template <
class Derived>
238template <
class Derived2>
242 return base_expand(
other.base_sizes());
245template <
class Derived,
246 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
250 return Derived(torch::operator+(a, b), a.batch_sizes());
253template <
class Derived,
254 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
261template <
class Derived,
262 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
270template <
class Derived,
271 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
275 return Derived(torch::operator-(a, b), a.batch_sizes());
278template <
class Derived,
279 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
286template <
class Derived,
287 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
295template <
class Derived,
296 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
300 return Derived(torch::operator*(a, b), a.batch_sizes());
303template <
class Derived,
304 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
311template <
class Derived,
312 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
316 return Derived(torch::operator/(a, b), a.batch_sizes());
319template <
class Derived,
320 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
324 return Derived(torch::operator/(a, b), b.batch_sizes());
327template <
class Derived,
328 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:54
NEML2's enhanced tensor type.
Definition TensorBase.h:46
Derived clone(torch::MemoryFormat memory_format=torch::MemoryFormat::Contiguous) const
Definition TensorBaseImpl.h:130
static Derived logspace(const Derived &start, const Derived &end, Size nstep, Size dim=0, Real base=10)
log-space equivalent of the linspace named constructor
Definition TensorBaseImpl.h:121
neml2::Tensor base_flatten() const
Flatten base dimensions.
Definition TensorBaseImpl.h:413
Derived batch_reshape(const TraceableTensorShape &batch_shape) const
Reshape batch dimensions.
Definition TensorBaseImpl.h:350
Derived batch_index(indexing::TensorIndicesRef indices) const
Get a tensor by slicing on the batch dimensions.
Definition TensorBaseImpl.h:213
neml2::Tensor base_transpose(Size d1, Size d2) const
Transpose two base dimensions.
Definition TensorBaseImpl.h:404
TraceableSize batch_size(Size index) const
Return the size of a batch axis.
Definition TensorBaseImpl.h:179
Derived batch_expand_as(const Derived2 &other) const
Expand the batch to have the same shape as another tensor.
Definition TensorBase.h:232
Size base_storage() const
Return the flattened storage needed just for the base indices.
Definition TensorBaseImpl.h:206
TensorBase()=default
Default constructor.
bool batched() const
Whether the tensor is batched.
Definition TensorBaseImpl.h:151
Size batch_dim() const
Return the number of batch dimensions.
Definition TensorBaseImpl.h:158
Derived detach() const
Discard function graph.
Definition TensorBaseImpl.h:137
Derived operator-() const
Negation.
Definition TensorBaseImpl.h:423
Derived batch_transpose(Size d1, Size d2) const
Transpose two batch dimensions.
Definition TensorBaseImpl.h:395
neml2::Tensor base_expand_copy(TensorShapeRef base_shape) const
Return a new tensor with values broadcast along the base dimensions.
Definition TensorBaseImpl.h:343
neml2::Tensor base_unsqueeze(Size d) const
Unsqueeze a base dimension.
Definition TensorBaseImpl.h:387
Derived to(const torch::TensorOptions &options) const
Change tensor options.
Definition TensorBaseImpl.h:144
Size base_size(Size index) const
Return the size of a base axis.
Definition TensorBaseImpl.h:199
static Derived zeros_like(const Derived &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:75
const TraceableTensorShape & batch_sizes() const
Return the batch size.
Definition TensorBaseImpl.h:172
static Derived empty_like(const Derived &other)
Empty tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:68
Derived batch_unsqueeze(Size d) const
Unsqueeze a batch dimension.
Definition TensorBaseImpl.h:379
static Derived linspace(const Derived &start, const Derived &end, Size nstep, Size dim=0)
Create a new tensor by adding a new batch dimension with linear spacing between start and end.
Definition TensorBaseImpl.h:96
static Derived full_like(const Derived &other, Real init)
Definition TensorBaseImpl.h:89
void batch_index_put_(indexing::TensorIndicesRef indices, const torch::Tensor &other)
Definition TensorBaseImpl.h:232
Derived batch_expand_copy(const TraceableTensorShape &batch_shape) const
Return a new tensor with values broadcast along the batch dimensions.
Definition TensorBaseImpl.h:336
Size base_dim() const
Return the number of base dimensions.
Definition TensorBaseImpl.h:165
TensorShapeRef base_sizes() const
Return the base size.
Definition TensorBaseImpl.h:192
TensorBase(const TensorBase< Derived2 > &tensor)
Copy constructor.
Definition TensorBase.h:222
neml2::Tensor base_reshape(TensorShapeRef base_shape) const
Reshape base dimensions.
Definition TensorBaseImpl.h:363
void base_index_put_(indexing::TensorIndicesRef indices, const torch::Tensor &other)
Definition TensorBaseImpl.h:251
Derived batch_expand(const TraceableTensorShape &batch_shape) const
Definition TensorBaseImpl.h:277
neml2::Tensor base_expand(TensorShapeRef base_shape) const
Return a new view of the tensor with values broadcast along the base dimensions.
Definition TensorBaseImpl.h:309
static Derived ones_like(const Derived &other)
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:82
neml2::Tensor base_index(indexing::TensorIndicesRef indices) const
Get a tensor by slicing on the base dimensions.
Definition TensorBaseImpl.h:223
Derived2 base_expand_as(const Derived2 &other) const
Expand the base to have the same shape as another tensor.
Definition TensorBase.h:240
Derived variable_data() const
Definition TensorBaseImpl.h:270
torch::ArrayRef< TensorIndex > TensorIndicesRef
Definition types.h:42
Definition CrossRef.cxx:31
Vec operator*(const Derived1 &A, const Derived2 &b)
matrix-vector product
Definition R2Base.cxx:233
auto operator/(const T1 &a, const T2 &b)
Definition Variable.h:367
Size broadcast_batch_dim(const T &...)
The batch dimension after broadcasting.
double Real
Definition types.h:31
int64_t Size
Definition types.h:33
void neml_assert_broadcastable_dbg(const T &...)
A helper function to assert (in Debug mode) that all tensors are broadcastable.
torch::IntArrayRef TensorShapeRef
Definition types.h:35
auto operator+(const T1 &a, const T2 &b)
Definition Variable.h:364
auto operator-(const T1 &a, const T2 &b)
Definition Variable.h:365
Traceable size.
Definition types.h:52
Traceable tensor shape.
Definition types.h:81