Line data Source code
1 : // Copyright 2024, UChicago Argonne, LLC
2 : // All Rights Reserved
3 : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
4 : // By: Argonne National Laboratory
5 : // OPEN SOURCE LICENSE (MIT)
6 : //
7 : // Permission is hereby granted, free of charge, to any person obtaining a copy
8 : // of this software and associated documentation files (the "Software"), to deal
9 : // in the Software without restriction, including without limitation the rights
10 : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 : // copies of the Software, and to permit persons to whom the Software is
12 : // furnished to do so, subject to the following conditions:
13 : //
14 : // The above copyright notice and this permission notice shall be included in
15 : // all copies or substantial portions of the Software.
16 : //
17 : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 : // THE SOFTWARE.
24 :
25 : #include "neml2/models/Variable.h"
26 : #include "neml2/models/Model.h"
27 : #include "neml2/models/DependencyResolver.h"
28 : #include "neml2/tensors/tensors.h"
29 : #include "neml2/misc/assertions.h"
30 : #include "neml2/tensors/functions/bmm.h"
31 : #include "neml2/jit/utils.h"
32 : #include "neml2/jit/TraceableTensorShape.h"
33 :
34 : namespace neml2
35 : {
36 1764 : VariableBase::VariableBase(VariableName name_in, Model * owner, TensorShapeRef list_shape)
37 1764 : : _name(std::move(name_in)),
38 1764 : _owner(owner),
39 3528 : _list_sizes(list_shape)
40 : {
41 1764 : }
42 :
43 : const Model &
44 12303 : VariableBase::owner() const
45 : {
46 12303 : neml_assert_dbg(_owner, "Owner of variable '", name(), "' has not been defined.");
47 12303 : return *_owner;
48 : }
49 :
50 : Model &
51 10217 : VariableBase::owner()
52 : {
53 10217 : neml_assert_dbg(_owner, "Owner of variable '", name(), "' has not been defined.");
54 10217 : return *_owner;
55 : }
56 :
57 : bool
58 845 : VariableBase::is_state() const
59 : {
60 845 : return _name.is_state();
61 : }
62 :
63 : bool
64 166 : VariableBase::is_old_state() const
65 : {
66 166 : return _name.is_old_state();
67 : }
68 :
69 : bool
70 171 : VariableBase::is_force() const
71 : {
72 171 : return _name.is_force();
73 : }
74 :
75 : bool
76 48 : VariableBase::is_old_force() const
77 : {
78 48 : return _name.is_old_force();
79 : }
80 :
81 : bool
82 136 : VariableBase::is_residual() const
83 : {
84 136 : return _name.is_residual();
85 : }
86 :
87 : bool
88 90 : VariableBase::is_parameter() const
89 : {
90 90 : return _name.is_parameter();
91 : }
92 :
93 : bool
94 397 : VariableBase::is_solve_dependent() const
95 : {
96 397 : return is_state() || is_residual() || is_parameter();
97 : }
98 :
99 : bool
100 1573 : VariableBase::is_dependent() const
101 : {
102 1573 : return !currently_solving_nonlinear_system() || is_solve_dependent();
103 : }
104 :
105 : TensorOptions
106 399 : VariableBase::options() const
107 : {
108 399 : return tensor().options();
109 : }
110 :
111 : Dtype
112 48 : VariableBase::scalar_type() const
113 : {
114 48 : return tensor().scalar_type();
115 : }
116 :
117 : Device
118 0 : VariableBase::device() const
119 : {
120 0 : return tensor().device();
121 : }
122 :
123 : Size
124 0 : VariableBase::dim() const
125 : {
126 0 : return tensor().dim();
127 : }
128 :
129 : TensorShapeRef
130 0 : VariableBase::sizes() const
131 : {
132 0 : return tensor().sizes();
133 : }
134 :
135 : Size
136 0 : VariableBase::size(Size dim) const
137 : {
138 0 : return tensor().size(dim);
139 : }
140 :
141 : bool
142 0 : VariableBase::batched() const
143 : {
144 0 : return tensor().batched();
145 : }
146 :
147 : Size
148 36971 : VariableBase::batch_dim() const
149 : {
150 36971 : return tensor().batch_dim();
151 : }
152 :
153 : Size
154 674697 : VariableBase::list_dim() const
155 : {
156 674697 : return Size(list_sizes().size());
157 : }
158 :
159 : Size
160 0 : VariableBase::base_dim() const
161 : {
162 0 : return Size(base_sizes().size());
163 : }
164 :
165 : TraceableTensorShape
166 1 : VariableBase::batch_sizes() const
167 : {
168 1 : return tensor().batch_sizes();
169 : }
170 :
171 : TensorShapeRef
172 729996 : VariableBase::list_sizes() const
173 : {
174 729996 : return _list_sizes;
175 : }
176 :
177 : TraceableSize
178 0 : VariableBase::batch_size(Size dim) const
179 : {
180 0 : return tensor().batch_size(dim);
181 : }
182 :
183 : Size
184 0 : VariableBase::base_size(Size dim) const
185 : {
186 0 : return base_sizes()[dim];
187 : }
188 :
189 : Size
190 1 : VariableBase::list_size(Size dim) const
191 : {
192 1 : return list_sizes()[dim];
193 : }
194 :
195 : Size
196 0 : VariableBase::base_storage() const
197 : {
198 0 : return utils::storage_size(base_sizes());
199 : }
200 :
201 : Size
202 2869 : VariableBase::assembly_storage() const
203 : {
204 2869 : return utils::storage_size(list_sizes()) * utils::storage_size(base_sizes());
205 : }
206 :
207 : bool
208 0 : VariableBase::requires_grad() const
209 : {
210 0 : return tensor().requires_grad();
211 : }
212 :
213 : Derivative
214 816 : VariableBase::d(const VariableBase & var)
215 : {
216 816 : neml_assert_dbg(owning(),
217 : "Cannot assign derivative to a referencing variable '",
218 816 : name(),
219 : "' with respect to '",
220 816 : var.name(),
221 : "'.");
222 816 : return Derivative({assembly_storage(), var.assembly_storage()}, &_derivs[var.name()]);
223 : }
224 :
225 : Derivative
226 233 : VariableBase::d(const VariableBase & var1, const VariableBase & var2)
227 : {
228 233 : neml_assert_dbg(owning(),
229 : "Cannot assign second derivative to a referencing variable '",
230 233 : name(),
231 : "' with respect to '",
232 233 : var1.name(),
233 : "' and '",
234 233 : var2.name(),
235 : "'.");
236 932 : return Derivative({assembly_storage(), var1.assembly_storage(), var2.assembly_storage()},
237 233 : &_sec_derivs[var1.name()][var2.name()]);
238 : }
239 :
240 : void
241 0 : VariableBase::request_AD(const VariableBase & u)
242 : {
243 0 : owner().request_AD(*this, u);
244 0 : }
245 :
246 : void
247 4 : VariableBase::request_AD(const std::vector<const VariableBase *> & us)
248 : {
249 18 : for (const auto & u : us)
250 : {
251 14 : neml_assert(u, "Cannot request AD for a null variable.");
252 14 : owner().request_AD(*this, *u);
253 : }
254 4 : }
255 :
256 : void
257 0 : VariableBase::request_AD(const VariableBase & u1, const VariableBase & u2)
258 : {
259 0 : owner().request_AD(*this, u1, u2);
260 0 : }
261 :
262 : void
263 3 : VariableBase::request_AD(const std::vector<const VariableBase *> & u1s,
264 : const std::vector<const VariableBase *> & u2s)
265 : {
266 15 : for (const auto & u1 : u1s)
267 60 : for (const auto & u2 : u2s)
268 : {
269 48 : neml_assert(u1, "Cannot request AD for a null variable.");
270 48 : neml_assert(u2, "Cannot request AD for a null variable.");
271 48 : owner().request_AD(*this, *u1, *u2);
272 : }
273 3 : }
274 :
275 : void
276 20175 : VariableBase::clear()
277 : {
278 20175 : neml_assert_dbg(owning(), "Cannot clear a referencing variable '", name(), "'.");
279 20175 : clear_derivatives();
280 20175 : }
281 :
282 : void
283 20781 : VariableBase::clear_derivatives()
284 : {
285 20781 : _derivs.clear();
286 20781 : _sec_derivs.clear();
287 20781 : }
288 :
289 : void
290 85 : VariableBase::apply_chain_rule(const DependencyResolver<Model, VariableName> & dep)
291 : {
292 109 : for (const auto & [model, var] : dep.outbound_items())
293 109 : if (var == name())
294 : {
295 85 : _derivs = total_derivatives(dep, model, var);
296 85 : return;
297 : }
298 : }
299 :
300 : void
301 16 : VariableBase::apply_second_order_chain_rule(const DependencyResolver<Model, VariableName> & dep)
302 : {
303 16 : for (const auto & [model, var] : dep.outbound_items())
304 16 : if (var == name())
305 : {
306 16 : _sec_derivs = total_second_derivatives(dep, model, var);
307 16 : return;
308 : }
309 : }
310 :
311 : static void
312 1025 : assign_or_add(Tensor & dest, const Tensor & val)
313 : {
314 1025 : if (dest.defined())
315 72 : dest = dest + val;
316 : else
317 953 : dest = val;
318 1025 : }
319 :
320 : ValueMap
321 418 : VariableBase::total_derivatives(const DependencyResolver<Model, VariableName> & dep,
322 : Model * model,
323 : const VariableName & yvar) const
324 : {
325 418 : ValueMap derivs;
326 :
327 1121 : for (const auto & [uvar, dy_du] : model->output_variable(yvar).derivatives())
328 : {
329 703 : if (dep.inbound_items().count({model, uvar}))
330 479 : assign_or_add(derivs[uvar], dy_du);
331 : else
332 448 : for (const auto & depu : dep.item_providers().at({model, uvar}))
333 639 : for (const auto & [xvar, du_dx] : total_derivatives(dep, depu.parent, uvar))
334 639 : assign_or_add(derivs[xvar], bmm(dy_du, du_dx));
335 : }
336 :
337 418 : return derivs;
338 0 : }
339 :
340 : DerivMap
341 48 : VariableBase::total_second_derivatives(const DependencyResolver<Model, VariableName> & dep,
342 : Model * model,
343 : const VariableName & yvar) const
344 : {
345 48 : DerivMap sec_derivs;
346 :
347 83 : for (const auto & [u1var, d2y_du1] : model->output_variable(yvar).second_derivatives())
348 131 : for (const auto & [u2var, d2y_du1u2] : d2y_du1)
349 : {
350 96 : if (dep.inbound_items().count({model, u1var}) && dep.inbound_items().count({model, u2var}))
351 25 : assign_or_add(sec_derivs[u1var][u2var], d2y_du1u2);
352 71 : else if (dep.inbound_items().count({model, u1var}))
353 38 : for (const auto & depu2 : dep.item_providers().at({model, u2var}))
354 38 : for (const auto & [x2var, du2_dxk] : total_derivatives(dep, depu2.parent, u2var))
355 19 : assign_or_add(sec_derivs[u1var][x2var],
356 95 : Tensor(at::einsum("...ijq,...qk", {d2y_du1u2, du2_dxk}),
357 19 : utils::broadcast_batch_dim(d2y_du1u2, du2_dxk)));
358 52 : else if (dep.inbound_items().count({model, u2var}))
359 36 : for (const auto & depu1 : dep.item_providers().at({model, u1var}))
360 36 : for (const auto & [x1var, du1_dxj] : total_derivatives(dep, depu1.parent, u1var))
361 18 : assign_or_add(sec_derivs[x1var][u2var],
362 90 : Tensor(at::einsum("...ipk,...pj", {d2y_du1u2, du1_dxj}),
363 18 : utils::broadcast_batch_dim(d2y_du1u2, du1_dxj)));
364 : else
365 68 : for (const auto & depu1 : dep.item_providers().at({model, u1var}))
366 72 : for (const auto & [x1var, du1_dxj] : total_derivatives(dep, depu1.parent, u1var))
367 76 : for (const auto & depu2 : dep.item_providers().at({model, u2var}))
368 84 : for (const auto & [x2var, du2_dxk] : total_derivatives(dep, depu2.parent, u2var))
369 46 : assign_or_add(
370 46 : sec_derivs[x1var][x2var],
371 276 : Tensor(at::einsum("...ipq,...pj,...qk", {d2y_du1u2, du1_dxj, du2_dxk}),
372 72 : utils::broadcast_batch_dim(d2y_du1u2, du1_dxj, du2_dxk)));
373 : }
374 :
375 124 : for (const auto & [uvar, dy_du] : model->output_variable(yvar).derivatives())
376 76 : if (!dep.inbound_items().count({model, uvar}))
377 64 : for (const auto & depu : dep.item_providers().at({model, uvar}))
378 47 : for (const auto & [x1var, d2u_dx1] : total_second_derivatives(dep, depu.parent, uvar))
379 38 : for (const auto & [x2var, d2u_dx1x2] : d2u_dx1)
380 23 : assign_or_add(sec_derivs[x1var][x2var],
381 115 : Tensor(at::einsum("...ip,...pjk", {dy_du, d2u_dx1x2}),
382 32 : utils::broadcast_batch_dim(dy_du, d2u_dx1x2)));
383 :
384 48 : return sec_derivs;
385 106 : }
386 :
387 : template <typename T>
388 : TensorType
389 1630 : Variable<T>::type() const
390 : {
391 1630 : return TensorTypeEnum<T>::value;
392 : }
393 :
394 : template <typename T>
395 : std::unique_ptr<VariableBase>
396 538 : Variable<T>::clone(const VariableName & name, Model * owner) const
397 : {
398 : if constexpr (std::is_same_v<T, Tensor>)
399 : {
400 : return std::move(std::make_unique<Variable<Tensor>>(
401 : name.empty() ? this->name() : name, owner ? owner : _owner, list_sizes(), base_sizes()));
402 : }
403 : else
404 : {
405 1614 : return std::move(std::make_unique<Variable<T>>(
406 2152 : name.empty() ? this->name() : name, owner ? owner : _owner, list_sizes()));
407 : }
408 : }
409 :
410 : template <typename T>
411 : void
412 805 : Variable<T>::ref(const VariableBase & var, bool ref_is_mutable)
413 : {
414 805 : neml_assert(!_ref || ref() == var.ref(),
415 : "Variable '",
416 805 : name(),
417 : "' cannot reference another variable '",
418 805 : var.name(),
419 : "' after it has been assigned a reference. \nThe "
420 : "existing reference '",
421 805 : ref()->name(),
422 : "' was declared by model '",
423 805 : ref()->owner().name(),
424 : "'. \nThe new reference is declared by model '",
425 805 : var.owner().name(),
426 : "'.");
427 805 : neml_assert(&var != this, "Variable '", name(), "' cannot reference itself.");
428 805 : neml_assert(var.ref() != this,
429 : "Variable '",
430 805 : name(),
431 : "' cannot reference a variable that is referencing itself.");
432 805 : const auto * var_ptr = dynamic_cast<const Variable<T> *>(var.ref());
433 805 : neml_assert(var_ptr,
434 : "Variable ",
435 805 : name(),
436 : " of type ",
437 805 : type(),
438 : " failed to reference another variable named ",
439 805 : var.name(),
440 : " of type ",
441 805 : var.type(),
442 : ": Dynamic cast failure.");
443 805 : _ref = var_ptr;
444 805 : _ref_is_mutable |= ref_is_mutable;
445 805 : }
446 :
447 : template <typename T>
448 : void
449 20179 : Variable<T>::zero(const TensorOptions & options)
450 : {
451 20179 : if (owning())
452 : {
453 : if constexpr (std::is_same_v<T, Tensor>)
454 : _value = T::zeros(list_sizes(), base_sizes(), options);
455 : else
456 20179 : _value = T::zeros(list_sizes(), options);
457 : }
458 : else
459 : {
460 0 : neml_assert_dbg(_ref_is_mutable,
461 : "Model '",
462 0 : owner().name(),
463 : "' is trying to zero a variable '",
464 0 : name(),
465 : "' declared by model '",
466 0 : ref()->owner().name(),
467 : "' , but the referenced variable is not mutable.");
468 : // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
469 0 : const_cast<VariableBase *>(ref())->zero(options);
470 : }
471 20179 : }
472 :
473 : template <typename T>
474 : void
475 20513 : Variable<T>::set(const Tensor & val)
476 : {
477 20513 : if (owning())
478 15849 : _value = T(val.base_reshape(utils::add_shapes(list_sizes(), base_sizes())),
479 31698 : utils::add_traceable_shapes(val.batch_sizes(), list_sizes()));
480 : else
481 : {
482 4664 : neml_assert_dbg(_ref_is_mutable,
483 : "Model '",
484 4664 : owner().name(),
485 : "' is trying to assign value to a variable '",
486 4664 : name(),
487 : "' declared by model '",
488 4664 : ref()->owner().name(),
489 : "' , but the referenced variable is not mutable.");
490 : // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
491 4664 : const_cast<VariableBase *>(ref())->set(val);
492 : }
493 20513 : }
494 :
495 : template <typename T>
496 : void
497 14389 : Variable<T>::set(const ATensor & val, bool force)
498 : {
499 14389 : if (owning())
500 : {
501 : if constexpr (std::is_same_v<T, Tensor>)
502 : _value = T(val, val.dim() - base_dim());
503 : else
504 8898 : _value = T(val);
505 : }
506 : else
507 : {
508 5491 : neml_assert_dbg(_ref_is_mutable || force,
509 : "Model '",
510 5491 : owner().name(),
511 : "' is trying to assign value to a variable '",
512 5491 : name(),
513 : "' declared by model '",
514 5491 : ref()->owner().name(),
515 : "' , but the referenced variable is not mutable.");
516 : // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
517 5491 : const_cast<VariableBase *>(ref())->set(val);
518 : }
519 14389 : }
520 :
521 : template <typename T>
522 : Tensor
523 0 : Variable<T>::get() const
524 : {
525 0 : return tensor().base_flatten();
526 : }
527 :
528 : template <typename T>
529 : Tensor
530 1092351 : Variable<T>::tensor() const
531 : {
532 1092351 : if (owning())
533 : {
534 674697 : neml_assert_dbg(_value.defined(), "Variable '", name(), "' has undefined value.");
535 674697 : auto batch_sizes = _value.batch_sizes().slice(0, _value.batch_dim() - list_dim());
536 674697 : return Tensor(_value, batch_sizes);
537 674697 : }
538 :
539 417654 : return ref()->tensor();
540 : }
541 :
542 : template <typename T>
543 : void
544 6 : Variable<T>::requires_grad_(bool req)
545 : {
546 6 : if (owning())
547 6 : _value.requires_grad_(req);
548 : else
549 : // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
550 0 : const_cast<VariableBase *>(ref())->requires_grad_(req);
551 6 : }
552 :
553 : template <typename T>
554 : void
555 684 : Variable<T>::operator=(const Tensor & val)
556 : {
557 684 : if (owning())
558 684 : _value = T(val);
559 : else
560 : {
561 0 : neml_assert_dbg(_ref_is_mutable,
562 : "Model '",
563 0 : owner().name(),
564 : "' is trying to assign value to a variable '",
565 0 : name(),
566 : "' declared by model '",
567 0 : ref()->owner().name(),
568 : "' , but the referenced variable is not mutable.");
569 : // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
570 0 : *const_cast<VariableBase *>(ref()) = val;
571 : }
572 684 : }
573 :
574 : template <typename T>
575 : void
576 20175 : Variable<T>::clear()
577 : {
578 20175 : if (owning())
579 : {
580 20175 : VariableBase::clear();
581 20175 : _value = T();
582 : }
583 : else
584 : {
585 0 : neml_assert_dbg(_ref_is_mutable,
586 : "Model '",
587 0 : owner().name(),
588 : "' is trying to clear a variable '",
589 0 : name(),
590 : "' declared by model '",
591 0 : ref()->owner().name(),
592 : "' , but the referenced variable is not mutable.");
593 : // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
594 0 : const_cast<VariableBase *>(ref())->clear();
595 : }
596 20175 : }
597 :
598 : #define INSTANTIATE_VARIABLE(T) template class Variable<T>
599 : FOR_ALL_PRIMITIVETENSOR(INSTANTIATE_VARIABLE);
600 :
601 : Derivative &
602 1049 : Derivative::operator=(const Tensor & val)
603 : {
604 1049 : if (!_deriv->defined())
605 1047 : *_deriv = val.base_reshape(_base_sizes);
606 : else
607 2 : *_deriv = *_deriv + val.base_reshape(_base_sizes);
608 1049 : return *this;
609 : }
610 : }
|