27#include <torch/csrc/api/include/torch/detail/TensorDataContainer.h>
29#include "neml2/tensors/TensorBase.h"
30#include "neml2/misc/defaults.h"
31#include "neml2/jit/types.h"
58 template <
class Derived>
Size batch_dim() const
Definition TensorBaseImpl.h:164
const TraceableTensorShape & batch_sizes() const
Definition TensorBaseImpl.h:178
static Tensor identity(Size n, const TensorOptions &options=default_tensor_options())
Unbatched identity tensor.
Definition Tensor.cxx:176
static Tensor full(TensorShapeRef base_shape, Real init, const TensorOptions &options=default_tensor_options())
Unbatched tensor filled with a given value given base shape.
Definition Tensor.cxx:154
Tensor(const TensorBase< Derived > &tensor)
Copy from TensorBase.
Definition Tensor.h:59
static Tensor empty(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Unbatched empty tensor given base shape.
Definition Tensor.cxx:91
static Tensor zeros(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Unbatched tensor filled with zeros given base shape.
Definition Tensor.cxx:112
static Tensor create(const TensorDataContainer &data, const TensorOptions &options=default_tensor_options())
Arbitrary (unbatched) tensor from a nested container.
Definition Tensor.cxx:77
static Tensor ones(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Unbatched tensor filled with ones given base shape.
Definition Tensor.cxx:133
Tensor()=default
Special member functions.
TraceableTensorShape broadcast_batch_sizes(const std::vector< Tensor > &tensors)
Find the broadcast batch shape of all the tensors The returned batch shape will be traceable.
Definition Tensor.cxx:43
Definition DiagnosticsInterface.cxx:30
TensorOptions default_tensor_options()
Default floating point tensor options.
Definition defaults.cxx:44
at::Tensor ATensor
Definition types.h:42
double Real
Definition types.h:68
int64_t Size
Definition types.h:69
torch::detail::TensorDataContainer TensorDataContainer
Definition PrimitiveTensor.h:34
c10::TensorOptions TensorOptions
Definition types.h:63
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:72
Traceable tensor shape.
Definition TraceableTensorShape.h:38