27#include <torch/csrc/api/include/torch/detail/TensorDataContainer.h>
29#include "neml2/misc/types.h"
30#include "neml2/tensors/TensorBase.h"
31#include "neml2/misc/defaults.h"
58 template <
class Derived>
NEML2's enhanced tensor type.
Definition TensorBase.h:50
Size batch_dim() const
Definition TensorBaseImpl.h:168
const TraceableTensorShape & batch_sizes() const
Definition TensorBaseImpl.h:182
Tensor base_unsqueeze_to(Size n) const
Expand base dimension to a given size.
Definition Tensor.cxx:190
static Tensor identity(Size n, const TensorOptions &options=default_tensor_options())
Unbatched identity tensor.
Definition Tensor.cxx:178
static Tensor full(TensorShapeRef base_shape, const CScalar &init, const TensorOptions &options=default_tensor_options())
Unbatched tensor filled with a given value given base shape.
Definition Tensor.cxx:156
Tensor(const TensorBase< Derived > &tensor)
Copy from TensorBase.
Definition Tensor.h:59
static Tensor empty(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Unbatched empty tensor given base shape.
Definition Tensor.cxx:93
static Tensor zeros(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Unbatched tensor filled with zeros given base shape.
Definition Tensor.cxx:114
static Tensor create(const TensorDataContainer &data, const TensorOptions &options=default_tensor_options())
Arbitrary (unbatched) tensor from a nested container.
Definition Tensor.cxx:79
static Tensor ones(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Unbatched tensor filled with ones given base shape.
Definition Tensor.cxx:135
Tensor()=default
Special member functions.
TraceableTensorShape broadcast_batch_sizes(const std::vector< Tensor > &tensors)
Find the broadcast batch shape of all the tensors The returned batch shape will be traceable.
Definition Tensor.cxx:45
Definition DiagnosticsInterface.cxx:30
TensorOptions default_tensor_options()
Default floating point tensor options.
Definition defaults.cxx:42
int64_t Size
Definition types.h:65
c10::Scalar CScalar
Definition types.h:39
at::Tensor ATensor
Definition types.h:38
c10::TensorOptions TensorOptions
Definition types.h:60
torch::detail::TensorDataContainer TensorDataContainer
Definition PrimitiveTensor.h:36
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:67
Traceable tensor shape.
Definition TraceableTensorShape.h:38