Background on GEMM

A simplified viewpoint

Consider a simplified version of GEMM with conformable real linear operators \((\mathbf{A},\mathbf{B},\mathbf{C})\):

\[\displaystyle\mathbf{C} = \alpha \cdot\,\operatorname{op}(\mathbf{A})\, \cdot \,\operatorname{op}(\mathbf{B}) + \,\beta \cdot \mathbf{C}.\]

Here, \(\operatorname{op}(\cdot)\) can return its argument either unchanged or transposed. Its action on \(\mathbf{A}\) and \(\mathbf{B}\) is determined by contextual information in the form of flags. The flag for \(\mathbf{A}\) is traditionally called \(\text{“}\texttt{opA}\text{”}\) and is interpreted as

\[\begin{split}\operatorname{op}(\mathbf{A}) = \begin{cases} \mathbf{A} & \text{ if } \texttt{opA} \texttt{ == NoTrans} \\ \mathbf{A}^T & \text{ if } \texttt{opA} \texttt{ == Trans} \end{cases}.\end{split}\]

The flag for \(\mathbf{B}\) is traditionally named \(\text{“}\texttt{opB}\text{”}\) and is interpreted similarly.

An accurate description

The GEMM API accepts dimensions \((m, n, k)\), pointers \((A, B, C)\), and executes

(1)\[\displaystyle\operatorname{mat}(C) = \alpha \cdot\, \underbrace{\operatorname{op}(\operatorname{mat}(A))}_{m \times k}\, \cdot \,\underbrace{\operatorname{op}(\operatorname{mat}(B))}_{k \times n} + \,\beta \cdot\underbrace{\operatorname{mat}(C)}_{m \times n},\]

where \(\operatorname{mat}(\cdot)\) accepts a pointer and returns a matrix based on the following contextual information:

  • explicit or inferred dimensions (considering \(\text{(}m, n, k\text{)}\) and \(\text{(}\texttt{opA},\texttt{opB}\text{)}\) in Eq. 1),

  • a stride parameter associated with the pointer, and

  • a layout parameter that applies to all three matrices in Eq. 1.

We use the \(\text{“}\operatorname{mat}\text{”}\) operator only to help with exposition. No such operator appears in the GEMM API. For reference, here is a standard function signature for a version of GEMM that requires all three matrices in Eq. 1 to have a common numerical type, T.

template <typename T>
gemm(
  blas::Layout ell, blas::Op opA, blas::Op opB, int m, int n, int k,
  T alpha, const T* A, int lda, const T* B, int ldb, T beta, T* C, int ldc
)

A complete explanation of how \(\operatorname{mat}\) extracts submatrices from this contextual information is given below.

Details on \(\operatorname{mat}(\cdot)\)

The semantics of \(\operatorname{mat}\) can be understood by focusing on \(\mathbf{A} = \operatorname{mat}(A)\). First, there is the matter of the dimensions. These are inferred from \((m, k)\) and from \(\texttt{opA}\) in the way indicated by Eq. 1.

  • If \(\texttt{opA} \texttt{ == NoTrans}\), then \(\mathbf{A}\) is \(m \times k\).

  • If \(\texttt{opA} \texttt{ == Trans }\), then \(\mathbf{A}\) is \(k \times m\).

Moving forward let us say that \(\mathbf{A}\) is \(r \times c\). The actual contents of \(\mathbf{A}\) are determined by the pointer, \(A\text{,}\) an explicitly declared stride parameter, \(\texttt{lda}\text{,}\) and a layout parameter, \(\texttt{ell}\text{,}\) according to the rule

\[\begin{split}\mathbf{A}_{i,j} = \begin{cases} A[\,i + j \cdot \texttt{lda}\,] & \text{ if } \texttt{ell == ColMajor} \\ A[\,i \cdot \texttt{lda} + j\,] & \text{ if } \texttt{ell == RowMajor} \end{cases}\end{split}\]

where we zero-index \(\mathbf{A}\) for consistency with indexing into buffers in C/C++.

Only the leading \(r \times c\) submatrix of \(\operatorname{mat}(A)\) will be accessed in computing Eq. 1. Note that in order for this submatrix to be well-defined it’s necessary that

\[\begin{split}\texttt{lda} \geq \begin{cases} r & \text{ if } \texttt{ell == ColMajor} \\ c & \text{ if } \texttt{ell == RowMajor} \end{cases}.\end{split}\]

Most performance libraries check that this is the case on entry to GEMM and will raise an error if this condition isn’t satisfied.