27#include <ATen/ScalarOps.h>
28#include <torch/csrc/autograd/variable.h>
29#include <torch/csrc/api/include/torch/detail/TensorDataContainer.h>
33#include "neml2/misc/defaults.h"
34#include "neml2/misc/errors.h"
35#include "neml2/misc/types.h"
36#include "neml2/tensors/Tensor.h"
37#include "neml2/tensors/functions/utils.h"
38#include "neml2/tensors/functions/stack.h"
49template <
class Derived,
Size... S>
80 template <
class Derived2>
135 template <
typename... Args,
138 [[nodiscard]]
static Derived
fill(Args &&... args);
145 template <
typename... Args>
157template <
class Derived,
Size... S>
159 :
TensorBase<Derived>(tensor, tensor.dim() - const_base_dim - intmd_dim, intmd_dim)
164template <
class Derived,
Size... S>
168 :
TensorBase<Derived>(tensor, dynamic_dim, intmd_dim)
171 throw NEMLException(
"Inconsistent dimensions when constructing PrimitiveTensor. Expected "
172 "tensor to have dynamic dimension " +
173 std::to_string(
dynamic_dim) +
", intmd dimension " +
174 std::to_string(
intmd_dim) +
", and base dimension " +
176 std::to_string(tensor.dim()) +
" dimensions.");
180template <
class Derived,
Size... S>
184 :
TensorBase<Derived>(tensor, dynamic_shape, intmd_dim)
189template <
class Derived,
Size... S>
190template <
class Derived2>
197template <
class Derived,
Size... S>
202 if (this->base_sizes() != const_base_sizes)
207template <
class Derived,
Size... S>
210 return neml2::Tensor(*
this, this->dynamic_sizes(), this->intmd_dim());
213template <
class Derived,
Size... S>
219 return Derived(torch::autograd::make_variable(
220 data.convert_to_tensor(options.requires_grad(
false)), options.requires_grad()),
225template <
class Derived,
Size... S>
232template <
class Derived,
Size... S>
238 return Tensor::empty(dynamic_shape, intmd_shape, const_base_sizes, options);
241template <
class Derived,
Size... S>
248template <
class Derived,
Size... S>
254 return Tensor::zeros(dynamic_shape, intmd_shape, const_base_sizes, options);
257template <
class Derived,
Size... S>
264template <
class Derived,
Size... S>
270 return Tensor::ones(dynamic_shape, intmd_shape, const_base_sizes, options);
273template <
class Derived,
Size... S>
280template <
class Derived,
Size... S>
287 return Tensor::full(dynamic_shape, intmd_shape, const_base_sizes, init, options);
290template <
class Derived,
Size... S>
297template <
class Derived,
Size... S>
303 return Tensor::rand(dynamic_shape, intmd_shape, const_base_sizes, options);
306template <
class Derived,
Size... S>
311 std::vector<ATensor> tensors_einsum(tensors_aligned.size());
312 for (std::size_t j = 0; j < tensors_aligned.size(); ++j)
313 tensors_einsum[j] = tensors_aligned[j];
314 auto res = at::einsum(equation, tensors_einsum);
315 return Derived(res, i);
318template <
class Tuple, std::size_t... I>
322 return std::vector<neml2::Tensor>{
323 neml2::Tensor(at::scalar_to_tensor(std::get<I>(std::forward<Tuple>(t)), options.device())
324 .to(options.dtype()),
328template <
class Derived,
Size... S>
329template <
typename... Args,
typename>
333 if constexpr (
sizeof...(Args) == const_base_numel)
335 if constexpr ((std::is_convertible_v<Args, neml2::Tensor> && ...))
339 "All input tensors must be scalar-like (no base dimensions)");
341 return base_stack({std::forward<Args>(args)...}).base_reshape(const_base_sizes);
343 else if constexpr ((std::is_convertible_v<Args, CScalar> && ...))
349 else if constexpr (
sizeof...(Args) == const_base_numel + 1)
351 auto tup = std::forward_as_tuple(std::forward<Args>(args)...);
352 const auto & options = std::get<
sizeof...(Args) - 1>(tup);
353 auto vals =
make_tensors(tup, std::make_index_sequence<
sizeof...(Args) - 1>{}, options);
357 throw NEMLException(
"Invalid argument types to PrimitiveTensor::fill");
PrimitiveTensor inherits from TensorBase and additionally templates on the base shape.
Definition PrimitiveTensor.h:51
static Derived full(const CScalar &init, const TensorOptions &options=default_tensor_options())
Definition PrimitiveTensor.h:275
static Derived empty(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape={}, const TensorOptions &options=default_tensor_options())
Definition PrimitiveTensor.h:234
static Derived empty(const TensorOptions &options=default_tensor_options())
Definition PrimitiveTensor.h:227
PrimitiveTensor(const TensorBase< Derived2 > &tensor)
Copy constructor.
Definition PrimitiveTensor.h:191
PrimitiveTensor(const ATensor &tensor, Size dynamic_dim, Size intmd_dim)
Construct from an ATensor and extract dynamic shape given dynamic dimension.
Definition PrimitiveTensor.h:165
static Derived zeros(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape={}, const TensorOptions &options=default_tensor_options())
Definition PrimitiveTensor.h:250
static const TensorShape const_base_sizes
The base shape.
Definition PrimitiveTensor.h:57
static Derived rand(const TensorOptions &options=default_tensor_options())
Definition PrimitiveTensor.h:292
static Derived ones(const TensorOptions &options=default_tensor_options())
Definition PrimitiveTensor.h:259
Scalar operator()(Args... i) const
Single-element accessor.
Definition Scalar.h:53
std::integer_sequence< Size, S... > base_sizes_sequence
Base shape sequence.
Definition PrimitiveTensor.h:54
static constexpr Size const_base_dim
The base dim.
Definition PrimitiveTensor.h:60
static Derived fill(Args &&... args)
Definition PrimitiveTensor.h:331
static Derived einsum(c10::string_view equation, TensorList tensors)
Einstein summation along base dimensions.
Definition PrimitiveTensor.h:308
static Derived ones(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape={}, const TensorOptions &options=default_tensor_options())
Definition PrimitiveTensor.h:266
static Derived rand(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape, const TensorOptions &options=default_tensor_options())
Definition PrimitiveTensor.h:299
PrimitiveTensor(const ATensor &tensor, Size intmd_dim)
Construct from an ATensor and infer dynamic shape.
Definition PrimitiveTensor.h:158
void validate_shapes_and_dims() const
Validate shapes and dimensions.
Definition PrimitiveTensor.h:199
PrimitiveTensor(const ATensor &tensor, const TraceableTensorShape &dynamic_shape, Size intmd_dim)
Construct from an ATensor given dynamic shape.
Definition PrimitiveTensor.h:181
static Derived zeros(const TensorOptions &options=default_tensor_options())
Definition PrimitiveTensor.h:243
PrimitiveTensor()=default
Special member functions.
static constexpr Size const_base_numel
The base numel.
Definition PrimitiveTensor.h:63
static Derived create(const TensorDataContainer &data, Size intmd_dim=0, const TensorOptions &options=default_tensor_options())
Arbitrary tensor from a nested container with inferred batch dimension.
Definition PrimitiveTensor.h:215
static Derived full(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape, const CScalar &init, const TensorOptions &options=default_tensor_options())
Definition PrimitiveTensor.h:282
Scalar.
Definition Scalar.h:38
NEML2's enhanced tensor type.
Definition TensorBase.h:77
Size dynamic_dim() const
Definition TensorBaseImpl.h:168
Size intmd_dim() const
Definition TensorBaseImpl.h:182
neml2::Tensor base_reshape(TensorShapeRef shape) const
Definition TensorBaseImpl.h:603
static Tensor full(TensorShapeRef base_shape, const CScalar &init, const TensorOptions &options=default_tensor_options())
Definition Tensor.cxx:172
static Tensor empty(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Definition Tensor.cxx:99
static Tensor zeros(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Definition Tensor.cxx:124
static Tensor create(const TensorDataContainer &data, const TensorOptions &options=default_tensor_options())
Arbitrary (unbatched) tensor from a nested container.
Definition Tensor.cxx:80
static Tensor ones(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Definition Tensor.cxx:148
static Tensor rand(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Definition Tensor.cxx:197
std::pair< std::vector< Tensor >, Size > align_intmd_dim(TensorList tensors)
Definition utils.cxx:30
Definition DiagnosticsInterface.cxx:30
void neml_assert_dbg(bool assertion, Args &&... args)
Definition assertions.h:60
c10::SmallVector< Size, 8 > TensorShape
Definition types.h:66
TensorOptions default_tensor_options()
Default floating point tensor options.
Definition defaults.cxx:42
at::Tensor ATensor
Definition types.h:38
c10::ArrayRef< neml2::Tensor > TensorList
Definition Tensor.h:37
auto make_tensors(Tuple &&t, std::index_sequence< I... >, const TensorOptions &options)
Definition PrimitiveTensor.h:320
int64_t Size
Definition types.h:65
c10::Scalar CScalar
Definition types.h:39
Tensor base_stack(const std::vector< Tensor > &tensors, Size d)
Definition stack.cxx:64
c10::TensorOptions TensorOptions
Definition types.h:60
torch::detail::TensorDataContainer TensorDataContainer
Definition PrimitiveTensor.h:42
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:67
Traceable tensor shape.
Definition TraceableTensorShape.h:38