Line data Source code
1 : // Copyright 2024, UChicago Argonne, LLC
2 : // All Rights Reserved
3 : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
4 : // By: Argonne National Laboratory
5 : // OPEN SOURCE LICENSE (MIT)
6 : //
7 : // Permission is hereby granted, free of charge, to any person obtaining a copy
8 : // of this software and associated documentation files (the "Software"), to deal
9 : // in the Software without restriction, including without limitation the rights
10 : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 : // copies of the Software, and to permit persons to whom the Software is
12 : // furnished to do so, subject to the following conditions:
13 : //
14 : // The above copyright notice and this permission notice shall be included in
15 : // all copies or substantial portions of the Software.
16 : //
17 : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 : // THE SOFTWARE.
24 :
25 : #include "neml2/user_tensors/UserTensor.h"
26 : #include "neml2/misc/assertions.h"
27 :
28 : namespace neml2
29 : {
30 : register_NEML2_object_alias(UserTensor, "Tensor");
31 :
32 : OptionSet
33 2 : UserTensor::expected_options()
34 : {
35 2 : OptionSet options = UserTensorBase::expected_options();
36 2 : options.doc() = "Construct a Tensor from a vector of values. The vector will be reshaped "
37 2 : "according to the specified batch and base shapes.";
38 :
39 4 : options.set<std::vector<double>>("values");
40 2 : options.set("values").doc() = "Values in this (flattened) tensor";
41 :
42 6 : options.set<TensorShape>("batch_shape") = {};
43 2 : options.set("batch_shape").doc() = "Batch shape";
44 :
45 6 : options.set<TensorShape>("base_shape") = {};
46 2 : options.set("base_shape").doc() = "Base shape";
47 :
48 2 : return options;
49 0 : }
50 :
51 12 : UserTensor::UserTensor(const OptionSet & options)
52 24 : : Tensor(Tensor::empty(options.get<TensorShape>("batch_shape"),
53 24 : options.get<TensorShape>("base_shape"),
54 12 : default_tensor_options())),
55 48 : UserTensorBase(options)
56 : {
57 12 : auto vals = options.get<std::vector<double>>("values");
58 12 : auto flat = Tensor::create(vals, default_tensor_options());
59 12 : if (vals.size() == size_t(this->base_storage()))
60 27 : this->index_put_({indexing::Ellipsis}, flat.reshape(this->base_sizes()));
61 3 : else if (vals.size() == size_t(utils::storage_size(this->sizes())))
62 6 : this->index_put_({indexing::Ellipsis}, flat.reshape(this->sizes()));
63 : else
64 1 : neml_assert(false,
65 : "Number of values ",
66 1 : vals.size(),
67 : " must equal to either the base storage size ",
68 2 : this->base_storage(),
69 : " or the total storage size ",
70 2 : utils::storage_size(this->sizes()));
71 26 : }
72 : } // namespace neml2
|