-
Notifications
You must be signed in to change notification settings - Fork 151
Problem Nomenclature
This page is deprecated as of 2024-09-30 and will be removed in ROCm 6.4. New documentation is under active development.
- Standard GEMM has 4 variants (2 free indices (i, j) and 1 summation index l)
-
N(N:nontranspose)N: C[i,j] = Sum[l] A[i,l] * B[l,j]
-
NT(T:transpose): C[i,j] = Sum[l] A[i,l] * B[j, l]
-
TN: C[i,j] = Sum[l] A[l, i] * B[l,j]
-
TT: C[i,j] = Sum[l] A[l, i] * B[j, l]
- C[i,j,k] = Sum[l] A[i,l,k] * B[l,j,k] (batched-GEMM; 2 free indices, 1 batched index k and 1 summation index l)
- C[i,j] = Sum[k,l] A[i,k,l] * B[j,l,k] (2D summation)
- C[i,j,k,l,m] = Sum[n] A[i,k,m,l,n] * B[j,k,l,n,m] (GEMM with 3 batched indices)
- C[i,j,k,l,m] = Sum[n,o] A[i,k,m,o,n] * B[j,m,l,n,o] (4 free indices, 2 summation indices and 1 batched index)
- C[i,j,k,l] = Sum[m,n] A[i,j,m,n,l] * B[m,n,k,j,l] (batched image convolution mapped to 7D tensor contraction)
- and even crazier
The indices describe the dimensionality of the problem being solved. A GEMM operation takes 2 2-dimensional matrices as input (totaling 4 input dimensions) and contracts them along one dimension (which cancels out 2 of the dimensions), resulting in a 2-dimensional result.
Whenever an index shows up in multiple tensors, those tensors must be the same size along that dimension but they may have different strides.
There are 3 categories of indices/dimensions that Tensile deals with: free, batch and bound.
Free indices are the indices of tensor C which come in pairs; one of the pair shows up in tensor A while the other shows up in tensor B. In the really crazy example above, i/j/k/l are the 4 free indices of tensor C. Indices i and k come from tensor A and indices j and l come from tensor B.
Batch indices are the indices of tensor C which shows up in both tensor A and tensor B. For example, the difference between the GEMM example and the batched-GEMM example above is the additional index. In the batched-GEMM example, the index K is the batch index which is batching together multiple independent GEMMs.
The final type of indices are called bound indices or summation indices. These indices do not show up in tensor C; they show up in the summation symbol (Sum[k]) and in tensors A and B. It is along these indices that we perform the inner products (pairwise multiply then sum).
Problem supported by Tensile must meet the following conditions:
- There must be at least one pair of free indices.