27#include "neml2/models/map_types.h"
Definition DiagnosticsInterface.cxx:30
c10::Device Device
Definition types.h:66
std::map< LabeledAxisAccessor, ValueMap > DerivMap
Definition map_types_fwd.h:34
int64_t Size
Definition types.h:69
DerivMap derivmap_cat_reduce(std::vector< DerivMap > &&results, Size batch_dim)
Concatenate the tensors in the DerivMap along the batch dimension.
Definition derivmap_helpers.cxx:32
DerivMap derivmap_no_operation(DerivMap &&x)
No operation.
Definition derivmap_helpers.cxx:74
DerivMap derivmap_move_device(DerivMap &&x, Device device)
Move all tensors in a DerivMap to a device.
Definition derivmap_helpers.cxx:64