NEML2 2.0.0
Loading...
Searching...
No Matches
neml2::utils Namespace Reference

Functions

std::stringstream & operator>> (std::stringstream &in, torch::Tensor &)
 This is a dummy to prevent compilers whining about not know how to >> torch::Tensor.
 
std::string join (const std::vector< std::string > &strs, const std::string &delim)
 
std::vector< std::string > split (const std::string &str, const std::string &delims)
 
std::string trim (const std::string &str, const std::string &white_space)
 
bool start_with (std::string_view str, std::string_view prefix)
 
bool end_with (std::string_view str, std::string_view suffix)
 
template<>
void parse_ (bool &val, const std::string &raw_str)
 
template<>
void parse_vector_ (std::vector< bool > &vals, const std::string &raw_str)
 
template<>
void parse_ (VariableName &val, const std::string &raw_str)
 
template<>
void parse_ (TensorShape &val, const std::string &raw_str)
 
std::string demangle (const char *name)
 Demangle a piece of cxx abi type information.
 
TraceableTensorShape extract_batch_sizes (const torch::Tensor &tensor, Size batch_dim)
 Extract the batch shape of a tensor given batch dimension The extracted batch shape will be traceable.
 
Size storage_size (TensorShapeRef shape)
 The flattened storage size of a tensor with given shape.
 
TensorShape pad_prepend (TensorShapeRef s, Size dim, Size pad=1)
 Pad shape s to dimension dim by prepending sizes of pad.
 
torch::Tensor pad_prepend (const torch::Tensor &s, Size dim, Size pad)
 
std::string indentation (int level, int indent)
 
TraceableTensorShape broadcast_batch_sizes (const std::vector< Tensor > &tensors)
 Find the broadcast batch shape of all the tensors The returned batch shape will be traceable.
 
torch::Dtype same_dtype (const std::vector< Tensor > &tensors)
 Make sure all tensors have the same dtype and return the common dtype.
 
torch::Device same_device (const std::vector< Tensor > &tensors)
 Make sure all tensors have the same device and return the common device.
 
template<typename T >
void parse_ (T &val, const std::string &raw_str)
 
template<typename T >
parse (const std::string &raw_str)
 
template<typename T >
void parse_vector_ (std::vector< T > &vals, const std::string &raw_str)
 
template<typename T >
std::vector< T > parse_vector (const std::string &raw_str)
 
template<typename T >
void parse_vector_vector_ (std::vector< std::vector< T > > &vals, const std::string &raw_str)
 
template<typename T >
std::vector< std::vector< T > > parse_vector_vector (const std::string &raw_str)
 
template<>
void parse_< bool > (bool &, const std::string &raw_str)
 
template<>
void parse_vector_< bool > (std::vector< bool > &, const std::string &raw_str)
 This special one is for the evil std::vector<bool>!
 
template<>
void parse_< TensorShape > (TensorShape &, const std::string &raw_str)
 
template<>
void parse_< VariableName > (VariableName &, const std::string &raw_str)
 
template<class... T>
bool sizes_same (T &&... shapes)
 Check if all shapes are the same.
 
template<class... T>
bool sizes_broadcastable (const T &... shapes)
 Check if the shapes are broadcastable.
 
template<class... T>
TensorShape broadcast_sizes (const T &... shapes)
 Return the broadcast shape of all the shapes.
 
template<typename... S>
TensorShape add_shapes (S &&... shape)
 
template<typename... S>
TraceableTensorShape add_traceable_shapes (S &&... shape)
 
template<typename T >
std::string stringify (const T &t)
 
template<>
std::string stringify (const bool &t)
 

Function Documentation

◆ add_shapes()

template<typename... S>
TensorShape add_shapes ( S &&... shape)

◆ add_traceable_shapes()

template<typename... S>
TraceableTensorShape add_traceable_shapes ( S &&... shape)

◆ broadcast_batch_sizes()

TraceableTensorShape broadcast_batch_sizes ( const std::vector< Tensor > & tensors)

Find the broadcast batch shape of all the tensors The returned batch shape will be traceable.

See also
neml2::TraceableTensorShape

Pre-pad ones to the shapes

Braodcast

◆ broadcast_sizes()

template<class... T>
TensorShape broadcast_sizes ( const T &... shapes)

Return the broadcast shape of all the shapes.

◆ demangle()

std::string demangle ( const char * name)

Demangle a piece of cxx abi type information.

