27#include "neml2/tensors/TensorBase.h"
40torch::Dtype
same_dtype(
const std::vector<Tensor> & tensors);
43torch::Device
same_device(
const std::vector<Tensor> & tensors);
59 template <
class Derived2>
115Tensor
bmm(
const Tensor & a,
const Tensor & b);
124Tensor
bmv(
const Tensor & a,
const Tensor & v);
133Tensor
bvv(
const Tensor & a,
const Tensor & b);
136Tensor
operator*(
const Tensor & a,
const Tensor & b);
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:54
NEML2's enhanced tensor type.
Definition TensorBase.h:46
Size batch_dim() const
Return the number of batch dimensions.
Definition TensorBaseImpl.h:158
static Tensor full(TensorShapeRef base_shape, Real init, const torch::TensorOptions &options=default_tensor_options())
Unbatched tensor filled with a given value given base shape.
Definition Tensor.cxx:170
Tensor(const TensorBase< Derived2 > &tensor)
Copy constructor.
Definition Tensor.h:60
static Tensor ones(TensorShapeRef base_shape, const torch::TensorOptions &options=default_tensor_options())
Unbatched tensor filled with ones given base shape.
Definition Tensor.cxx:149
static Tensor empty(TensorShapeRef base_shape, const torch::TensorOptions &options=default_tensor_options())
Unbatched empty tensor given base shape.
Definition Tensor.cxx:107
static Tensor identity(Size n, const torch::TensorOptions &options=default_tensor_options())
Unbatched identity tensor.
Definition Tensor.cxx:192
Tensor()=default
Default constructor.
static Tensor zeros(TensorShapeRef base_shape, const torch::TensorOptions &options=default_tensor_options())
Unbatched tensor filled with zeros given base shape.
Definition Tensor.cxx:128
Tensor bmv(const Tensor &a, const Tensor &v)
Batched matrix-vector product.
Definition Tensor.cxx:223
Tensor bvv(const Tensor &a, const Tensor &b)
Batched vector-vector (dot) product.
Definition Tensor.cxx:238
Tensor bmm(const Tensor &a, const Tensor &b)
Batched matrix-matrix product.
Definition Tensor.cxx:208
torch::Dtype same_dtype(const std::vector< Tensor > &tensors)
Make sure all tensors have the same dtype and return the common dtype.
Definition Tensor.cxx:56
TraceableTensorShape broadcast_batch_sizes(const std::vector< Tensor > &tensors)
Find the broadcast batch shape of all the tensors The returned batch shape will be traceable.
Definition Tensor.cxx:33
torch::Device same_device(const std::vector< Tensor > &tensors)
Make sure all tensors have the same device and return the common device.
Definition Tensor.cxx:76
Definition CrossRef.cxx:31
Vec operator*(const Derived1 &A, const Derived2 &b)
matrix-vector product
Definition R2Base.cxx:233
torch::TensorOptions & default_tensor_options()
Definition types.cxx:157
int64_t Size
Definition types.h:33
torch::IntArrayRef TensorShapeRef
Definition types.h:35
Traceable tensor shape.
Definition types.h:81