mini_jit::einsum

struct EinsumNode

Public Functions

inline EinsumNode(std::vector<int64_t> const &output_dimension_ids, std::string tensor_expression, EinsumNode *left_child, EinsumNode *right_child)
inline ~EinsumNode()
inline int64_t get_number_of_children() const

Public Members

std::vector<int64_t> m_output_dimension_ids

The IDs of the dimensions in the output tensor.

std::vector<int64_t> m_dimension_ids

The IDs of the dimensions in the operation.

mini_jit::dtype_t m_dtype = mini_jit::dtype_t::fp32

The data type of the tensor.

mini_jit::ptype_t m_prim_first_touch = mini_jit::ptype_t::none

Primititve type for the first touch kernel.

mini_jit::ptype_t m_prim_main = mini_jit::ptype_t::none

Primitive type for the main kernel.

mini_jit::ptype_t m_prim_last_touch = mini_jit::ptype_t::none

Primitive type for the last touch kernel.

std::vector<mini_jit::dim_t> m_dim_types

Dimension types of the loops (m, n, k, c)

std::vector<mini_jit::exec_t> m_exec_types

Execution types of the loops (seq, shared, prim)

std::vector<int64_t> m_dim_sizes

Sizes of the dimensions (loops)

std::vector<int64_t> m_strides_in0

Strides of the first input tensor.

std::vector<int64_t> m_strides_in1

Strides of the second input tensor.

std::vector<int64_t> m_strides_out

Strides of the output tensor.

int64_t m_tensor_size = 1

Size of the output tensor.

void *m_tensor_out = nullptr

The output tensor for this node.

mini_jit::TensorOperation m_operation

The tensor operation associated with this node.

std::string m_tensor_expression = ""

String representation of the einsum expression.

EinsumNode *m_left_child = nullptr

The left child node in the einsum tree.

EinsumNode *m_right_child = nullptr

The right child node in the einsum tree.

double m_computational_operations = 0.0

The number of operations performed by this node.

class EinsumTree
#include <EinsumTree.h>

The EinsumTree class provides methods to transform a string einsum expression into a tree structure based on nodes and tensor operations.

Public Functions

EinsumTree() = delete

Deleted constructor to prevent instantiation of the static EinsumTree class.

Public Static Functions

static EinsumNode *parse_einsum_expression(std::string const &einsum_expression, std::vector<int64_t> &dimension_sizes)

Parses the einsum expression and creates an einsum tree. In case the dimensions are not in an optimal order, permutation nodes will be inserted.

Parameters:
  • einsum_expression – The string representation of the einsum operation.

  • dimension_sizes – A vector to store the sizes of the dimensions used in the expression.

Returns:

The root EinsumNode representing the output of the parsed expression.

static void execute(EinsumNode *root_node, std::vector<int64_t> &dimension_sizes, std::map<std::string, void const*> &tensor_inputs)

Executes the einsum tree using tensor operations.

Parameters:
  • root_node – The root EinsumNode of the einsum tree.

  • dimension_sizes – A vector containing the sizes of the dimensions used in the expression.

  • tensor_inputs – A map containing the input tensors for the einsum operation.

static void optimize_einsum_nodes(EinsumNode *root_node, int64_t thread_target, int64_t max_kernel_size, int64_t min_kernel_size)

Optimize the given einsum tree by applying transformations to the dimensions, sizes and strides.

Parameters:
  • root_node – The root node of the einsum tree.

  • thread_target – The target number of threads for parallel execution.

  • max_kernel_size – The maximum size of the kernel to be used.

  • min_kernel_size – The minimum size of the kernel to be used.

static void lower_einsum_nodes_to_tensor_operations(EinsumNode *root_node, std::vector<int64_t> &dimension_sizes, mini_jit::dtype_t dtype)

Lower the given einsum tree to executable tensor operations.

Parameters:
  • root_node – The root node of the einsum tree.

  • dimension_sizes – The array with the dimension sizes sorted by id.

  • dtype – The data type of the tensor.

static std::string to_string(EinsumNode *root_node)

Convert the einsum tree to a string representation.

Parameters:

root_node – The root node of the einsum tree.

Returns:

A string representation of the einsum tree.