27#include <torch/csrc/autograd/variable.h>
28#include <torch/csrc/api/include/torch/detail/TensorDataContainer.h>
30#include "neml2/tensors/Tensor.h"
41template <
class Derived,
Size... S>
109template <
class Derived,
Size... S>
119template <
class Derived,
Size... S>
130template <
class Derived,
Size... S>
140template <
class Derived,
Size... S>
150template <
class Derived,
Size... S>
156template <
class Derived,
Size... S>
161 return Derived(torch::autograd::make_variable(
162 data.convert_to_tensor(options.requires_grad(
false)), options.requires_grad()));
165template <
class Derived,
Size... S>
172template <
class Derived,
Size... S>
180template <
class Derived,
Size... S>
187template <
class Derived,
Size... S>
195template <
class Derived,
Size... S>
202template <
class Derived,
Size... S>
210template <
class Derived,
Size... S>
217template <
class Derived,
Size... S>
static Derived full(const TraceableTensorShape &batch_shape, Real init, const TensorOptions &options=default_tensor_options())
Full tensor given batch shape.
Definition PrimitiveTensor.h:219
static Derived ones(const TraceableTensorShape &batch_shape, const TensorOptions &options=default_tensor_options())
Unit tensor given batch shape.
Definition PrimitiveTensor.h:204
PrimitiveTensor(const ATensor &tensor)
Construct from another ATensor and infer batch dimension.
Definition PrimitiveTensor.h:141
static Tensor identity_map(const TensorOptions &)
Derived tensor classes should define identity_map where appropriate.
Definition PrimitiveTensor.h:99
static Derived empty(const TensorOptions &options=default_tensor_options())
Unbatched empty tensor.
Definition PrimitiveTensor.h:167
PrimitiveTensor(const ATensor &tensor, const TraceableTensorShape &batch_shape)
Construct from another ATensor given batch shape.
Definition PrimitiveTensor.h:120
static Derived create(const TensorDataContainer &data, const TensorOptions &options=default_tensor_options())
Arbitrary tensor from a nested container with inferred batch dimension.
Definition PrimitiveTensor.h:158
PrimitiveTensor(const Tensor &tensor)
Copy constructor.
Definition PrimitiveTensor.h:131
static const TensorShape const_base_sizes
The base shape.
Definition PrimitiveTensor.h:46
static Derived empty(const TraceableTensorShape &batch_shape, const TensorOptions &options=default_tensor_options())
Empty tensor given batch shape.
Definition PrimitiveTensor.h:174
static Derived ones(const TensorOptions &options=default_tensor_options())
Unbatched unit tensor.
Definition PrimitiveTensor.h:197
static constexpr Size const_base_dim
The base dim.
Definition PrimitiveTensor.h:49
static Derived full(Real init, const TensorOptions &options=default_tensor_options())
Unbatched tensor filled with a given value given base shape.
Definition PrimitiveTensor.h:212
PrimitiveTensor(const ATensor &tensor, Size batch_dim)
Construct from another ATensor given batch dimension.
Definition PrimitiveTensor.h:110
static const Size const_base_storage
The base storage.
Definition PrimitiveTensor.h:52
static Derived zeros(const TensorOptions &options=default_tensor_options())
Unbatched zero tensor.
Definition PrimitiveTensor.h:182
static Derived zeros(const TraceableTensorShape &batch_shape, const TensorOptions &options=default_tensor_options())
Zero tensor given batch shape.
Definition PrimitiveTensor.h:189
PrimitiveTensor()=default
Special member functions.
TensorBase()=default
Special member functions.
Size batch_dim() const
Definition TensorBaseImpl.h:164
const TraceableTensorShape & batch_sizes() const
Return the batch size.
Definition TensorBaseImpl.h:178
TensorShapeRef base_sizes() const
Return the base size.
Definition TensorBaseImpl.h:198
static Tensor full(TensorShapeRef base_shape, Real init, const TensorOptions &options=default_tensor_options())
Unbatched tensor filled with a given value given base shape.
Definition Tensor.cxx:154
static Tensor empty(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Unbatched empty tensor given base shape.
Definition Tensor.cxx:91
static Tensor zeros(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Unbatched tensor filled with zeros given base shape.
Definition Tensor.cxx:112
static Tensor ones(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Unbatched tensor filled with ones given base shape.
Definition Tensor.cxx:133
Size storage_size(TensorShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition shape_utils.cxx:30
Definition DiagnosticsInterface.cxx:30
c10::SmallVector< Size, 8 > TensorShape
Definition types.h:71
TensorOptions default_tensor_options()
Default floating point tensor options.
Definition defaults.cxx:44
at::Tensor ATensor
Definition types.h:42
double Real
Definition types.h:68
int64_t Size
Definition types.h:69
torch::detail::TensorDataContainer TensorDataContainer
Definition PrimitiveTensor.h:34
c10::TensorOptions TensorOptions
Definition types.h:63
Traceable tensor shape.
Definition TraceableTensorShape.h:38