27#include "neml2/models/Interpolation.h"
28#include "neml2/tensors/Scalar.h"
84 void set_value(
bool out,
bool dout_din,
bool d2out_din2)
override;
97 template <
typename T2>
98 T2 mask(
const T2 & in,
const Scalar & m)
const;
102template <
typename T2>
104LinearInterpolation<T>::mask(
const T2 & in,
const Scalar & m)
const
109 return T2(in.batch_expand_as(m).index({m})).batch_reshape(B);
Interpolation(const OptionSet &options)
Definition Interpolation.cxx:57
LinearInterpolation(const OptionSet &options)
Definition LinearInterpolation.cxx:44
static OptionSet expected_options()
Definition LinearInterpolation.cxx:35
void set_value(bool out, bool dout_din, bool d2out_din2) override
The map between input -> output, and optionally its derivatives.
Definition LinearInterpolation.cxx:51
A custom map-like data structure. The keys are strings, and the values can be nonhomogeneously typed.
Definition OptionSet.h:52
Scalar.
Definition Scalar.h:38
const TraceableTensorShape & batch_sizes() const
Return the batch size.
Definition TensorBaseImpl.h:178
Definition DiagnosticsInterface.cxx:30
TraceableTensorShape slice(Size start, Size end) const
Slice the shape, semantically the same as ArrayRef::slice, but traceable.
Definition TraceableTensorShape.cxx:59