Line data Source code
1 : // Copyright 2024, UChicago Argonne, LLC
2 : // All Rights Reserved
3 : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
4 : // By: Argonne National Laboratory
5 : // OPEN SOURCE LICENSE (MIT)
6 : //
7 : // Permission is hereby granted, free of charge, to any person obtaining a copy
8 : // of this software and associated documentation files (the "Software"), to deal
9 : // in the Software without restriction, including without limitation the rights
10 : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 : // copies of the Software, and to permit persons to whom the Software is
12 : // furnished to do so, subject to the following conditions:
13 : //
14 : // The above copyright notice and this permission notice shall be included in
15 : // all copies or substantial portions of the Software.
16 : //
17 : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 : // THE SOFTWARE.
24 :
25 : #include "neml2/models/BufferStore.h"
26 : #include "neml2/misc/assertions.h"
27 : #include "neml2/base/NEML2Object.h"
28 : #include "neml2/base/TensorName.h"
29 : #include "neml2/base/Settings.h"
30 : #include "neml2/tensors/tensors.h"
31 : #include "neml2/tensors/TensorValue.h"
32 :
33 : namespace neml2
34 : {
35 463 : BufferStore::BufferStore(NEML2Object * object)
36 463 : : _object(object)
37 : {
38 463 : }
39 :
40 : std::map<std::string, std::unique_ptr<TensorValueBase>> &
41 145584 : BufferStore::named_buffers()
42 : {
43 145584 : neml_assert(_object->host() == _object,
44 : "named_buffers() should only be called on the host model.");
45 145584 : return _buffer_values;
46 : }
47 :
48 : TensorValueBase &
49 4 : BufferStore::get_buffer(const std::string & name)
50 : {
51 4 : neml_assert(_object->host() == _object, "This method should only be called on the host model.");
52 4 : neml_assert(_buffer_values.count(name), "Buffer named ", name, " does not exist.");
53 4 : return *_buffer_values[name];
54 : }
55 :
56 : void
57 91 : BufferStore::send_buffers_to(const TensorOptions & options)
58 : {
59 117 : for (auto && [name, buffer] : _buffer_values)
60 26 : buffer->to_(options);
61 91 : }
62 :
63 : template <typename T, typename>
64 : const T &
65 523 : BufferStore::declare_buffer(const std::string & name, const T & rawval)
66 : {
67 523 : if (_object->host() != _object)
68 423 : return _object->host<BufferStore>()->declare_buffer(
69 282 : _object->name() + _object->settings().buffer_name_separator() + name, rawval);
70 :
71 382 : TensorValueBase * base_ptr = nullptr;
72 :
73 : // If the buffer already exists, return its reference
74 382 : if (_buffer_values.count(name))
75 4 : base_ptr = &get_buffer(name);
76 : else
77 : {
78 378 : auto val = std::make_unique<TensorValue<T>>(rawval);
79 378 : auto [it, success] = _buffer_values.emplace(name, std::move(val));
80 378 : base_ptr = it->second.get();
81 378 : }
82 :
83 382 : auto ptr = dynamic_cast<TensorValue<T> *>(base_ptr);
84 382 : neml_assert(ptr, "Internal error: Failed to cast buffer to a concrete type.");
85 382 : return ptr->value();
86 : }
87 :
88 : template <typename T, typename>
89 : const T &
90 134 : BufferStore::declare_buffer(const std::string & name, const TensorName<T> & tensorname)
91 : {
92 134 : auto * factory = _object->factory();
93 134 : neml_assert(factory, "Internal error: factory != nullptr");
94 134 : return declare_buffer(name, tensorname.resolve(factory));
95 : }
96 :
97 : template <typename T, typename>
98 : const T &
99 95 : BufferStore::declare_buffer(const std::string & name, const std::string & input_option_name)
100 : {
101 95 : if (_object->input_options().contains(input_option_name))
102 95 : return declare_buffer<T>(name, _object->input_options().get<TensorName<T>>(input_option_name));
103 :
104 0 : throw NEMLException(
105 : "Trying to register buffer named " + name + " from input option named " + input_option_name +
106 : " of type " + utils::demangle(typeid(T).name()) +
107 : ". Make sure you provided the correct buffer name, option name, and buffer type.");
108 : }
109 :
110 : #define BUFFERSTORE_INTANTIATE_TENSORBASE(T) \
111 : template const T & BufferStore::declare_buffer<T>(const std::string &, const T &); \
112 : template const T & BufferStore::declare_buffer<T>(const std::string &, const TensorName<T> &); \
113 : template const T & BufferStore::declare_buffer<T>(const std::string &, const std::string &)
114 : FOR_ALL_TENSORBASE(BUFFERSTORE_INTANTIATE_TENSORBASE);
115 :
116 : void
117 0 : BufferStore::assign_buffer_stack(jit::Stack & stack)
118 : {
119 0 : const auto & buffers = _object->host<BufferStore>()->named_buffers();
120 :
121 0 : neml_assert_dbg(stack.size() >= buffers.size(),
122 : "Stack size (",
123 0 : stack.size(),
124 : ") is smaller than the number of buffers in the model (",
125 0 : buffers.size(),
126 : ").");
127 :
128 : // Last n tensors in the stack are the buffers
129 0 : std::size_t i = stack.size() - buffers.size();
130 0 : for (auto && [name, buffer] : buffers)
131 : {
132 0 : const auto tensor = stack[i++].toTensor();
133 0 : (*buffer) = Tensor(tensor, tensor.dim() - Tensor(*buffer).base_dim());
134 0 : }
135 :
136 : // Drop the input variables from the stack
137 0 : jit::drop(stack, buffers.size());
138 0 : }
139 :
140 : jit::Stack
141 0 : BufferStore::collect_buffer_stack() const
142 : {
143 0 : const auto & buffers = _object->host<BufferStore>()->named_buffers();
144 0 : jit::Stack stack;
145 0 : stack.reserve(buffers.size());
146 0 : for (auto && [name, buffer] : buffers)
147 0 : stack.emplace_back(Tensor(*buffer));
148 0 : return stack;
149 0 : }
150 : } // namespace neml2
|