27#include <torch/csrc/api/include/torch/detail/TensorDataContainer.h>
29#include "neml2/misc/types.h"
30#include "neml2/tensors/TensorBase.h"
31#include "neml2/misc/defaults.h"
61 template <
class Derived>
Size dynamic_dim() const
Definition TensorBaseImpl.h:168
const TraceableTensorShape & dynamic_sizes() const
Definition TensorBaseImpl.h:203
Size intmd_dim() const
Definition TensorBaseImpl.h:182
static Tensor identity(Size n, const TensorOptions &options=default_tensor_options())
Identity tensor.
Tensor(const ATensor &tensor, const TraceableTensorShape &dynamic_shape, Size intmd_dim=0)
Construct from another ATensor with given dynamic shape.
static Tensor empty(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape, TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
static Tensor ones(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
static Tensor zeros(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
static Tensor full(TensorShapeRef base_shape, const CScalar &init, const TensorOptions &options=default_tensor_options())
static Tensor zeros(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape, TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
static Tensor rand(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape, TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
static Tensor full(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape, TensorShapeRef base_shape, const CScalar &init, const TensorOptions &options=default_tensor_options())
Tensor(const TensorBase< Derived > &tensor)
Copy from TensorBase.
Definition Tensor.h:62
static Tensor rand(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
static Tensor create(const TensorDataContainer &data, const TensorOptions &options=default_tensor_options())
Arbitrary (unbatched) tensor from a nested container.
static Tensor empty(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Tensor()=default
Special member functions.
static Tensor ones(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape, TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Tensor(const ATensor &tensor, Size dynamic_dim, Size intmd_dim=0)
Construct from another ATensor with inferred dynamic shape.
static Tensor create(const TensorDataContainer &data, Size dynamic_dim, Size intmd_dim=0, const TensorOptions &options=default_tensor_options())
Arbitrary tensor from a nested container.
TensorShape broadcast_intmd_sizes(const std::vector< Tensor > &tensors)
Find the broadcast intermediate shape of all the tensors.
TraceableTensorShape broadcast_dynamic_sizes(const std::vector< Tensor > &tensors)
Find the broadcast dynamic shape of all the tensors The returned dynamic shape will be traceable.
Definition DiagnosticsInterface.h:31
c10::SmallVector< Size, 8 > TensorShape
Definition types.h:72
TensorOptions default_tensor_options()
Default floating point tensor options.
at::Tensor ATensor
Definition types.h:42
c10::ArrayRef< neml2::Tensor > TensorList
Definition Tensor.h:37
int64_t Size
Definition types.h:71
c10::Scalar CScalar
Definition types.h:43
torch::detail::TensorDataContainer TensorDataContainer
Definition PrimitiveTensor.h:42
c10::TensorOptions TensorOptions
Definition types.h:66
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:73
Traceable tensor shape.
Definition TraceableTensorShape.h:38