LCOV - code coverage report
Current view: top level - tensors - Tensor.cxx (source / functions) Coverage Total Hit
Test: coverage.info Lines: 95.1 % 82 78
Test Date: 2025-06-29 01:25:44 Functions: 100.0 % 17 17

            Line data    Source code
       1              : // Copyright 2024, UChicago Argonne, LLC
       2              : // All Rights Reserved
       3              : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
       4              : // By: Argonne National Laboratory
       5              : // OPEN SOURCE LICENSE (MIT)
       6              : //
       7              : // Permission is hereby granted, free of charge, to any person obtaining a copy
       8              : // of this software and associated documentation files (the "Software"), to deal
       9              : // in the Software without restriction, including without limitation the rights
      10              : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
      11              : // copies of the Software, and to permit persons to whom the Software is
      12              : // furnished to do so, subject to the following conditions:
      13              : //
      14              : // The above copyright notice and this permission notice shall be included in
      15              : // all copies or substantial portions of the Software.
      16              : //
      17              : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      18              : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      19              : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      20              : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      21              : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      22              : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      23              : // THE SOFTWARE.
      24              : 
      25              : #include <torch/csrc/autograd/variable.h>
      26              : #include "neml2/tensors/Tensor.h"
      27              : #include "neml2/tensors/shape_utils.h"
      28              : #include "neml2/misc/assertions.h"
      29              : #include "neml2/jit/types.h"
      30              : 
      31              : namespace neml2
      32              : {
      33              : namespace utils
      34              : {
      35              : ATensor
      36         3779 : pad_prepend(const ATensor & s, Size dim, Size pad)
      37              : {
      38         3779 :   neml_assert_dbg(s.defined(), "pad_prepend: shape must be defined");
      39         3779 :   neml_assert_dbg(s.scalar_type() == kInt64, "pad_prepend: shape must be of type int64");
      40         3779 :   neml_assert_dbg(s.dim() == 1, "pad_prepend: shape must be 1D");
      41        22674 :   return at::cat({at::full({dim - s.size(0)}, pad, s.options()), s});
      42        11337 : }
      43              : 
      44              : TraceableTensorShape
      45         6674 : broadcast_batch_sizes(const std::vector<Tensor> & tensors)
      46              : {
      47         6674 :   Size dim = 0;
      48         6674 :   auto shapes = std::vector<ATensor>{};
      49        26626 :   for (const auto & t : tensors)
      50        19952 :     if (t.defined())
      51              :     {
      52        19937 :       dim = t.batch_dim() > dim ? t.batch_dim() : dim;
      53        19937 :       const auto shape = t.batch_sizes().as_tensor();
      54        19937 :       if (shape.defined())
      55         3779 :         shapes.push_back(shape);
      56        19937 :     }
      57         6674 :   if (shapes.empty())
      58         5393 :     return TraceableTensorShape(TensorShape{});
      59              :   /// Pre-pad ones to the shapes
      60         5060 :   for (auto & s : shapes)
      61         3779 :     s = pad_prepend(s, dim, 1);
      62              :   /// Braodcast
      63         1281 :   const auto all_shapes = at::stack(shapes);
      64         1281 :   return std::get<0>(at::max(all_shapes, 0));
      65         6674 : }
      66              : } // namespace utils
      67              : 
      68        29731 : Tensor::Tensor(const ATensor & tensor, Size batch_dim)
      69        29731 :   : TensorBase(tensor, batch_dim)
      70              : {
      71        29731 : }
      72              : 
      73      2529846 : Tensor::Tensor(const ATensor & tensor, const TraceableTensorShape & batch_shape)
      74      2529846 :   : TensorBase(tensor, batch_shape)
      75              : {
      76      2529846 : }
      77              : 
      78              : Tensor
      79          839 : Tensor::create(const TensorDataContainer & data, const TensorOptions & options)
      80              : {
      81          839 :   return create(data, 0, options);
      82              : }
      83              : 
      84              : Tensor
      85          846 : Tensor::create(const TensorDataContainer & data, Size batch_dim, const TensorOptions & options)
      86              : {
      87         1692 :   return Tensor(torch::autograd::make_variable(data.convert_to_tensor(options.requires_grad(false)),
      88          846 :                                                options.requires_grad()),
      89         1692 :                 batch_dim);
      90              : }
      91              : 
      92              : Tensor
      93            2 : Tensor::empty(TensorShapeRef base_shape, const TensorOptions & options)
      94              : {
      95            2 :   return Tensor(at::empty(base_shape, options), 0);
      96              : }
      97              : 
      98              : Tensor
      99          171 : Tensor::empty(const TraceableTensorShape & batch_shape,
     100              :               TensorShapeRef base_shape,
     101              :               const TensorOptions & options)
     102              : {
     103              :   // Record batch shape
     104          358 :   for (Size i = 0; i < (Size)batch_shape.size(); ++i)
     105          187 :     if (const auto * const si = batch_shape[i].traceable())
     106            0 :       jit::tracer::ArgumentStash::stashIntArrayRefElem(
     107            0 :           "size", batch_shape.size() + base_shape.size(), i, *si);
     108              : 
     109          342 :   return Tensor(at::empty(utils::add_shapes(batch_shape.concrete(), base_shape), options),
     110          342 :                 batch_shape);
     111              : }
     112              : 
     113              : Tensor
     114           51 : Tensor::zeros(TensorShapeRef base_shape, const TensorOptions & options)
     115              : {
     116           51 :   return Tensor(at::zeros(base_shape, options), 0);
     117              : }
     118              : 
     119              : Tensor
     120        18529 : Tensor::zeros(const TraceableTensorShape & batch_shape,
     121              :               TensorShapeRef base_shape,
     122              :               const TensorOptions & options)
     123              : {
     124              :   // Record batch shape
     125        18959 :   for (Size i = 0; i < (Size)batch_shape.size(); ++i)
     126          430 :     if (const auto * const si = batch_shape[i].traceable())
     127           27 :       jit::tracer::ArgumentStash::stashIntArrayRefElem(
     128            9 :           "size", batch_shape.size() + base_shape.size(), i, *si);
     129              : 
     130        37058 :   return Tensor(at::zeros(utils::add_shapes(batch_shape.concrete(), base_shape), options),
     131        37058 :                 batch_shape);
     132              : }
     133              : 
     134              : Tensor
     135          135 : Tensor::ones(TensorShapeRef base_shape, const TensorOptions & options)
     136              : {
     137          135 :   return Tensor(at::ones(base_shape, options), 0);
     138              : }
     139              : 
     140              : Tensor
     141          105 : Tensor::ones(const TraceableTensorShape & batch_shape,
     142              :              TensorShapeRef base_shape,
     143              :              const TensorOptions & options)
     144              : {
     145              :   // Record batch shape
     146          240 :   for (Size i = 0; i < (Size)batch_shape.size(); ++i)
     147          135 :     if (const auto * const si = batch_shape[i].traceable())
     148           27 :       jit::tracer::ArgumentStash::stashIntArrayRefElem(
     149            9 :           "size", batch_shape.size() + base_shape.size(), i, *si);
     150              : 
     151          210 :   return Tensor(at::ones(utils::add_shapes(batch_shape.concrete(), base_shape), options),
     152          210 :                 batch_shape);
     153              : }
     154              : 
     155              : Tensor
     156         1669 : Tensor::full(TensorShapeRef base_shape, const CScalar & init, const TensorOptions & options)
     157              : {
     158         1669 :   return Tensor(at::full(base_shape, init, options), 0);
     159              : }
     160              : 
     161              : Tensor
     162          280 : Tensor::full(const TraceableTensorShape & batch_shape,
     163              :              TensorShapeRef base_shape,
     164              :              const CScalar & init,
     165              :              const TensorOptions & options)
     166              : {
     167              :   // Record batch shape
     168          863 :   for (Size i = 0; i < (Size)batch_shape.size(); ++i)
     169          583 :     if (const auto * const si = batch_shape[i].traceable())
     170            0 :       jit::tracer::ArgumentStash::stashIntArrayRefElem(
     171            0 :           "size", batch_shape.size() + base_shape.size(), i, *si);
     172              : 
     173          560 :   return Tensor(at::full(utils::add_shapes(batch_shape.concrete(), base_shape), init, options),
     174          560 :                 batch_shape);
     175              : }
     176              : 
     177              : Tensor
     178          193 : Tensor::identity(Size n, const TensorOptions & options)
     179              : {
     180          193 :   return Tensor(at::eye(n, options), 0);
     181              : }
     182              : 
     183              : Tensor
     184            3 : Tensor::identity(const TraceableTensorShape & batch_shape, Size n, const TensorOptions & options)
     185              : {
     186            3 :   return identity(n, options).batch_expand_copy(batch_shape);
     187              : }
     188              : 
     189              : Tensor
     190            4 : Tensor::base_unsqueeze_to(Size n) const
     191              : {
     192            4 :   neml_assert_dbg(n >= base_dim(),
     193              :                   "base_unsqueeze_to: n (",
     194              :                   n,
     195              :                   ") must be greater than or equal to base_dim (",
     196            4 :                   base_dim(),
     197              :                   ").");
     198            8 :   indexing::TensorIndices net{indexing::Ellipsis};
     199            4 :   net.insert(net.end(), n - base_dim(), indexing::None);
     200           12 :   return Tensor(index(net), batch_sizes());
     201            8 : }
     202              : } // end namespace neml2
        

Generated by: LCOV version 2.0-1