27#include <torch/csrc/autograd/variable.h>
28#include <torch/csrc/api/include/torch/detail/TensorDataContainer.h>
30#include "neml2/misc/errors.h"
31#include "neml2/tensors/Tensor.h"
32#include "neml2/tensors/shape_utils.h"
43template <
class Derived,
Size... S>
111template <
class Derived,
Size... S>
121template <
class Derived,
Size... S>
132template <
class Derived,
Size... S>
142template <
class Derived,
Size... S>
144 :
TensorBase<Derived>(tensor, tensor.dim() - const_base_dim)
152template <
class Derived,
Size... S>
155 return Tensor(*
this, this->batch_sizes());
158template <
class Derived,
Size... S>
163 return Derived(torch::autograd::make_variable(
164 data.convert_to_tensor(options.requires_grad(
false)), options.requires_grad()));
167template <
class Derived,
Size... S>
174template <
class Derived,
Size... S>
179 return Tensor::empty(batch_shape, const_base_sizes, options);
182template <
class Derived,
Size... S>
189template <
class Derived,
Size... S>
194 return Tensor::zeros(batch_shape, const_base_sizes, options);
197template <
class Derived,
Size... S>
204template <
class Derived,
Size... S>
209 return Tensor::ones(batch_shape, const_base_sizes, options);
212template <
class Derived,
Size... S>
219template <
class Derived,
Size... S>
225 return Tensor::full(batch_shape, const_base_sizes, init, options);
PrimitiveTensor inherits from TensorBase and additionally templates on the base shape.
Definition PrimitiveTensor.h:45
static Derived full(const CScalar &init, const TensorOptions &options=default_tensor_options())
Unbatched tensor filled with a given value given base shape.
Definition PrimitiveTensor.h:214
static Derived ones(const TraceableTensorShape &batch_shape, const TensorOptions &options=default_tensor_options())
Unit tensor given batch shape.
Definition PrimitiveTensor.h:206
PrimitiveTensor(const ATensor &tensor)
Construct from another ATensor and infer batch dimension.
Definition PrimitiveTensor.h:143
static Tensor identity_map(const TensorOptions &)
Derived tensor classes should define identity_map where appropriate.
Definition PrimitiveTensor.h:101
static Derived empty(const TensorOptions &options=default_tensor_options())
Unbatched empty tensor.
Definition PrimitiveTensor.h:169
PrimitiveTensor(const ATensor &tensor, const TraceableTensorShape &batch_shape)
Construct from another ATensor given batch shape.
Definition PrimitiveTensor.h:122
static Derived create(const TensorDataContainer &data, const TensorOptions &options=default_tensor_options())
Arbitrary tensor from a nested container with inferred batch dimension.
Definition PrimitiveTensor.h:160
PrimitiveTensor(const Tensor &tensor)
Copy constructor.
Definition PrimitiveTensor.h:133
static Derived full(const TraceableTensorShape &batch_shape, const CScalar &init, const TensorOptions &options=default_tensor_options())
Full tensor given batch shape.
Definition PrimitiveTensor.h:221
static const TensorShape const_base_sizes
The base shape.
Definition PrimitiveTensor.h:48
static Derived empty(const TraceableTensorShape &batch_shape, const TensorOptions &options=default_tensor_options())
Empty tensor given batch shape.
Definition PrimitiveTensor.h:176
static Derived ones(const TensorOptions &options=default_tensor_options())
Unbatched unit tensor.
Definition PrimitiveTensor.h:199
static constexpr Size const_base_dim
The base dim.
Definition PrimitiveTensor.h:51
PrimitiveTensor(const ATensor &tensor, Size batch_dim)
Construct from another ATensor given batch dimension.
Definition PrimitiveTensor.h:112
static const Size const_base_storage
The base storage.
Definition PrimitiveTensor.h:54
static Derived zeros(const TensorOptions &options=default_tensor_options())
Unbatched zero tensor.
Definition PrimitiveTensor.h:184
static Derived zeros(const TraceableTensorShape &batch_shape, const TensorOptions &options=default_tensor_options())
Zero tensor given batch shape.
Definition PrimitiveTensor.h:191
PrimitiveTensor()=default
Special member functions.
NEML2's enhanced tensor type.
Definition TensorBase.h:50
Size batch_dim() const
Return the number of batch dimensions.
Definition TensorBaseImpl.h:168
TensorShapeRef base_sizes() const
Return the base size.
Definition TensorBaseImpl.h:202
static Tensor full(TensorShapeRef base_shape, const CScalar &init, const TensorOptions &options=default_tensor_options())
Unbatched tensor filled with a given value given base shape.
Definition Tensor.cxx:156
static Tensor empty(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Unbatched empty tensor given base shape.
Definition Tensor.cxx:93
static Tensor zeros(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Unbatched tensor filled with zeros given base shape.
Definition Tensor.cxx:114
static Tensor ones(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Unbatched tensor filled with ones given base shape.
Definition Tensor.cxx:135
Size storage_size(TensorShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition shape_utils.cxx:32
Definition DiagnosticsInterface.cxx:30
c10::SmallVector< Size, 8 > TensorShape
Definition types.h:66
TensorOptions default_tensor_options()
Default floating point tensor options.
Definition defaults.cxx:42
int64_t Size
Definition types.h:65
c10::Scalar CScalar
Definition types.h:39
at::Tensor ATensor
Definition types.h:38
c10::TensorOptions TensorOptions
Definition types.h:60
torch::detail::TensorDataContainer TensorDataContainer
Definition PrimitiveTensor.h:36
Traceable tensor shape.
Definition TraceableTensorShape.h:38