27#include "neml2/tensors/Tensor.h"
28#include "neml2/models/utils.h"
61 virtual
void operator=(const
Tensor & val) = 0;
87 : _value(std::move(value))
virtual Size base_dim() const =0
TensorValueBase(const TensorValueBase &)=default
TensorValueBase()=default
TensorValueBase(TensorValueBase &&) noexcept=default
Size _cached_intmd_dim
Cached intermediate dimension.
Definition TensorValue.h:78
virtual void assign(const ATensor &val, TracerPrivilege key)=0
Secret assignment operator used by low-level operations such as jit tracing.
virtual TensorType type() const =0
Tensor type.
virtual void to_(const TensorOptions &)=0
Send the value to the target options.
virtual void requires_grad_(bool req=true)=0
Require grad.
void requires_grad_(bool req=true) override
Require grad.
void operator=(const Tensor &val) override
assignment operator
void to_(const TensorOptions &options) override
Send the value to the target options.
Size base_dim() const override
TensorType type() const override
Tensor type.
void assign(const ATensor &val, TracerPrivilege key) override
Secret assignment operator used by low-level operations such as jit tracing.
const T & operator()() const
Definition TensorValue.h:97
TensorValue(T value)
Definition TensorValue.h:86
Definition DiagnosticsInterface.h:31
at::Tensor ATensor
Definition types.h:42
int64_t Size
Definition types.h:71
TensorType
Definition tensors.h:56
c10::TensorOptions TensorOptions
Definition types.h:66
A passkey to allow trusted classes to perform raw assignment to variables and parameters.
Definition utils.h:35