Line data Source code
1 : // Copyright 2024, UChicago Argonne, LLC
2 : // All Rights Reserved
3 : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
4 : // By: Argonne National Laboratory
5 : // OPEN SOURCE LICENSE (MIT)
6 : //
7 : // Permission is hereby granted, free of charge, to any person obtaining a copy
8 : // of this software and associated documentation files (the "Software"), to deal
9 : // in the Software without restriction, including without limitation the rights
10 : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 : // copies of the Software, and to permit persons to whom the Software is
12 : // furnished to do so, subject to the following conditions:
13 : //
14 : // The above copyright notice and this permission notice shall be included in
15 : // all copies or substantial portions of the Software.
16 : //
17 : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 : // THE SOFTWARE.
24 :
25 : #include <torch/csrc/autograd/variable.h>
26 : #include "neml2/tensors/Tensor.h"
27 : #include "neml2/tensors/shape_utils.h"
28 : #include "neml2/misc/assertions.h"
29 : #include "neml2/jit/types.h"
30 :
31 : namespace neml2
32 : {
33 : namespace utils
34 : {
35 : ATensor
36 3779 : pad_prepend(const ATensor & s, Size dim, Size pad)
37 : {
38 3779 : neml_assert_dbg(s.defined(), "pad_prepend: shape must be defined");
39 3779 : neml_assert_dbg(s.scalar_type() == kInt64, "pad_prepend: shape must be of type int64");
40 3779 : neml_assert_dbg(s.dim() == 1, "pad_prepend: shape must be 1D");
41 22674 : return at::cat({at::full({dim - s.size(0)}, pad, s.options()), s});
42 11337 : }
43 :
44 : TraceableTensorShape
45 6674 : broadcast_batch_sizes(const std::vector<Tensor> & tensors)
46 : {
47 6674 : Size dim = 0;
48 6674 : auto shapes = std::vector<ATensor>{};
49 26626 : for (const auto & t : tensors)
50 19952 : if (t.defined())
51 : {
52 19937 : dim = t.batch_dim() > dim ? t.batch_dim() : dim;
53 19937 : const auto shape = t.batch_sizes().as_tensor();
54 19937 : if (shape.defined())
55 3779 : shapes.push_back(shape);
56 19937 : }
57 6674 : if (shapes.empty())
58 5393 : return TraceableTensorShape(TensorShape{});
59 : /// Pre-pad ones to the shapes
60 5060 : for (auto & s : shapes)
61 3779 : s = pad_prepend(s, dim, 1);
62 : /// Braodcast
63 1281 : const auto all_shapes = at::stack(shapes);
64 1281 : return std::get<0>(at::max(all_shapes, 0));
65 6674 : }
66 : } // namespace utils
67 :
68 29731 : Tensor::Tensor(const ATensor & tensor, Size batch_dim)
69 29731 : : TensorBase(tensor, batch_dim)
70 : {
71 29731 : }
72 :
73 2529846 : Tensor::Tensor(const ATensor & tensor, const TraceableTensorShape & batch_shape)
74 2529846 : : TensorBase(tensor, batch_shape)
75 : {
76 2529846 : }
77 :
78 : Tensor
79 839 : Tensor::create(const TensorDataContainer & data, const TensorOptions & options)
80 : {
81 839 : return create(data, 0, options);
82 : }
83 :
84 : Tensor
85 846 : Tensor::create(const TensorDataContainer & data, Size batch_dim, const TensorOptions & options)
86 : {
87 1692 : return Tensor(torch::autograd::make_variable(data.convert_to_tensor(options.requires_grad(false)),
88 846 : options.requires_grad()),
89 1692 : batch_dim);
90 : }
91 :
92 : Tensor
93 2 : Tensor::empty(TensorShapeRef base_shape, const TensorOptions & options)
94 : {
95 2 : return Tensor(at::empty(base_shape, options), 0);
96 : }
97 :
98 : Tensor
99 171 : Tensor::empty(const TraceableTensorShape & batch_shape,
100 : TensorShapeRef base_shape,
101 : const TensorOptions & options)
102 : {
103 : // Record batch shape
104 358 : for (Size i = 0; i < (Size)batch_shape.size(); ++i)
105 187 : if (const auto * const si = batch_shape[i].traceable())
106 0 : jit::tracer::ArgumentStash::stashIntArrayRefElem(
107 0 : "size", batch_shape.size() + base_shape.size(), i, *si);
108 :
109 342 : return Tensor(at::empty(utils::add_shapes(batch_shape.concrete(), base_shape), options),
110 342 : batch_shape);
111 : }
112 :
113 : Tensor
114 51 : Tensor::zeros(TensorShapeRef base_shape, const TensorOptions & options)
115 : {
116 51 : return Tensor(at::zeros(base_shape, options), 0);
117 : }
118 :
119 : Tensor
120 18529 : Tensor::zeros(const TraceableTensorShape & batch_shape,
121 : TensorShapeRef base_shape,
122 : const TensorOptions & options)
123 : {
124 : // Record batch shape
125 18959 : for (Size i = 0; i < (Size)batch_shape.size(); ++i)
126 430 : if (const auto * const si = batch_shape[i].traceable())
127 27 : jit::tracer::ArgumentStash::stashIntArrayRefElem(
128 9 : "size", batch_shape.size() + base_shape.size(), i, *si);
129 :
130 37058 : return Tensor(at::zeros(utils::add_shapes(batch_shape.concrete(), base_shape), options),
131 37058 : batch_shape);
132 : }
133 :
134 : Tensor
135 135 : Tensor::ones(TensorShapeRef base_shape, const TensorOptions & options)
136 : {
137 135 : return Tensor(at::ones(base_shape, options), 0);
138 : }
139 :
140 : Tensor
141 105 : Tensor::ones(const TraceableTensorShape & batch_shape,
142 : TensorShapeRef base_shape,
143 : const TensorOptions & options)
144 : {
145 : // Record batch shape
146 240 : for (Size i = 0; i < (Size)batch_shape.size(); ++i)
147 135 : if (const auto * const si = batch_shape[i].traceable())
148 27 : jit::tracer::ArgumentStash::stashIntArrayRefElem(
149 9 : "size", batch_shape.size() + base_shape.size(), i, *si);
150 :
151 210 : return Tensor(at::ones(utils::add_shapes(batch_shape.concrete(), base_shape), options),
152 210 : batch_shape);
153 : }
154 :
155 : Tensor
156 1669 : Tensor::full(TensorShapeRef base_shape, const CScalar & init, const TensorOptions & options)
157 : {
158 1669 : return Tensor(at::full(base_shape, init, options), 0);
159 : }
160 :
161 : Tensor
162 280 : Tensor::full(const TraceableTensorShape & batch_shape,
163 : TensorShapeRef base_shape,
164 : const CScalar & init,
165 : const TensorOptions & options)
166 : {
167 : // Record batch shape
168 863 : for (Size i = 0; i < (Size)batch_shape.size(); ++i)
169 583 : if (const auto * const si = batch_shape[i].traceable())
170 0 : jit::tracer::ArgumentStash::stashIntArrayRefElem(
171 0 : "size", batch_shape.size() + base_shape.size(), i, *si);
172 :
173 560 : return Tensor(at::full(utils::add_shapes(batch_shape.concrete(), base_shape), init, options),
174 560 : batch_shape);
175 : }
176 :
177 : Tensor
178 193 : Tensor::identity(Size n, const TensorOptions & options)
179 : {
180 193 : return Tensor(at::eye(n, options), 0);
181 : }
182 :
183 : Tensor
184 3 : Tensor::identity(const TraceableTensorShape & batch_shape, Size n, const TensorOptions & options)
185 : {
186 3 : return identity(n, options).batch_expand_copy(batch_shape);
187 : }
188 :
189 : Tensor
190 4 : Tensor::base_unsqueeze_to(Size n) const
191 : {
192 4 : neml_assert_dbg(n >= base_dim(),
193 : "base_unsqueeze_to: n (",
194 : n,
195 : ") must be greater than or equal to base_dim (",
196 4 : base_dim(),
197 : ").");
198 8 : indexing::TensorIndices net{indexing::Ellipsis};
199 4 : net.insert(net.end(), n - base_dim(), indexing::None);
200 12 : return Tensor(index(net), batch_sizes());
201 8 : }
202 : } // end namespace neml2
|