Bug 61: Attention Deficit Disorder
- This code implements the forward pass of transformer multi-head scaled dot-product attention.
- m is the number of embeddings, h is the number of heads, and n is the per-head query/key/value length.
- x is a row-major in/out matrix of m embeddings (m rows); each embedding is h*n floats (h*n columns).
- wq, wk, wv, wo are row-major query, key, value, output weight matrices; each has h*n rows, h*n columns.
- q is the row-major query matrix (m rows, h*n columns). q is the product of x and the transpose of wq.
- k is the row-major key matrix (m rows, h*n columns). k is the product of x and the transpose of wk.
- v is the row-major value matrix (h*n rows, m columns). v is the product of wv and the transpose of x.
- q is used as h disjoint submatrices (one submatrix per head); each submatrix has m rows, n columns.
- k is used as h disjoint submatrices (one submatrix per head); each submatrix has m rows, n columns.
- v is used as h disjoint submatrices (one submatrix per head); each submatrix has n rows, m columns.
- Each head's submatrix of q is multiplied against the transpose of that head's submatrix of k.
- Each element of the resulting attention weight matrix (m rows, m columns) is divided by sqrt(n).
- Softmax (with subtraction of the maximum, for numerical stability) is applied to each row separately.
- Each head's final matrix is the product of its softmax matrix and the transpose of its v submatrix.
- Final matrices for all h heads are stored as submatrices of row-major matrix y (m rows, h*n columns).
- x is overwritten with the product of y and the transpose of matrix wo (mixing final matrices along rows).
Fix The Tiny Bug In This Go Code: