27#include "neml2/models/Interpolation.h"
28#include "neml2/tensors/Scalar.h"
93 template <
typename T2>
94 static T2
mask(
const T2 & in,
const Scalar & m);
97 void set_value(
bool out,
bool dout_din,
bool d2out_din2)
override;
107template <
typename T2>
114 return T2(in.batch_expand_as(m).index({m})).batch_reshape(B);
The base class for interpolated variable.
Definition Interpolation.h:51
Linearly interpolate the parameter along a single axis.
Definition LinearInterpolation.h:77
const Scalar & _X
The abscissa values of the interpolant.
Definition LinearInterpolation.h:100
static T2 mask(const T2 &in, const Scalar &m)
Apply the mask tensor m on the input in.
Definition LinearInterpolation.h:109
LinearInterpolation(const OptionSet &options)
Definition LinearInterpolation.cxx:53
const Variable< Scalar > & _x
Argument of interpolation.
Definition LinearInterpolation.h:103
static OptionSet expected_options()
Definition LinearInterpolation.cxx:36
void set_value(bool out, bool dout_din, bool d2out_din2) override
The map between input -> output, and optionally its derivatives.
Definition LinearInterpolation.cxx:62
A custom map-like data structure. The keys are strings, and the values can be nonhomogeneously typed.
Definition OptionSet.h:51
Scalar.
Definition Scalar.h:38
const TraceableTensorShape & batch_sizes() const
Return the batch size.
Definition TensorBaseImpl.h:182
Concrete definition of a variable.
Definition VariableStore.h:41
Definition DiagnosticsInterface.cxx:30
TraceableTensorShape slice(Size start, Size end) const
Slice the shape, semantically the same as ArrayRef::slice, but traceable.
Definition TraceableTensorShape.cxx:59