27#include "neml2/tensors/Tensor.h"
32template <
typename F,
typename T1,
typename T2>
36 return Tensor(std::forward<F>(f)(a, b.batch_unsqueeze(-1)), b.batch_sizes());
40template <
typename F,
typename T1,
typename T2>
44 return Tensor(std::forward<F>(f)(a.batch_unsqueeze(-1), b), a.batch_sizes())
49template <
typename F,
typename T1,
typename T2>
53 return Tensor(std::forward<F>(f)(a.batch_unsqueeze(-1), b.batch_unsqueeze(-2)), a.batch_dim() - 1)
neml2::Tensor base_transpose(Size d1, Size d2) const
Transpose two base dimensions.
Definition TensorBaseImpl.h:438
Definition DiagnosticsInterface.cxx:30
Tensor list_derivative_outer_product_b(F &&f, const T1 &a, const T2 &b)
outer product on lists, where the second input is a list tensor
Definition list_tensors.h:42
Tensor list_derivative_outer_product_ab(F &&f, const T1 &a, const T2 &b)
outer product on lists where both inputs are list tensors
Definition list_tensors.h:51
Tensor list_derivative_outer_product_a(F &&f, const T1 &a, const T2 &b)
outer product on lists, where the first input is a list tensor
Definition list_tensors.h:34