29#include "neml2/misc/types.h"
30#include "neml2/base/LabeledAxisAccessor.h"
31#include "neml2/tensors/indexing.h"
126 const std::vector<const LabeledAxis *> &
subaxes()
const;
136 const std::vector<std::pair<Size, Size>> &
subaxis_slices()
const;
155 void ensure_setup_dbg()
const;
167 std::map<std::string, std::pair<TensorShape, TensorShape>> _variables;
170 std::map<std::string, std::shared_ptr<LabeledAxis>> _subaxes;
180 std::map<LabeledAxisAccessor, std::size_t> _variable_to_id_map;
182 std::vector<LabeledAxisAccessor> _id_to_variable_map;
184 std::vector<Size> _id_to_variable_size_map;
186 std::vector<std::pair<Size, Size>> _id_to_variable_slice_map;
188 std::vector<TensorShape> _id_to_intmd_sizes_map;
190 std::vector<TensorShape> _id_to_base_sizes_map;
201 std::vector<const LabeledAxis *> _sorted_subaxes;
203 std::map<std::string, std::size_t> _subaxis_to_id_map;
205 std::vector<std::string> _id_to_subaxis_map;
207 std::vector<Size> _id_to_subaxis_size_map;
209 std::vector<std::pair<Size, Size>> _id_to_subaxis_slice_map;
213std::ostream &
operator<<(std::ostream & os,
const LabeledAxis & axis);
215bool operator==(
const LabeledAxis & a,
const LabeledAxis & b);
217bool operator!=(
const LabeledAxis & a,
const LabeledAxis & b);
The accessor containing all the information needed to access an item in a LabeledAxis.
Definition LabeledAxisAccessor.h:56
A labeled axis is used to associate layout of a tensor with human-interpretable names.
Definition LabeledAxis.h:48
bool equals(const LabeledAxis &other) const
Check to see if two axes are equivalent.
Definition LabeledAxis.cxx:537
std::size_t nsubaxis() const
Definition LabeledAxis.cxx:414
LabeledAxis()=default
Empty constructor.
bool is_setup() const
Whether the axis has been set up.
Definition LabeledAxis.h:60
bool has_variable(const LabeledAxisAccessor &name) const
Check the existence of a variable by its name.
Definition LabeledAxis.cxx:301
const std::vector< TensorShape > & variable_base_sizes() const
Get the variable base shapes (in assembly order)
Definition LabeledAxis.cxx:407
LabeledAxisAccessor disqualify(const LabeledAxisAccessor &accessor) const
Return the disqualified name of an item (i.e. remove the prefix)
Definition LabeledAxis.cxx:75
const std::vector< std::pair< Size, Size > > & variable_slices() const
Get the variable slicing indices (in assembly order)
Definition LabeledAxis.cxx:356
const std::vector< Size > & variable_sizes() const
Get the variable storage sizes (in assembly order)
Definition LabeledAxis.cxx:370
void add_variable(const LabeledAxisAccessor &name, TensorShapeRef intmd_sizes, TensorShapeRef base_sizes)
Add a variable with known storage size.
Definition LabeledAxis.cxx:97
friend std::ostream & operator<<(std::ostream &os, const LabeledAxis &axis)
Definition LabeledAxis.cxx:567
std::size_t nvariable() const
Definition LabeledAxis.cxx:287
const std::pair< Size, Size > & variable_slice(const LabeledAxisAccessor &name) const
Get the slicing indices of a variable by name.
Definition LabeledAxis.cxx:363
void set_intmd_sizes(const LabeledAxisAccessor &name, TensorShapeRef shape)
Set intermediate shape of a variable.
Definition LabeledAxis.cxx:117
std::vector< LabeledAxisAccessor > variable_names_unsrt() const
Get variable names in unsorted order.
Definition LabeledAxis.cxx:333
const std::vector< std::pair< Size, Size > > & subaxis_slices() const
Get the sub-axis slicing indices (in assembly order)
Definition LabeledAxis.cxx:492
std::vector< std::string > subaxis_names_unsrt() const
Get subaxis names in unsorted order.
Definition LabeledAxis.cxx:475
LabeledAxis & add_subaxis(const std::string &name)
Add a sub-axis.
Definition LabeledAxis.cxx:87
const std::vector< TensorShape > & variable_intmd_sizes() const
Get the variable left-batch shapes (in assembly order)
Definition LabeledAxis.cxx:400
bool has_subaxis(const LabeledAxisAccessor &name) const
Check the existence of a subaxis by its name.
Definition LabeledAxis.cxx:420
std::pair< Size, Size > subaxis_slice(const LabeledAxisAccessor &name) const
Get the slicing indices of a sub-axis by name.
Definition LabeledAxis.cxx:499
indexing::Slice slice(const LabeledAxisAccessor &name) const
Get the slicing indices of a variable or a local sub-axis.
Definition LabeledAxis.cxx:264
void setup_layout()
Setup the layout of all items recursively.
Definition LabeledAxis.cxx:129
Size subaxis_size(const LabeledAxisAccessor &name) const
Get the storage size of a sub-axis by name.
Definition LabeledAxis.cxx:524
const std::vector< LabeledAxisAccessor > & variable_names() const
Get the variable names.
Definition LabeledAxis.cxx:349
std::size_t variable_id(const LabeledAxisAccessor &name) const
Get the assembly ID of a variable.
Definition LabeledAxis.cxx:319
Size variable_size(const LabeledAxisAccessor &name) const
Get the storage size of a variable by name.
Definition LabeledAxis.cxx:377
void clear()
De-initialize the axis.
Definition LabeledAxis.cxx:37
Size size() const
Get the storage size of the entire axis.
Definition LabeledAxis.cxx:197
const std::vector< Size > & subaxis_sizes() const
Get the sub-axis storage sizes (in assembly order)
Definition LabeledAxis.cxx:517
std::size_t subaxis_id(const std::string &name) const
Get the assembly ID of a sub-axis.
Definition LabeledAxis.cxx:433
LabeledAxisAccessor qualify(const LabeledAxisAccessor &accessor) const
Return the fully qualified name of an item (i.e. useful when this axis is a sub-axis)
Definition LabeledAxis.cxx:69
const std::vector< std::string > & subaxis_names() const
Get the sub-axis names.
Definition LabeledAxis.cxx:485
const LabeledAxis & subaxis(const LabeledAxisAccessor &name) const
Get a sub-axis by name.
Definition LabeledAxis.cxx:453
const std::vector< const LabeledAxis * > & subaxes() const
Get the sub-axes (in assembly order)
Definition LabeledAxis.cxx:446
Definition DiagnosticsInterface.cxx:30
std::string name(ElasticConstant p)
Definition ElasticityConverter.cxx:30
bool operator==(const LabeledAxis &a, const LabeledAxis &b)
Definition LabeledAxis.cxx:597
bool operator!=(const LabeledAxis &a, const LabeledAxis &b)
Definition LabeledAxis.cxx:603
int64_t Size
Definition types.h:65
std::ostream & operator<<(std::ostream &os, const EnumSelection &es)
Definition EnumSelection.cxx:31
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:67