27#include <torch/csrc/api/include/torch/detail/TensorDataContainer.h>
29#include "neml2/misc/types.h"
30#include "neml2/tensors/TensorBase.h"
31#include "neml2/misc/defaults.h"
59 template <
class Derived>
NEML2's enhanced tensor type.
Definition TensorBase.h:77
Size dynamic_dim() const
Definition TensorBaseImpl.h:168
const TraceableTensorShape & dynamic_sizes() const
Definition TensorBaseImpl.h:203
Size intmd_dim() const
Definition TensorBaseImpl.h:182
static Tensor identity(Size n, const TensorOptions &options=default_tensor_options())
Identity tensor.
Definition Tensor.cxx:221
static Tensor full(TensorShapeRef base_shape, const CScalar &init, const TensorOptions &options=default_tensor_options())
Definition Tensor.cxx:172
Tensor(const TensorBase< Derived > &tensor)
Copy from TensorBase.
Definition Tensor.h:60
static Tensor empty(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Definition Tensor.cxx:99
static Tensor zeros(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Definition Tensor.cxx:124
static Tensor create(const TensorDataContainer &data, const TensorOptions &options=default_tensor_options())
Arbitrary (unbatched) tensor from a nested container.
Definition Tensor.cxx:80
static Tensor ones(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Definition Tensor.cxx:148
Tensor()=default
Special member functions.
static Tensor rand(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Definition Tensor.cxx:197
TraceableTensorShape broadcast_dynamic_sizes(const std::vector< Tensor > &tensors)
Find the broadcast dynamic shape of all the tensors The returned dynamic shape will be traceable.
Definition Tensor.cxx:46
Definition DiagnosticsInterface.cxx:30
TensorOptions default_tensor_options()
Default floating point tensor options.
Definition defaults.cxx:42
at::Tensor ATensor
Definition types.h:38
c10::ArrayRef< neml2::Tensor > TensorList
Definition Tensor.h:37
int64_t Size
Definition types.h:65
c10::Scalar CScalar
Definition types.h:39
c10::TensorOptions TensorOptions
Definition types.h:60
torch::detail::TensorDataContainer TensorDataContainer
Definition PrimitiveTensor.h:42
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:67
Traceable tensor shape.
Definition TraceableTensorShape.h:38