27#include <ATen/TensorOperators.h>
29#include "neml2/misc/types.h"
30#include "neml2/tensors/macros.h"
31#include "neml2/tensors/tensors_fwd.h"
36#define FORWARD_DECLARATION(T) class T
40#define DECLARE_BINARY_OP(op, T1, T2, TR) TR op(const T1 & a, const T2 & b)
41#define DECLARE_BINARY_OP_SELF(op, T) DECLARE_BINARY_OP(op, T, T, T)
42#define DECLARE_BINARY_OP_SYM(op, T1, T2, TR) \
43 DECLARE_BINARY_OP(op, T1, T2, TR); \
44 DECLARE_BINARY_OP(op, T2, T1, TR)
45#define DECLARE_BINARY_OP_NONCONST(op, T1, T2, TR) TR op(T1 & a, const T2 & b)
56#define DECLARE_ADD_SELF(T) DECLARE_BINARY_OP_SELF(operator+, T)
57#define DECLARE_ADD_SYM_SCALAR(T) DECLARE_BINARY_OP_SYM(operator+, T, Scalar, T)
58#define DECLARE_ADD_SYM_REAL(T) DECLARE_BINARY_OP_SYM(operator+, T, Real, T)
63DECLARE_ADD_SYM_SCALAR(
Tensor);
64DECLARE_ADD_SYM_REAL(
Tensor);
66DECLARE_ADD_SYM_REAL(
Scalar);
67#undef DECLARE_ADD_SELF
68#undef DECLARE_ADD_SYM_SCALAR
69#undef DECLARE_ADD_SYM_REAL
80#define DECLARE_SUB_SELF(T) DECLARE_BINARY_OP_SELF(operator-, T)
81#define DECLARE_SUB_SYM_SCALAR(T) DECLARE_BINARY_OP_SYM(operator-, T, Scalar, T)
82#define DECLARE_SUB_SYM_REAL(T) DECLARE_BINARY_OP_SYM(operator-, T, Real, T)
87DECLARE_SUB_SYM_SCALAR(
Tensor);
88DECLARE_SUB_SYM_REAL(
Tensor);
90DECLARE_SUB_SYM_REAL(
Scalar);
91#undef DECLARE_SUB_SELF
92#undef DECLARE_SUB_SYM_SCALAR
93#undef DECLARE_SUB_SYM_REAL
104#define DECLARE_MUL_SELF(T) DECLARE_BINARY_OP_SELF(operator*, T)
105#define DECLARE_MUL_SYM_SCALAR(T) DECLARE_BINARY_OP_SYM(operator*, T, Scalar, T)
106#define DECLARE_MUL_SYM_REAL(T) DECLARE_BINARY_OP_SYM(operator*, T, Real, T)
110DECLARE_MUL_SYM_SCALAR(
Tensor);
111DECLARE_MUL_SYM_REAL(
Tensor);
113DECLARE_MUL_SYM_REAL(
Scalar);
114#undef DECLARE_MUL_SELF
115#undef DECLARE_MUL_SYM_SCALAR
116#undef DECLARE_MUL_SYM_REAL
127#define DECLARE_DIV_SELF(T) DECLARE_BINARY_OP_SELF(operator/, T)
128#define DECLARE_DIV_SYM_SCALAR(T) DECLARE_BINARY_OP_SYM(operator/, T, Scalar, T)
129#define DECLARE_DIV_SYM_REAL(T) DECLARE_BINARY_OP_SYM(operator/, T, Real, T)
133DECLARE_DIV_SYM_SCALAR(
Tensor);
134DECLARE_DIV_SYM_REAL(
Tensor);
136DECLARE_DIV_SYM_REAL(
Scalar);
137#undef DECLARE_DIV_SELF
138#undef DECLARE_DIV_SYM_SCALAR
139#undef DECLARE_DIV_SYM_REAL
150#define DECLARE_ADD_EQ(T) DECLARE_BINARY_OP_NONCONST(operator+=, T, Real, T &)
165#define DECLARE_SUB_EQ(T) DECLARE_BINARY_OP_NONCONST(operator-=, T, Real, T &)
180#define DECLARE_MUL_EQ(T) DECLARE_BINARY_OP_NONCONST(operator*=, T, Real, T &)
195#define DECLARE_DIV_EQ(T) DECLARE_BINARY_OP_NONCONST(operator/=, T, Real, T &)
202#undef DECLARE_BINARY_OP
203#undef DECLARE_BINARY_OP_SELF
204#undef DECLARE_BINARY_OP_SYM
Scalar.
Definition Scalar.h:38
Definition DiagnosticsInterface.cxx:30
FOR_ALL_TENSORBASE(INSTANTIATE_TENSORNAME)
FOR_ALL_NONSCALAR_PRIMITIVETENSOR(DEFINE_ADD_SELF)