29#include "neml2/misc/types.h"
53 std::vector<TensorShape> intmd_shapes,
54 std::vector<TensorShape> base_shapes,
55 std::vector<IStructure> istrs);
67 std::size_t group_idx,
70 std::vector<std::size_t> offsets = {});
83 bool is_view()
const {
return _parent !=
nullptr; }
91 std::vector<VariableName>
vars()
const;
104 std::vector<VariableName> _vars;
106 std::vector<TensorShape> _intmd_shapes;
108 std::vector<TensorShape> _base_shapes;
110 std::vector<std::size_t> _offsets;
112 std::vector<IStructure> _istrs;
117 std::size_t _group_idx = 0;
119 std::size_t _start = 0;
121 std::size_t _end = 0;
Definition DiagnosticsInterface.h:31
c10::SmallVector< Size, 8 > TensorShape
Definition types.h:72
bool operator==(const AxisLayout &, const AxisLayout &)
comparison operator
std::string VariableName
Definition types.h:75
Definition AxisLayout.h:34
std::pair< std::size_t, std::size_t > group_offsets(std::size_t) const
Starting and ending offsets of a variable group.
std::vector< VariableName > vars() const
Accessor for variable names.
AxisLayout(const std::vector< std::vector< VariableName > > &vars, std::vector< TensorShape > intmd_shapes, std::vector< TensorShape > base_shapes, std::vector< IStructure > istrs)
Construct a new Axis Layout object.
std::size_t ngroup() const
Number of variable groups.
bool is_view() const
Whether this is a view into a parent layout.
Definition AxisLayout.h:83
const VariableName & var(std::size_t) const
Accessor for variable name.
AxisLayout(const AxisLayout *parent, std::size_t group_idx, std::size_t start, std::size_t end, std::vector< std::size_t > offsets={})
Construct a new Axis Layout object by viewing into a parent layout.
const TensorShape & base_sizes(std::size_t) const
Accessor for variable base shape.
void update_intmd_shapes(const std::vector< TensorShape > &)
Update intermediate shapes.
AxisLayout group(std::size_t) const
Contiguous view of the variable group.
std::vector< Size > storage_sizes(bool include_intmd) const
Storage sizes of variables.
IStructure
Enum for the structure represented by intermediate dimensions (if any).
Definition AxisLayout.h:39
@ DENSE
All intermediate dimensions are grouped into base dimensions.
Definition AxisLayout.h:40
@ BLOCK
Intermediate dimensions represent blocks of variables.
Definition AxisLayout.h:41
IStructure istr(std::size_t=0) const
Variable group IStructure.
std::size_t nvar() const
Number of variables.
const TensorShape & intmd_sizes(std::size_t) const
Accessor for variable intermediate shape.
AxisLayout view() const
Contiguous view of the entire layout.