mini_jit::einsum
-
struct EinsumNode
Public Functions
-
inline EinsumNode(std::vector<int64_t> const &output_dimension_ids, std::string tensor_expression, EinsumNode *left_child, EinsumNode *right_child)
-
inline ~EinsumNode()
-
inline int64_t get_number_of_children() const
Public Members
-
std::vector<int64_t> m_output_dimension_ids
The IDs of the dimensions in the output tensor.
-
std::vector<int64_t> m_dimension_ids
The IDs of the dimensions in the operation.
-
mini_jit::dtype_t m_dtype = mini_jit::dtype_t::fp32
The data type of the tensor.
-
mini_jit::ptype_t m_prim_first_touch = mini_jit::ptype_t::none
Primititve type for the first touch kernel.
-
mini_jit::ptype_t m_prim_main = mini_jit::ptype_t::none
Primitive type for the main kernel.
-
mini_jit::ptype_t m_prim_last_touch = mini_jit::ptype_t::none
Primitive type for the last touch kernel.
-
std::vector<mini_jit::dim_t> m_dim_types
Dimension types of the loops (m, n, k, c)
-
std::vector<mini_jit::exec_t> m_exec_types
Execution types of the loops (seq, shared, prim)
-
std::vector<int64_t> m_dim_sizes
Sizes of the dimensions (loops)
-
std::vector<int64_t> m_strides_in0
Strides of the first input tensor.
-
std::vector<int64_t> m_strides_in1
Strides of the second input tensor.
-
std::vector<int64_t> m_strides_out
Strides of the output tensor.
-
int64_t m_tensor_size = 1
Size of the output tensor.
-
void *m_tensor_out = nullptr
The output tensor for this node.
-
mini_jit::TensorOperation m_operation
The tensor operation associated with this node.
-
std::string m_tensor_expression = ""
String representation of the einsum expression.
-
EinsumNode *m_left_child = nullptr
The left child node in the einsum tree.
-
EinsumNode *m_right_child = nullptr
The right child node in the einsum tree.
-
double m_computational_operations = 0.0
The number of operations performed by this node.
-
inline EinsumNode(std::vector<int64_t> const &output_dimension_ids, std::string tensor_expression, EinsumNode *left_child, EinsumNode *right_child)
-
class EinsumTree
- #include <EinsumTree.h>
The EinsumTree class provides methods to transform a string einsum expression into a tree structure based on nodes and tensor operations.
Public Functions
-
EinsumTree() = delete
Deleted constructor to prevent instantiation of the static EinsumTree class.
Public Static Functions
-
static EinsumNode *parse_einsum_expression(std::string const &einsum_expression, std::vector<int64_t> &dimension_sizes)
Parses the einsum expression and creates an einsum tree. In case the dimensions are not in an optimal order, permutation nodes will be inserted.
- Parameters:
einsum_expression – The string representation of the einsum operation.
dimension_sizes – A vector to store the sizes of the dimensions used in the expression.
- Returns:
The root EinsumNode representing the output of the parsed expression.
-
static void execute(EinsumNode *root_node, std::vector<int64_t> &dimension_sizes, std::map<std::string, void const*> &tensor_inputs)
Executes the einsum tree using tensor operations.
- Parameters:
root_node – The root EinsumNode of the einsum tree.
dimension_sizes – A vector containing the sizes of the dimensions used in the expression.
tensor_inputs – A map containing the input tensors for the einsum operation.
-
static void optimize_einsum_nodes(EinsumNode *root_node, int64_t thread_target, int64_t max_kernel_size, int64_t min_kernel_size)
Optimize the given einsum tree by applying transformations to the dimensions, sizes and strides.
- Parameters:
root_node – The root node of the einsum tree.
thread_target – The target number of threads for parallel execution.
max_kernel_size – The maximum size of the kernel to be used.
min_kernel_size – The minimum size of the kernel to be used.
-
static void lower_einsum_nodes_to_tensor_operations(EinsumNode *root_node, std::vector<int64_t> &dimension_sizes, mini_jit::dtype_t dtype)
Lower the given einsum tree to executable tensor operations.
- Parameters:
root_node – The root node of the einsum tree.
dimension_sizes – The array with the dimension sizes sorted by id.
dtype – The data type of the tensor.
-
static std::string to_string(EinsumNode *root_node)
Convert the einsum tree to a string representation.
- Parameters:
root_node – The root node of the einsum tree.
- Returns:
A string representation of the einsum tree.
-
EinsumTree() = delete