Line data Source code
1 : // Copyright 2024, UChicago Argonne, LLC
2 : // All Rights Reserved
3 : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
4 : // By: Argonne National Laboratory
5 : // OPEN SOURCE LICENSE (MIT)
6 : //
7 : // Permission is hereby granted, free of charge, to any person obtaining a copy
8 : // of this software and associated documentation files (the "Software"), to deal
9 : // in the Software without restriction, including without limitation the rights
10 : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 : // copies of the Software, and to permit persons to whom the Software is
12 : // furnished to do so, subject to the following conditions:
13 : //
14 : // The above copyright notice and this permission notice shall be included in
15 : // all copies or substantial portions of the Software.
16 : //
17 : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 : // THE SOFTWARE.
24 :
25 : #include "neml2/models/Variable.h"
26 : #include "neml2/models/Model.h"
27 : #include "neml2/models/DependencyResolver.h"
28 : #include "neml2/models/map_types.h"
29 : #include "neml2/tensors/tensors.h"
30 : #include "neml2/tensors/assertions.h"
31 : #include "neml2/tensors/functions/bmm.h"
32 : #include "neml2/jit/utils.h"
33 : #include "neml2/jit/TraceableTensorShape.h"
34 :
35 : namespace neml2
36 : {
37 1579 : VariableBase::VariableBase(VariableName name_in, Model * owner, TensorShapeRef list_shape)
38 1579 : : _name(std::move(name_in)),
39 1579 : _owner(owner),
40 3158 : _list_sizes(list_shape)
41 : {
42 1579 : }
43 :
44 : const Model &
45 11906 : VariableBase::owner() const
46 : {
47 11906 : neml_assert_dbg(_owner, "Owner of variable '", name(), "' has not been defined.");
48 11906 : return *_owner;
49 : }
50 :
51 : Model &
52 10043 : VariableBase::owner()
53 : {
54 10043 : neml_assert_dbg(_owner, "Owner of variable '", name(), "' has not been defined.");
55 10043 : return *_owner;
56 : }
57 :
58 : bool
59 826 : VariableBase::is_state() const
60 : {
61 826 : return _name.is_state();
62 : }
63 :
64 : bool
65 163 : VariableBase::is_old_state() const
66 : {
67 163 : return _name.is_old_state();
68 : }
69 :
70 : bool
71 167 : VariableBase::is_force() const
72 : {
73 167 : return _name.is_force();
74 : }
75 :
76 : bool
77 47 : VariableBase::is_old_force() const
78 : {
79 47 : return _name.is_old_force();
80 : }
81 :
82 : bool
83 134 : VariableBase::is_residual() const
84 : {
85 134 : return _name.is_residual();
86 : }
87 :
88 : bool
89 88 : VariableBase::is_parameter() const
90 : {
91 88 : return _name.is_parameter();
92 : }
93 :
94 : bool
95 388 : VariableBase::is_solve_dependent() const
96 : {
97 388 : return is_state() || is_residual() || is_parameter();
98 : }
99 :
100 : bool
101 1427 : VariableBase::is_dependent() const
102 : {
103 1427 : return !currently_solving_nonlinear_system() || is_solve_dependent();
104 : }
105 :
106 : TensorOptions
107 304 : VariableBase::options() const
108 : {
109 304 : return tensor().options();
110 : }
111 :
112 : Dtype
113 46 : VariableBase::scalar_type() const
114 : {
115 46 : return tensor().scalar_type();
116 : }
117 :
118 : Device
119 0 : VariableBase::device() const
120 : {
121 0 : return tensor().device();
122 : }
123 :
124 : Size
125 0 : VariableBase::dim() const
126 : {
127 0 : return tensor().dim();
128 : }
129 :
130 : TensorShapeRef
131 0 : VariableBase::sizes() const
132 : {
133 0 : return tensor().sizes();
134 : }
135 :
136 : Size
137 0 : VariableBase::size(Size dim) const
138 : {
139 0 : return tensor().size(dim);
140 : }
141 :
142 : bool
143 0 : VariableBase::batched() const
144 : {
145 0 : return tensor().batched();
146 : }
147 :
148 : Size
149 35906 : VariableBase::batch_dim() const
150 : {
151 35906 : return tensor().batch_dim();
152 : }
153 :
154 : Size
155 632828 : VariableBase::list_dim() const
156 : {
157 632828 : return Size(list_sizes().size());
158 : }
159 :
160 : Size
161 0 : VariableBase::base_dim() const
162 : {
163 0 : return Size(base_sizes().size());
164 : }
165 :
166 : TraceableTensorShape
167 1 : VariableBase::batch_sizes() const
168 : {
169 1 : return tensor().batch_sizes();
170 : }
171 :
172 : TensorShapeRef
173 684232 : VariableBase::list_sizes() const
174 : {
175 684232 : return _list_sizes;
176 : }
177 :
178 : TraceableSize
179 0 : VariableBase::batch_size(Size dim) const
180 : {
181 0 : return tensor().batch_size(dim);
182 : }
183 :
184 : Size
185 0 : VariableBase::base_size(Size dim) const
186 : {
187 0 : return base_sizes()[dim];
188 : }
189 :
190 : Size
191 1 : VariableBase::list_size(Size dim) const
192 : {
193 1 : return list_sizes()[dim];
194 : }
195 :
196 : Size
197 0 : VariableBase::base_storage() const
198 : {
199 0 : return utils::storage_size(base_sizes());
200 : }
201 :
202 : Size
203 2585 : VariableBase::assembly_storage() const
204 : {
205 2585 : return utils::storage_size(list_sizes()) * utils::storage_size(base_sizes());
206 : }
207 :
208 : bool
209 0 : VariableBase::requires_grad() const
210 : {
211 0 : return tensor().requires_grad();
212 : }
213 :
214 : Derivative
215 745 : VariableBase::d(const VariableBase & var)
216 : {
217 745 : neml_assert_dbg(owning(),
218 : "Cannot assign derivative to a referencing variable '",
219 745 : name(),
220 : "' with respect to '",
221 745 : var.name(),
222 : "'.");
223 745 : return Derivative({assembly_storage(), var.assembly_storage()}, &_derivs[var.name()]);
224 : }
225 :
226 : Derivative
227 206 : VariableBase::d(const VariableBase & var1, const VariableBase & var2)
228 : {
229 206 : neml_assert_dbg(owning(),
230 : "Cannot assign second derivative to a referencing variable '",
231 206 : name(),
232 : "' with respect to '",
233 206 : var1.name(),
234 : "' and '",
235 206 : var2.name(),
236 : "'.");
237 824 : return Derivative({assembly_storage(), var1.assembly_storage(), var2.assembly_storage()},
238 206 : &_sec_derivs[var1.name()][var2.name()]);
239 : }
240 :
241 : void
242 0 : VariableBase::request_AD(const VariableBase & u)
243 : {
244 0 : owner().request_AD(*this, u);
245 0 : }
246 :
247 : void
248 4 : VariableBase::request_AD(const std::vector<const VariableBase *> & us)
249 : {
250 18 : for (const auto & u : us)
251 : {
252 14 : neml_assert(u, "Cannot request AD for a null variable.");
253 14 : owner().request_AD(*this, *u);
254 : }
255 4 : }
256 :
257 : void
258 0 : VariableBase::request_AD(const VariableBase & u1, const VariableBase & u2)
259 : {
260 0 : owner().request_AD(*this, u1, u2);
261 0 : }
262 :
263 : void
264 3 : VariableBase::request_AD(const std::vector<const VariableBase *> & u1s,
265 : const std::vector<const VariableBase *> & u2s)
266 : {
267 15 : for (const auto & u1 : u1s)
268 60 : for (const auto & u2 : u2s)
269 : {
270 48 : neml_assert(u1, "Cannot request AD for a null variable.");
271 48 : neml_assert(u2, "Cannot request AD for a null variable.");
272 48 : owner().request_AD(*this, *u1, *u2);
273 : }
274 3 : }
275 :
276 : void
277 18315 : VariableBase::clear()
278 : {
279 18315 : neml_assert_dbg(owning(), "Cannot clear a referencing variable '", name(), "'.");
280 18315 : _derivs.clear();
281 18315 : _sec_derivs.clear();
282 18315 : }
283 :
284 : void
285 75 : VariableBase::apply_chain_rule(const DependencyResolver<Model, VariableName> & dep)
286 : {
287 96 : for (const auto & [model, var] : dep.outbound_items())
288 96 : if (var == name())
289 : {
290 75 : _derivs = total_derivatives(dep, model, var);
291 75 : return;
292 : }
293 : }
294 :
295 : void
296 16 : VariableBase::apply_second_order_chain_rule(const DependencyResolver<Model, VariableName> & dep)
297 : {
298 16 : for (const auto & [model, var] : dep.outbound_items())
299 16 : if (var == name())
300 : {
301 16 : _sec_derivs = total_second_derivatives(dep, model, var);
302 16 : return;
303 : }
304 : }
305 :
306 : void
307 975 : assign_or_add(Tensor & dest, const Tensor & val)
308 : {
309 975 : if (dest.defined())
310 69 : dest = dest + val;
311 : else
312 906 : dest = val;
313 975 : }
314 :
315 : ValueMap
316 397 : VariableBase::total_derivatives(const DependencyResolver<Model, VariableName> & dep,
317 : Model * model,
318 : const VariableName & yvar) const
319 : {
320 397 : ValueMap derivs;
321 :
322 1056 : for (const auto & [uvar, dy_du] : model->output_variable(yvar).derivatives())
323 : {
324 659 : if (dep.inbound_items().count({model, uvar}))
325 446 : assign_or_add(derivs[uvar], dy_du);
326 : else
327 426 : for (const auto & depu : dep.item_providers().at({model, uvar}))
328 611 : for (const auto & [xvar, du_dx] : total_derivatives(dep, depu.parent, uvar))
329 611 : assign_or_add(derivs[xvar], bmm(dy_du, du_dx));
330 : }
331 :
332 397 : return derivs;
333 0 : }
334 :
335 : DerivMap
336 48 : VariableBase::total_second_derivatives(const DependencyResolver<Model, VariableName> & dep,
337 : Model * model,
338 : const VariableName & yvar) const
339 : {
340 48 : DerivMap sec_derivs;
341 :
342 83 : for (const auto & [u1var, d2y_du1] : model->output_variable(yvar).second_derivatives())
343 131 : for (const auto & [u2var, d2y_du1u2] : d2y_du1)
344 : {
345 96 : if (dep.inbound_items().count({model, u1var}) && dep.inbound_items().count({model, u2var}))
346 25 : assign_or_add(sec_derivs[u1var][u2var], d2y_du1u2);
347 71 : else if (dep.inbound_items().count({model, u1var}))
348 38 : for (const auto & depu2 : dep.item_providers().at({model, u2var}))
349 38 : for (const auto & [x2var, du2_dxk] : total_derivatives(dep, depu2.parent, u2var))
350 19 : assign_or_add(sec_derivs[u1var][x2var],
351 95 : Tensor(at::einsum("...ijq,...qk", {d2y_du1u2, du2_dxk}),
352 19 : utils::broadcast_batch_dim(d2y_du1u2, du2_dxk)));
353 52 : else if (dep.inbound_items().count({model, u2var}))
354 36 : for (const auto & depu1 : dep.item_providers().at({model, u1var}))
355 36 : for (const auto & [x1var, du1_dxj] : total_derivatives(dep, depu1.parent, u1var))
356 18 : assign_or_add(sec_derivs[x1var][u2var],
357 90 : Tensor(at::einsum("...ipk,...pj", {d2y_du1u2, du1_dxj}),
358 18 : utils::broadcast_batch_dim(d2y_du1u2, du1_dxj)));
359 : else
360 68 : for (const auto & depu1 : dep.item_providers().at({model, u1var}))
361 72 : for (const auto & [x1var, du1_dxj] : total_derivatives(dep, depu1.parent, u1var))
362 76 : for (const auto & depu2 : dep.item_providers().at({model, u2var}))
363 84 : for (const auto & [x2var, du2_dxk] : total_derivatives(dep, depu2.parent, u2var))
364 46 : assign_or_add(
365 46 : sec_derivs[x1var][x2var],
366 276 : Tensor(at::einsum("...ipq,...pj,...qk", {d2y_du1u2, du1_dxj, du2_dxk}),
367 72 : utils::broadcast_batch_dim(d2y_du1u2, du1_dxj, du2_dxk)));
368 : }
369 :
370 124 : for (const auto & [uvar, dy_du] : model->output_variable(yvar).derivatives())
371 76 : if (!dep.inbound_items().count({model, uvar}))
372 64 : for (const auto & depu : dep.item_providers().at({model, uvar}))
373 47 : for (const auto & [x1var, d2u_dx1] : total_second_derivatives(dep, depu.parent, uvar))
374 38 : for (const auto & [x2var, d2u_dx1x2] : d2u_dx1)
375 23 : assign_or_add(sec_derivs[x1var][x2var],
376 115 : Tensor(at::einsum("...ip,...pjk", {dy_du, d2u_dx1x2}),
377 32 : utils::broadcast_batch_dim(dy_du, d2u_dx1x2)));
378 :
379 48 : return sec_derivs;
380 106 : }
381 :
382 : template <typename T>
383 : TensorType
384 1468 : Variable<T>::type() const
385 : {
386 1468 : return TensorTypeEnum<T>::value;
387 : }
388 :
389 : template <typename T>
390 : std::unique_ptr<VariableBase>
391 477 : Variable<T>::clone(const VariableName & name, Model * owner) const
392 : {
393 : if constexpr (std::is_same_v<T, Tensor>)
394 : {
395 : return std::move(std::make_unique<Variable<Tensor>>(
396 : name.empty() ? this->name() : name, owner ? owner : _owner, list_sizes(), base_sizes()));
397 : }
398 : else
399 : {
400 1431 : return std::move(std::make_unique<Variable<T>>(
401 1908 : name.empty() ? this->name() : name, owner ? owner : _owner, list_sizes()));
402 : }
403 : }
404 :
405 : template <typename T>
406 : void
407 724 : Variable<T>::ref(const VariableBase & var, bool ref_is_mutable)
408 : {
409 724 : neml_assert(!_ref || ref() == var.ref(),
410 : "Variable '",
411 724 : name(),
412 : "' cannot reference another variable '",
413 724 : var.name(),
414 : "' after it has been assigned a reference. \nThe "
415 : "existing reference '",
416 724 : ref()->name(),
417 : "' was declared by model '",
418 724 : ref()->owner().name(),
419 : "'. \nThe new reference is declared by model '",
420 724 : var.owner().name(),
421 : "'.");
422 724 : neml_assert(&var != this, "Variable '", name(), "' cannot reference itself.");
423 724 : neml_assert(var.ref() != this,
424 : "Variable '",
425 724 : name(),
426 : "' cannot reference a variable that is referencing itself.");
427 724 : const auto * var_ptr = dynamic_cast<const Variable<T> *>(var.ref());
428 724 : neml_assert(var_ptr,
429 : "Variable ",
430 724 : name(),
431 : " of type ",
432 724 : type(),
433 : " failed to reference another variable named ",
434 724 : var.name(),
435 : " of type ",
436 724 : var.type(),
437 : ": Dynamic cast failure.");
438 724 : _ref = var_ptr;
439 724 : _ref_is_mutable |= ref_is_mutable;
440 724 : }
441 :
442 : template <typename T>
443 : void
444 18315 : Variable<T>::zero(const TensorOptions & options)
445 : {
446 18315 : if (owning())
447 : {
448 : if constexpr (std::is_same_v<T, Tensor>)
449 : _value = T::zeros(list_sizes(), base_sizes(), options);
450 : else
451 18315 : _value = T::zeros(list_sizes(), options);
452 : }
453 : else
454 : {
455 0 : neml_assert_dbg(_ref_is_mutable,
456 : "Model '",
457 0 : owner().name(),
458 : "' is trying to zero a variable '",
459 0 : name(),
460 : "' declared by model '",
461 0 : ref()->owner().name(),
462 : "' , but the referenced variable is not mutable.");
463 : // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
464 0 : const_cast<VariableBase *>(ref())->zero(options);
465 : }
466 18315 : }
467 :
468 : template <typename T>
469 : void
470 19643 : Variable<T>::set(const Tensor & val)
471 : {
472 19643 : if (owning())
473 15006 : _value = T(val.base_reshape(utils::add_shapes(list_sizes(), base_sizes())),
474 30012 : utils::add_traceable_shapes(val.batch_sizes(), list_sizes()));
475 : else
476 : {
477 4637 : neml_assert_dbg(_ref_is_mutable,
478 : "Model '",
479 4637 : owner().name(),
480 : "' is trying to assign value to a variable '",
481 4637 : name(),
482 : "' declared by model '",
483 4637 : ref()->owner().name(),
484 : "' , but the referenced variable is not mutable.");
485 : // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
486 4637 : const_cast<VariableBase *>(ref())->set(val);
487 : }
488 19643 : }
489 :
490 : template <typename T>
491 : void
492 13558 : Variable<T>::set(const ATensor & val, bool force)
493 : {
494 13558 : if (owning())
495 : {
496 : if constexpr (std::is_same_v<T, Tensor>)
497 : _value = T(val, val.dim() - base_dim());
498 : else
499 8214 : _value = T(val);
500 : }
501 : else
502 : {
503 5344 : neml_assert_dbg(_ref_is_mutable || force,
504 : "Model '",
505 5344 : owner().name(),
506 : "' is trying to assign value to a variable '",
507 5344 : name(),
508 : "' declared by model '",
509 5344 : ref()->owner().name(),
510 : "' , but the referenced variable is not mutable.");
511 : // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
512 5344 : const_cast<VariableBase *>(ref())->set(val);
513 : }
514 13558 : }
515 :
516 : template <typename T>
517 : Tensor
518 0 : Variable<T>::get() const
519 : {
520 0 : return tensor().base_flatten();
521 : }
522 :
523 : template <typename T>
524 : Tensor
525 1045983 : Variable<T>::tensor() const
526 : {
527 1045983 : if (owning())
528 : {
529 632828 : neml_assert_dbg(_value.defined(), "Variable '", name(), "' has undefined value.");
530 632828 : auto batch_sizes = _value.batch_sizes().slice(0, _value.batch_dim() - list_dim());
531 632828 : return Tensor(_value, batch_sizes);
532 632828 : }
533 :
534 413155 : return ref()->tensor();
535 : }
536 :
537 : template <typename T>
538 : void
539 6 : Variable<T>::requires_grad_(bool req)
540 : {
541 6 : if (owning())
542 6 : _value.requires_grad_(req);
543 : else
544 : // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
545 0 : const_cast<VariableBase *>(ref())->requires_grad_(req);
546 6 : }
547 :
548 : template <typename T>
549 : void
550 618 : Variable<T>::operator=(const Tensor & val)
551 : {
552 618 : if (owning())
553 618 : _value = T(val);
554 : else
555 : {
556 0 : neml_assert_dbg(_ref_is_mutable,
557 : "Model '",
558 0 : owner().name(),
559 : "' is trying to assign value to a variable '",
560 0 : name(),
561 : "' declared by model '",
562 0 : ref()->owner().name(),
563 : "' , but the referenced variable is not mutable.");
564 : // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
565 0 : *const_cast<VariableBase *>(ref()) = val;
566 : }
567 618 : }
568 :
569 : template <typename T>
570 : void
571 18315 : Variable<T>::clear()
572 : {
573 18315 : if (owning())
574 : {
575 18315 : VariableBase::clear();
576 18315 : _value = T();
577 : }
578 : else
579 : {
580 0 : neml_assert_dbg(_ref_is_mutable,
581 : "Model '",
582 0 : owner().name(),
583 : "' is trying to clear a variable '",
584 0 : name(),
585 : "' declared by model '",
586 0 : ref()->owner().name(),
587 : "' , but the referenced variable is not mutable.");
588 : // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
589 0 : const_cast<VariableBase *>(ref())->clear();
590 : }
591 18315 : }
592 :
593 : #define INSTANTIATE_VARIABLE(T) template class Variable<T>
594 : FOR_ALL_PRIMITIVETENSOR(INSTANTIATE_VARIABLE);
595 :
596 : Derivative &
597 951 : Derivative::operator=(const Tensor & val)
598 : {
599 951 : *_deriv = val.base_reshape(_base_sizes);
600 951 : return *this;
601 : }
602 : }
|