LCOV - code coverage report
Current view: top level - jit - TraceableTensorShape.cxx (source / functions) Coverage Total Hit
Test: coverage.info Lines: 80.0 % 45 36
Test Date: 2025-06-29 01:25:44 Functions: 80.0 % 10 8

            Line data    Source code
       1              : // Copyright 2024, UChicago Argonne, LLC
       2              : // All Rights Reserved
       3              : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
       4              : // By: Argonne National Laboratory
       5              : // OPEN SOURCE LICENSE (MIT)
       6              : //
       7              : // Permission is hereby granted, free of charge, to any person obtaining a copy
       8              : // of this software and associated documentation files (the "Software"), to deal
       9              : // in the Software without restriction, including without limitation the rights
      10              : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
      11              : // copies of the Software, and to permit persons to whom the Software is
      12              : // furnished to do so, subject to the following conditions:
      13              : //
      14              : // The above copyright notice and this permission notice shall be included in
      15              : // all copies or substantial portions of the Software.
      16              : //
      17              : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      18              : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      19              : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      20              : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      21              : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      22              : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      23              : // THE SOFTWARE.
      24              : 
      25              : #include <ATen/ops/stack.h>
      26              : 
      27              : #include "neml2/jit/TraceableTensorShape.h"
      28              : #include "neml2/misc/assertions.h"
      29              : 
      30              : namespace neml2
      31              : {
      32        99005 : TraceableTensorShape::TraceableTensorShape(const TensorShape & shape)
      33              : {
      34       149278 :   for (const auto & size : shape)
      35        50273 :     emplace_back(size);
      36        99005 : }
      37              : 
      38      2784075 : TraceableTensorShape::TraceableTensorShape(TensorShapeRef shape)
      39              : {
      40      3672963 :   for (const auto & size : shape)
      41       888888 :     emplace_back(size);
      42      2784075 : }
      43              : 
      44            0 : TraceableTensorShape::TraceableTensorShape(Size shape)
      45            0 :   : TraceableTensorShape(TensorShapeRef({shape}))
      46              : {
      47            0 : }
      48              : 
      49         1281 : TraceableTensorShape::TraceableTensorShape(const ATensor & shape)
      50              : {
      51         1281 :   neml_assert_dbg(shape.dim() == 1, "TraceableTensorShape: shape must be 1D");
      52         1281 :   neml_assert_dbg(shape.scalar_type() == kInt64,
      53              :                   "TraceableTensorShape: shape must be of type int64");
      54         2603 :   for (Size i = 0; i < shape.size(0); i++)
      55         1322 :     emplace_back(shape[i]);
      56         1281 : }
      57              : 
      58              : TraceableTensorShape
      59       632888 : TraceableTensorShape::slice(Size start, Size end) const
      60              : {
      61       632888 :   if (start < 0)
      62            0 :     start += Size(size());
      63       632888 :   if (end < 0)
      64           60 :     end += Size(size());
      65              : 
      66       632888 :   return TraceableTensorShape(begin() + start, begin() + end);
      67              : }
      68              : 
      69              : TraceableTensorShape
      70            0 : TraceableTensorShape::slice(Size N) const
      71              : {
      72            0 :   if (N < 0)
      73            0 :     N += Size(size());
      74            0 :   return TraceableTensorShape(begin() + N, end());
      75              : }
      76              : 
      77              : TensorShape
      78      5509913 : TraceableTensorShape::concrete() const
      79              : {
      80      5509913 :   TensorShape s;
      81      7286543 :   for (const auto & size : *this)
      82      1776630 :     s.push_back(size.concrete());
      83      5509913 :   return s;
      84            0 : }
      85              : 
      86              : ATensor
      87        19937 : TraceableTensorShape::as_tensor() const
      88              : {
      89        19937 :   if (empty())
      90        16158 :     return ATensor();
      91              : 
      92         3779 :   auto sizes = std::vector<ATensor>(size());
      93         7606 :   for (std::size_t i = 0; i < size(); i++)
      94         3827 :     sizes[i] = at(i).as_tensor();
      95         3779 :   return at::stack(sizes);
      96         3779 : }
      97              : 
      98              : bool
      99      2696796 : operator==(const TraceableTensorShape & lhs, const TraceableTensorShape & rhs)
     100              : {
     101      2696796 :   return lhs.concrete() == rhs.concrete();
     102              : }
     103              : 
     104              : bool
     105        97938 : operator!=(const TraceableTensorShape & lhs, const TraceableTensorShape & rhs)
     106              : {
     107        97938 :   return !(lhs == rhs);
     108              : }
     109              : } // namespace neml2
        

Generated by: LCOV version 2.0-1