◆ end_with()

bool end_with ( std::string_view str,
std::string_view suffix )

◆ extract_batch_sizes()

TraceableTensorShape extract_batch_sizes ( const torch::Tensor & tensor,
Size batch_dim )

Extract the batch shape of a tensor given batch dimension The extracted batch shape will be traceable.

See also
neml2::TraceableTensorShape

◆ indentation()

std::string indentation ( int level,
int indent )

◆ join()

std::string join ( const std::vector< std::string > & strs,
const std::string & delim )

◆ operator>>()

std::stringstream & operator>> ( std::stringstream & in,
torch::Tensor &  )

This is a dummy to prevent compilers whining about not know how to >> torch::Tensor.

◆ pad_prepend() [1/2]

torch::Tensor pad_prepend ( const torch::Tensor & s,
Size dim,
Size pad )

◆ pad_prepend() [2/2]

TensorShape pad_prepend ( TensorShapeRef s,
Size dim,
Size pad = 1 )

Pad shape s to dimension dim by prepending sizes of pad.

Parameters
sThe original shape to pad
dimThe resulting dimension
padThe values used to pad the shape, default to 1
Returns
TensorShape The padded shape with dimension dim

◆ parse()

template<typename T >
T parse ( const std::string & raw_str)

◆ parse_() [1/4]

template<>
void parse_ ( bool & val,
const std::string & raw_str )

◆ parse_() [2/4]

template<typename T >
void parse_ ( T & val,
const std::string & raw_str )

◆ parse_() [3/4]

template<>
void parse_ ( TensorShape & val,
const std::string & raw_str )

◆ parse_() [4/4]

template<>
void parse_ ( VariableName & val,
const std::string & raw_str )

◆ parse_< bool >()

template<>
void parse_< bool > ( bool & ,
const std::string & raw_str )

◆ parse_< TensorShape >()

template<>
void parse_< TensorShape > ( TensorShape & ,
const std::string & raw_str )

◆ parse_< VariableName >()

template<>
void parse_< VariableName > ( VariableName & ,
const std::string & raw_str )

◆ parse_vector()

template<typename T >
std::vector< T > parse_vector ( const std::string & raw_str)

◆ parse_vector_() [1/2]

template<>
void parse_vector_ ( std::vector< bool > & vals,
const std::string & raw_str )

◆ parse_vector_() [2/2]

template<typename T >
void parse_vector_ ( std::vector< T > & vals,
const std::string & raw_str )

◆ parse_vector_< bool >()

template<>
void parse_vector_< bool > ( std::vector< bool > & ,
const std::string & raw_str )

This special one is for the evil std::vector<bool>!

◆ parse_vector_vector()

template<typename T >
std::vector< std::vector< T > > parse_vector_vector ( const std::string & raw_str)

◆ parse_vector_vector_()

template<typename T >
void parse_vector_vector_ ( std::vector< std::vector< T > > & vals,
const std::string & raw_str )

◆ same_device()

torch::Device same_device ( const std::vector< Tensor > & tensors)

Make sure all tensors have the same device and return the common device.

◆ same_dtype()

torch::Dtype same_dtype ( const std::vector< Tensor > & tensors)

Make sure all tensors have the same dtype and return the common dtype.

◆ sizes_broadcastable()

template<class... T>
bool sizes_broadcastable ( const T &... shapes)

Check if the shapes are broadcastable.

Shapes are said to be broadcastable if, starting from the trailing dimension and iterating backward, the dimension sizes either are equal, one of them is 1, or one of them does not exist.

◆ sizes_same()

template<class... T>
bool sizes_same ( T &&... shapes)

Check if all shapes are the same.

◆ split()

std::vector< std::string > split ( const std::string & str,
const std::string & delims )

◆ start_with()

bool start_with ( std::string_view str,
std::string_view prefix )

◆ storage_size()

Size storage_size ( TensorShapeRef shape)

The flattened storage size of a tensor with given shape.

For example,

storage_size({}) == 1;
storage_size({0}) == 0;
storage_size({1}) == 1;
storage_size({1, 2, 3}) == 6;
storage_size({5, 1, 1}) == 5;
Size storage_size(TensorShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:55

◆ stringify() [1/2]

template<>
std::string stringify ( const bool & t)
inline

◆ stringify() [2/2]

template<typename T >
std::string stringify ( const T & t)

◆ trim()

std::string trim ( const std::string & str,
const std::string & white_space )