LCOV - code coverage report
Current view: top level - jit - TraceableTensorShape.cxx (source / functions) Coverage Total Hit
Test: coverage.info Lines: 80.0 % 45 36
Test Date: 2025-10-02 16:03:03 Functions: 80.0 % 10 8

            Line data    Source code
       1              : // Copyright 2024, UChicago Argonne, LLC
       2              : // All Rights Reserved
       3              : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
       4              : // By: Argonne National Laboratory
       5              : // OPEN SOURCE LICENSE (MIT)
       6              : //
       7              : // Permission is hereby granted, free of charge, to any person obtaining a copy
       8              : // of this software and associated documentation files (the "Software"), to deal
       9              : // in the Software without restriction, including without limitation the rights
      10              : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
      11              : // copies of the Software, and to permit persons to whom the Software is
      12              : // furnished to do so, subject to the following conditions:
      13              : //
      14              : // The above copyright notice and this permission notice shall be included in
      15              : // all copies or substantial portions of the Software.
      16              : //
      17              : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      18              : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      19              : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      20              : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      21              : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      22              : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      23              : // THE SOFTWARE.
      24              : 
      25              : #include <ATen/ops/stack.h>
      26              : 
      27              : #include "neml2/jit/TraceableTensorShape.h"
      28              : #include "neml2/misc/assertions.h"
      29              : 
      30              : namespace neml2
      31              : {
      32       107287 : TraceableTensorShape::TraceableTensorShape(const TensorShape & shape)
      33              : {
      34       160806 :   for (const auto & size : shape)
      35        53519 :     emplace_back(size);
      36       107287 : }
      37              : 
      38      2744575 : TraceableTensorShape::TraceableTensorShape(TensorShapeRef shape)
      39              : {
      40      3668948 :   for (const auto & size : shape)
      41       924373 :     emplace_back(size);
      42      2744575 : }
      43              : 
      44            0 : TraceableTensorShape::TraceableTensorShape(Size shape)
      45            0 :   : TraceableTensorShape(TensorShapeRef({shape}))
      46              : {
      47            0 : }
      48              : 
      49         1329 : TraceableTensorShape::TraceableTensorShape(const ATensor & shape)
      50              : {
      51         1329 :   neml_assert_dbg(shape.dim() == 1, "TraceableTensorShape: shape must be 1D");
      52         1329 :   neml_assert_dbg(shape.scalar_type() == kInt64,
      53              :                   "TraceableTensorShape: shape must be of type int64");
      54         2783 :   for (Size i = 0; i < shape.size(0); i++)
      55         1454 :     emplace_back(shape[i]);
      56         1329 : }
      57              : 
      58              : TraceableTensorShape
      59       674813 : TraceableTensorShape::slice(Size start, Size end) const
      60              : {
      61       674813 :   if (start < 0)
      62            0 :     start += Size(size());
      63       674813 :   if (end < 0)
      64          116 :     end += Size(size());
      65              : 
      66       674813 :   return TraceableTensorShape(begin() + start, begin() + end);
      67              : }
      68              : 
      69              : TraceableTensorShape
      70            0 : TraceableTensorShape::slice(Size N) const
      71              : {
      72            0 :   if (N < 0)
      73            0 :     N += Size(size());
      74            0 :   return TraceableTensorShape(begin() + N, end());
      75              : }
      76              : 
      77              : TensorShape
      78      5429079 : TraceableTensorShape::concrete() const
      79              : {
      80      5429079 :   TensorShape s;
      81      7277095 :   for (const auto & size : *this)
      82      1848016 :     s.push_back(size.concrete());
      83      5429079 :   return s;
      84            0 : }
      85              : 
      86              : ATensor
      87        20150 : TraceableTensorShape::as_tensor() const
      88              : {
      89        20150 :   if (empty())
      90        16275 :     return ATensor();
      91              : 
      92         3875 :   auto sizes = std::vector<ATensor>(size());
      93         7942 :   for (std::size_t i = 0; i < size(); i++)
      94         4067 :     sizes[i] = at(i).as_tensor();
      95         3875 :   return at::stack(sizes);
      96         3875 : }
      97              : 
      98              : bool
      99      2651813 : operator==(const TraceableTensorShape & lhs, const TraceableTensorShape & rhs)
     100              : {
     101      2651813 :   return lhs.concrete() == rhs.concrete();
     102              : }
     103              : 
     104              : bool
     105       106180 : operator!=(const TraceableTensorShape & lhs, const TraceableTensorShape & rhs)
     106              : {
     107       106180 :   return !(lhs == rhs);
     108              : }
     109              : } // namespace neml2
        

Generated by: LCOV version 2.0-1