27#include "neml2/tensors/Tensor.h"
36template <
class Derived,
Size... S>
59 template <
class Derived2>
109template <
class Derived,
Size... S>
114 "Base shape mismatch: trying to create a tensor with base shape ",
116 " from a tensor with base shape ",
120template <
class Derived,
Size...
S>
126 "Base shape mismatch: trying to create a tensor with base shape ",
128 " from a tensor with base shape ",
132template <
class Derived,
Size...
S>
133template <
class Derived2>
138 "Base shape mismatch: trying to create a tensor with base shape ",
140 " from a tensor with base shape ",
144template <
class Derived,
Size...
S>
146 :
TensorBase<Derived>(tensor, tensor.dim() - const_base_dim)
149 "Base shape mismatch: trying to create a tensor with base shape ",
151 " from a tensor with shape ",
155template <
class Derived,
Size...
S>
158 return Tensor(*
this, this->batch_sizes());
161template <
class Derived,
Size...
S>
168template <
class Derived,
Size...
S>
171 const torch::TensorOptions & options)
176template <
class Derived,
Size...
S>
183template <
class Derived,
Size...
S>
186 const torch::TensorOptions & options)
191template <
class Derived,
Size...
S>
198template <
class Derived,
Size...
S>
201 const torch::TensorOptions & options)
206template <
class Derived,
Size...
S>
213template <
class Derived,
Size...
S>
217 const torch::TensorOptions & options)
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:54
PrimitiveTensor inherits from TensorBase and additionally templates on the base shape.
Definition PrimitiveTensor.h:38
static Derived full(Real init, const torch::TensorOptions &options=default_tensor_options())
Unbatched tensor filled with a given value given base shape.
Definition PrimitiveTensor.h:208
static Derived ones(const TraceableTensorShape &batch_shape, const torch::TensorOptions &options=default_tensor_options())
Unit tensor given batch shape.
Definition PrimitiveTensor.h:200
PrimitiveTensor(const TensorBase< Derived2 > &tensor)
Copy constructor.
Definition PrimitiveTensor.h:134
PrimitiveTensor(const torch::Tensor &tensor)
Construct from another torch::Tensor and infer batch dimension.
Definition PrimitiveTensor.h:145
PrimitiveTensor(const torch::Tensor &tensor, Size batch_dim)
Construct from another torch::Tensor given batch dimension.
Definition PrimitiveTensor.h:110
static Derived empty(const torch::TensorOptions &options=default_tensor_options())
Unbatched empty tensor.
Definition PrimitiveTensor.h:163
static const TensorShape const_base_sizes
The base shape.
Definition PrimitiveTensor.h:41
static Derived zeros(const torch::TensorOptions &options=default_tensor_options())
Unbatched zero tensor.
Definition PrimitiveTensor.h:178
static constexpr Size const_base_dim
The base dim.
Definition PrimitiveTensor.h:44
PrimitiveTensor(const torch::Tensor &tensor, const TraceableTensorShape &batch_shape)
Construct from another torch::Tensor given batch shape.
Definition PrimitiveTensor.h:121
static Derived ones(const torch::TensorOptions &options=default_tensor_options())
Unbatched unit tensor.
Definition PrimitiveTensor.h:193
static const Size const_base_storage
The base storage.
Definition PrimitiveTensor.h:47
static Tensor identity_map(const torch::TensorOptions &)
Derived tensor classes should define identity_map where appropriate.
Definition PrimitiveTensor.h:99
static Derived full(const TraceableTensorShape &batch_shape, Real init, const torch::TensorOptions &options=default_tensor_options())
Full tensor given batch shape.
Definition PrimitiveTensor.h:215
PrimitiveTensor()=default
Default constructor.
static Derived empty(const TraceableTensorShape &batch_shape, const torch::TensorOptions &options=default_tensor_options())
Empty tensor given batch shape.
Definition PrimitiveTensor.h:170
static Derived zeros(const TraceableTensorShape &batch_shape, const torch::TensorOptions &options=default_tensor_options())
Zero tensor given batch shape.
Definition PrimitiveTensor.h:185
NEML2's enhanced tensor type.
Definition TensorBase.h:46
Size batch_dim() const
Return the number of batch dimensions.
Definition TensorBaseImpl.h:158
TensorShapeRef base_sizes() const
Return the base size.
Definition TensorBaseImpl.h:192
static Tensor full(TensorShapeRef base_shape, Real init, const torch::TensorOptions &options=default_tensor_options())
Unbatched tensor filled with a given value given base shape.
Definition Tensor.cxx:170
static Tensor ones(TensorShapeRef base_shape, const torch::TensorOptions &options=default_tensor_options())
Unbatched tensor filled with ones given base shape.
Definition Tensor.cxx:149
static Tensor empty(TensorShapeRef base_shape, const torch::TensorOptions &options=default_tensor_options())
Unbatched empty tensor given base shape.
Definition Tensor.cxx:107
static Tensor zeros(TensorShapeRef base_shape, const torch::TensorOptions &options=default_tensor_options())
Unbatched tensor filled with zeros given base shape.
Definition Tensor.cxx:128
Size storage_size(TensorShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:55
Definition CrossRef.cxx:31
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:76
torch::TensorOptions & default_tensor_options()
Definition types.cxx:157
double Real
Definition types.h:31
torch::SmallVector< Size > TensorShape
Definition types.h:34
int64_t Size
Definition types.h:33
Traceable tensor shape.
Definition types.h:81