27#include "neml2/tensors/Tensor.h"
32template <
typename F,
typename T1,
typename T2>
36 return Tensor(
f(a, b.batch_unsqueeze(-1)), b.batch_sizes());
40template <
typename F,
typename T1,
typename T2>
48template <
typename F,
typename T1,
typename T2>
52 return Tensor(
f(a.batch_unsqueeze(-1), b.batch_unsqueeze(-2)), a.batch_dim() - 1)
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:54
neml2::Tensor base_transpose(Size d1, Size d2) const
Transpose two base dimensions.
Definition TensorBaseImpl.h:404
Definition CrossRef.cxx:31
Tensor list_derivative_outer_product_b(F &&f, const T1 &a, const T2 &b)
outer product on lists, where the second input is a list tensor
Definition list_tensors.h:42
Tensor list_derivative_outer_product_ab(F &&f, const T1 &a, const T2 &b)
outer product on lists where both inputs are list tensors
Definition list_tensors.h:50
Tensor list_derivative_outer_product_a(F &&f, const T1 &a, const T2 &b)
outer product on lists, where the first input is a list tensor
Definition list_tensors.h:34