27#include "neml2/tensors/PrimitiveTensor.h"
58Scalar
abs(
const Scalar & a);
60template <
class Derived,
61 typename =
typename std::enable_if_t<!std::is_same_v<Derived, Scalar>>,
62 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
68 net.insert(
net.end(), a.base_dim(), torch::indexing::None);
72template <
class Derived,
73 typename =
typename std::enable_if_t<!std::is_same_v<Derived, Scalar>>,
74 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
81template <
class Derived,
82 typename =
typename std::enable_if_t<!std::is_same_v<Derived, Scalar>>,
83 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
89 net.insert(
net.end(), a.base_dim(), torch::indexing::None);
93template <
class Derived,
94 typename =
typename std::enable_if_t<!std::is_same_v<Derived, Scalar>>,
95 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
102template <
class Derived,
103 typename =
typename std::enable_if_t<!std::is_same_v<Derived, Scalar>>,
104 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
110 net.insert(
net.end(), a.base_dim(), torch::indexing::None);
114template <
class Derived,
115 typename =
typename std::enable_if_t<!std::is_same_v<Derived, Scalar>>,
116 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
123Scalar
operator*(
const Scalar & a,
const Scalar & b);
125template <
class Derived,
126 typename =
typename std::enable_if_t<!std::is_same_v<Derived, Scalar>>,
127 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
133 net.insert(
net.end(), a.base_dim(), torch::indexing::None);
137template <
class Derived,
138 typename =
typename std::enable_if_t<!std::is_same_v<Derived, Scalar>>,
139 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
145 net.insert(
net.end(), b.base_dim(), torch::indexing::None);
151template <
class Derived,
152 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
158 net.insert(
net.end(), a.base_dim(), torch::indexing::None);
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:54
PrimitiveTensor inherits from TensorBase and additionally templates on the base shape.
Definition PrimitiveTensor.h:38
PrimitiveTensor()=default
Default constructor.
Scalar.
Definition Scalar.h:38
Scalar(Real init, const torch::TensorOptions &options)
Definition Scalar.cxx:30
static Scalar identity_map(const torch::TensorOptions &options=default_tensor_options())
The derivative of a Scalar with respect to itself.
Definition Scalar.cxx:36
torch::SmallVector< TensorIndex > TensorIndices
Definition types.h:41
Scalar minimum(const Scalar &a, const Scalar &b)
Minimum between two scalars.
Definition Scalar.cxx:44
Tensor pow(const Real &a, const Tensor &n)
Definition math.cxx:336
Definition CrossRef.cxx:31
Vec operator*(const Derived1 &A, const Derived2 &b)
matrix-vector product
Definition R2Base.cxx:233
void neml_assert_batch_broadcastable_dbg(const T &...)
A helper function to assert that (in Debug mode) all tensors are batch-broadcastable.
auto operator/(const T1 &a, const T2 &b)
Definition Variable.h:367
Size broadcast_batch_dim(const T &...)
The batch dimension after broadcasting.
torch::TensorOptions & default_tensor_options()
Definition types.cxx:157
double Real
Definition types.h:31
auto operator+(const T1 &a, const T2 &b)
Definition Variable.h:364
Scalar abs(const Scalar &a)
Absolute value.
Definition Scalar.cxx:61
auto operator-(const T1 &a, const T2 &b)
Definition Variable.h:365