Line data Source code
1 : // Copyright 2024, UChicago Argonne, LLC
2 : // All Rights Reserved
3 : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
4 : // By: Argonne National Laboratory
5 : // OPEN SOURCE LICENSE (MIT)
6 : //
7 : // Permission is hereby granted, free of charge, to any person obtaining a copy
8 : // of this software and associated documentation files (the "Software"), to deal
9 : // in the Software without restriction, including without limitation the rights
10 : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 : // copies of the Software, and to permit persons to whom the Software is
12 : // furnished to do so, subject to the following conditions:
13 : //
14 : // The above copyright notice and this permission notice shall be included in
15 : // all copies or substantial portions of the Software.
16 : //
17 : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 : // THE SOFTWARE.
24 :
25 : #include "neml2/tensors/functions/operators.h"
26 : #include "neml2/tensors/tensors.h"
27 : #include "neml2/tensors/assertions.h"
28 :
29 : namespace neml2
30 : {
31 : ///////////////////////////////////////////////////////////////////////////////
32 : // Addition
33 : ///////////////////////////////////////////////////////////////////////////////
34 : #define DEFINE_ADD_SELF(T) \
35 : T operator+(const T & a, const T & b) \
36 : { \
37 : neml_assert_batch_broadcastable_dbg(a, b); \
38 : return T(at::operator+(a, b), utils::broadcast_batch_dim(a, b)); \
39 : } \
40 : static_assert(true)
41 :
42 : #define DEFINE_ADD_SYM_SCALAR(T) \
43 : T operator+(const T & a, const Scalar & b) \
44 : { \
45 : neml_assert_batch_broadcastable_dbg(a, b); \
46 : return T(at::operator+(a, b.base_unsqueeze_to(a.base_dim())), \
47 : utils::broadcast_batch_dim(a, b)); \
48 : } \
49 : T operator+(const Scalar & a, const T & b) { return b + a; } \
50 : static_assert(true)
51 :
52 : #define DEFINE_ADD_SYM_REAL(T) \
53 : T operator+(const T & a, const CScalar & b) { return T(at::operator+(a, b), a.batch_sizes()); } \
54 : T operator+(const CScalar & a, const T & b) { return b + a; } \
55 : static_assert(true)
56 :
57 1234 : FOR_ALL_NONSCALAR_PRIMITIVETENSOR(DEFINE_ADD_SELF);
58 2 : FOR_ALL_NONSCALAR_PRIMITIVETENSOR(DEFINE_ADD_SYM_SCALAR);
59 10 : FOR_ALL_NONSCALAR_PRIMITIVETENSOR(DEFINE_ADD_SYM_REAL);
60 1847 : DEFINE_ADD_SELF(Tensor);
61 2470 : DEFINE_ADD_SYM_SCALAR(Tensor);
62 38 : DEFINE_ADD_SYM_REAL(Tensor);
63 837 : DEFINE_ADD_SELF(Scalar);
64 2688 : DEFINE_ADD_SYM_REAL(Scalar);
65 :
66 : ///////////////////////////////////////////////////////////////////////////////
67 : // Subtraction
68 : ///////////////////////////////////////////////////////////////////////////////
69 : #define DEFINE_SUB_SELF(T) \
70 : T operator-(const T & a, const T & b) \
71 : { \
72 : neml_assert_batch_broadcastable_dbg(a, b); \
73 : return T(at::operator-(a, b), utils::broadcast_batch_dim(a, b)); \
74 : } \
75 : static_assert(true)
76 :
77 : #define DEFINE_SUB_SYM_SCALAR(T) \
78 : T operator-(const T & a, const Scalar & b) \
79 : { \
80 : neml_assert_batch_broadcastable_dbg(a, b); \
81 : return T(at::operator-(a, b.base_unsqueeze_to(a.base_dim())), \
82 : utils::broadcast_batch_dim(a, b)); \
83 : } \
84 : T operator-(const Scalar & a, const T & b) \
85 : { \
86 : neml_assert_batch_broadcastable_dbg(a, b); \
87 : return T(at::operator-(a.base_unsqueeze_to(b.base_dim()), b), \
88 : utils::broadcast_batch_dim(a, b)); \
89 : } \
90 : static_assert(true)
91 :
92 : #define DEFINE_SUB_SYM_REAL(T) \
93 : T operator-(const T & a, const CScalar & b) { return T(at::operator-(a, b), a.batch_sizes()); } \
94 : T operator-(const CScalar & a, const T & b) { return T(at::operator-(a, b), b.batch_sizes()); } \
95 : static_assert(true)
96 :
97 542 : FOR_ALL_NONSCALAR_PRIMITIVETENSOR(DEFINE_SUB_SELF);
98 0 : FOR_ALL_NONSCALAR_PRIMITIVETENSOR(DEFINE_SUB_SYM_SCALAR);
99 0 : FOR_ALL_NONSCALAR_PRIMITIVETENSOR(DEFINE_SUB_SYM_REAL);
100 2544 : DEFINE_SUB_SELF(Tensor);
101 8 : DEFINE_SUB_SYM_SCALAR(Tensor);
102 804 : DEFINE_SUB_SYM_REAL(Tensor);
103 1174 : DEFINE_SUB_SELF(Scalar);
104 2102 : DEFINE_SUB_SYM_REAL(Scalar);
105 :
106 : ///////////////////////////////////////////////////////////////////////////////
107 : // Multiplication
108 : ///////////////////////////////////////////////////////////////////////////////
109 : #define DEFINE_MUL_SELF(T) \
110 : T operator*(const T & a, const T & b) \
111 : { \
112 : neml_assert_batch_broadcastable_dbg(a, b); \
113 : return T(at::operator*(a, b), utils::broadcast_batch_dim(a, b)); \
114 : } \
115 : static_assert(true)
116 :
117 : #define DEFINE_MUL_SYM_SCALAR(T) \
118 : T operator*(const T & a, const Scalar & b) \
119 : { \
120 : neml_assert_batch_broadcastable_dbg(a, b); \
121 : return T(at::operator*(a, b.base_unsqueeze_to(a.base_dim())), \
122 : utils::broadcast_batch_dim(a, b)); \
123 : } \
124 : T operator*(const Scalar & a, const T & b) { return b * a; } \
125 : static_assert(true)
126 :
127 : #define DEFINE_MUL_SYM_REAL(T) \
128 : T operator*(const T & a, const CScalar & b) { return T(at::operator*(a, b), a.batch_sizes()); } \
129 : T operator*(const CScalar & a, const T & b) { return b * a; } \
130 : static_assert(true)
131 :
132 3316 : FOR_ALL_NONSCALAR_PRIMITIVETENSOR(DEFINE_MUL_SYM_SCALAR);
133 1204 : FOR_ALL_NONSCALAR_PRIMITIVETENSOR(DEFINE_MUL_SYM_REAL);
134 796 : DEFINE_MUL_SELF(Tensor);
135 1003 : DEFINE_MUL_SYM_SCALAR(Tensor);
136 718 : DEFINE_MUL_SYM_REAL(Tensor);
137 3518 : DEFINE_MUL_SELF(Scalar);
138 16139 : DEFINE_MUL_SYM_REAL(Scalar);
139 :
140 : ///////////////////////////////////////////////////////////////////////////////
141 : // Division
142 : ///////////////////////////////////////////////////////////////////////////////
143 : #define DEFINE_DIV_SELF(T) \
144 : T operator/(const T & a, const T & b) \
145 : { \
146 : neml_assert_batch_broadcastable_dbg(a, b); \
147 : return T(at::operator/(a, b), utils::broadcast_batch_dim(a, b)); \
148 : } \
149 : static_assert(true)
150 :
151 : #define DEFINE_DIV_SYM_SCALAR(T) \
152 : T operator/(const T & a, const Scalar & b) \
153 : { \
154 : neml_assert_batch_broadcastable_dbg(a, b); \
155 : return T(at::operator/(a, b.base_unsqueeze_to(a.base_dim())), \
156 : utils::broadcast_batch_dim(a, b)); \
157 : } \
158 : T operator/(const Scalar & a, const T & b) \
159 : { \
160 : neml_assert_batch_broadcastable_dbg(a, b); \
161 : return T(at::operator/(a.base_unsqueeze_to(b.base_dim()), b), \
162 : utils::broadcast_batch_dim(a, b)); \
163 : } \
164 : static_assert(true)
165 :
166 : #define DEFINE_DIV_SYM_REAL(T) \
167 : T operator/(const T & a, const CScalar & b) { return T(at::operator/(a, b), a.batch_sizes()); } \
168 : T operator/(const CScalar & a, const T & b) { return T(at::operator/(a, b), b.batch_sizes()); } \
169 : static_assert(true)
170 :
171 1296 : FOR_ALL_NONSCALAR_PRIMITIVETENSOR(DEFINE_DIV_SYM_SCALAR);
172 936 : FOR_ALL_NONSCALAR_PRIMITIVETENSOR(DEFINE_DIV_SYM_REAL);
173 55 : DEFINE_DIV_SELF(Tensor);
174 2472 : DEFINE_DIV_SYM_SCALAR(Tensor);
175 722 : DEFINE_DIV_SYM_REAL(Tensor);
176 783 : DEFINE_DIV_SELF(Scalar);
177 2432 : DEFINE_DIV_SYM_REAL(Scalar);
178 :
179 : ///////////////////////////////////////////////////////////////////////////////
180 : // In-place addition
181 : ///////////////////////////////////////////////////////////////////////////////
182 : #define DEFINE_ADD_EQ(T) \
183 : T & operator+=(T & a, const CScalar & b) \
184 : { \
185 : at::Tensor(a) += b; \
186 : return a; \
187 : } \
188 : static_assert(true)
189 :
190 0 : FOR_ALL_NONSCALAR_PRIMITIVETENSOR(DEFINE_ADD_EQ);
191 0 : DEFINE_ADD_EQ(Tensor);
192 0 : DEFINE_ADD_EQ(Scalar);
193 :
194 : ///////////////////////////////////////////////////////////////////////////////
195 : // In-place subtraction
196 : ///////////////////////////////////////////////////////////////////////////////
197 : #define DEFINE_SUB_EQ(T) \
198 : T & operator-=(T & a, const CScalar & b) \
199 : { \
200 : at::Tensor(a) -= b; \
201 : return a; \
202 : } \
203 : static_assert(true)
204 :
205 0 : FOR_ALL_NONSCALAR_PRIMITIVETENSOR(DEFINE_SUB_EQ);
206 0 : DEFINE_SUB_EQ(Tensor);
207 0 : DEFINE_SUB_EQ(Scalar);
208 :
209 : ///////////////////////////////////////////////////////////////////////////////
210 : // In-place multiplication
211 : ///////////////////////////////////////////////////////////////////////////////
212 : #define DEFINE_MUL_EQ(T) \
213 : T & operator*=(T & a, const CScalar & b) \
214 : { \
215 : at::Tensor(a) *= b; \
216 : return a; \
217 : } \
218 : static_assert(true)
219 :
220 0 : FOR_ALL_NONSCALAR_PRIMITIVETENSOR(DEFINE_MUL_EQ);
221 0 : DEFINE_MUL_EQ(Tensor);
222 0 : DEFINE_MUL_EQ(Scalar);
223 :
224 : ///////////////////////////////////////////////////////////////////////////////
225 : // In-place division
226 : ///////////////////////////////////////////////////////////////////////////////
227 : #define DEFINE_DIV_EQ(T) \
228 : T & operator/=(T & a, const CScalar & b) \
229 : { \
230 : at::Tensor(a) /= b; \
231 : return a; \
232 : } \
233 : static_assert(true)
234 :
235 0 : FOR_ALL_NONSCALAR_PRIMITIVETENSOR(DEFINE_DIV_EQ);
236 0 : DEFINE_DIV_EQ(Tensor);
237 0 : DEFINE_DIV_EQ(Scalar);
238 :
239 : } // namespace neml2
|