NEML2 2.1.0
Loading...
Searching...
No Matches
AxisLayout.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include <vector>
28
29#include "neml2/misc/types.h"
30
31namespace neml2
32{
34{
35 AxisLayout() = default;
36
38 enum class IStructure : uint8_t
39 {
42 };
43
52 AxisLayout(const std::vector<std::vector<VariableName>> & vars,
53 std::vector<TensorShape> intmd_shapes,
54 std::vector<TensorShape> base_shapes,
55 std::vector<IStructure> istrs);
56
66 AxisLayout(const AxisLayout * parent,
67 std::size_t group_idx,
68 std::size_t start,
69 std::size_t end,
70 std::vector<std::size_t> offsets = {});
71
73 std::size_t ngroup() const;
75 std::pair<std::size_t, std::size_t> group_offsets(std::size_t) const;
77 AxisLayout group(std::size_t) const;
79 IStructure istr(std::size_t = 0) const;
83 bool is_view() const { return _parent != nullptr; }
84
86 std::size_t nvar() const;
88 std::vector<Size> storage_sizes(bool include_intmd) const;
89
91 std::vector<VariableName> vars() const;
93 const VariableName & var(std::size_t) const;
95 const TensorShape & intmd_sizes(std::size_t) const;
97 const TensorShape & base_sizes(std::size_t) const;
98
100 void update_intmd_shapes(const std::vector<TensorShape> &);
101
102private:
104 std::vector<VariableName> _vars;
106 std::vector<TensorShape> _intmd_shapes;
108 std::vector<TensorShape> _base_shapes;
110 std::vector<std::size_t> _offsets;
112 std::vector<IStructure> _istrs;
113
115 const AxisLayout * _parent = nullptr;
117 std::size_t _group_idx = 0;
119 std::size_t _start = 0;
121 std::size_t _end = 0;
122};
123
125bool operator==(const AxisLayout &, const AxisLayout &);
126} // namespace neml2
Definition DiagnosticsInterface.h:31
c10::SmallVector< Size, 8 > TensorShape
Definition types.h:72
bool operator==(const AxisLayout &, const AxisLayout &)
comparison operator
std::string VariableName
Definition types.h:75
Definition AxisLayout.h:34
std::pair< std::size_t, std::size_t > group_offsets(std::size_t) const
Starting and ending offsets of a variable group.
std::vector< VariableName > vars() const
Accessor for variable names.
AxisLayout(const std::vector< std::vector< VariableName > > &vars, std::vector< TensorShape > intmd_shapes, std::vector< TensorShape > base_shapes, std::vector< IStructure > istrs)
Construct a new Axis Layout object.
std::size_t ngroup() const
Number of variable groups.
bool is_view() const
Whether this is a view into a parent layout.
Definition AxisLayout.h:83
const VariableName & var(std::size_t) const
Accessor for variable name.
AxisLayout(const AxisLayout *parent, std::size_t group_idx, std::size_t start, std::size_t end, std::vector< std::size_t > offsets={})
Construct a new Axis Layout object by viewing into a parent layout.
AxisLayout()=default
const TensorShape & base_sizes(std::size_t) const
Accessor for variable base shape.
void update_intmd_shapes(const std::vector< TensorShape > &)
Update intermediate shapes.
AxisLayout group(std::size_t) const
Contiguous view of the variable group.
std::vector< Size > storage_sizes(bool include_intmd) const
Storage sizes of variables.
IStructure
Enum for the structure represented by intermediate dimensions (if any).
Definition AxisLayout.h:39
@ DENSE
All intermediate dimensions are grouped into base dimensions.
Definition AxisLayout.h:40
@ BLOCK
Intermediate dimensions represent blocks of variables.
Definition AxisLayout.h:41
IStructure istr(std::size_t=0) const
Variable group IStructure.
std::size_t nvar() const
Number of variables.
const TensorShape & intmd_sizes(std::size_t) const
Accessor for variable intermediate shape.
AxisLayout view() const
Contiguous view of the entire layout.