27#include "neml2/tensors/Tensor.h"
57 std::vector<CacheKey> _cached_keys;
60 std::vector<Tensor> _cached_tensors;
Definition DiagnosticsInterface.h:31
c10::Device Device
Definition types.h:69
constexpr auto kFloat64
Definition types.h:54
bool operator==(const LabeledAxis &a, const LabeledAxis &b)
c10::TensorOptions TensorOptions
Definition types.h:66
c10::ScalarType Dtype
Definition types.h:67
const Tensor & operator()(const TensorOptions &)
Get the tensor with the given tensor options. If the tensor does not exist in the cache,...
TensorCache(std::function< Tensor(const TensorOptions &)> &&)
Construct a new Tensor Cache object.