func attention(m, h, n int, x, wq, wk, wv, wo []float32) { var ( hn = h * n q = make([]float32, m*hn) k = make([]float32, m*hn) v = make([]float32, hn*m) ) for i1 := 0; i1 < m; i1++ { for i2 := 0; i2 < hn; i2++ { var qq, kk, vv float32 for i3 := 0; i3 < hn; i3++ { x := x[i2*hn+i3] qq += x * wq[i2*hn+i3] kk += x * wk[i2*hn+i3] vv += x * wv[i2*hn+i3] } q[i1*hn+i2] = qq k[i1*hn+i2] = kk v[i2*m+i1] = vv } } var ( s = 1 / float32(math.Sqrt(float64(n))) a = make([]float32, m) y = make([]float32, m*hn) ) for i1 := 0; i1 < hn; i1 += n { for i2 := 0; i2 < m; i2++ { max := float32(-math.MaxFloat32) for i3 := 0; i3 < m; i3++ { var aa float32 for i4 := i1; i4 < i1+n; i4++ { q := q[i2*hn+i4] aa += q * k[i3*hn+i4] } if aa *= s; max < aa { max = aa } a[i3] = aa } var denom float32 for i3 := 0; i3 < m; i3++ { a1 := float64(a[i3] - max) a2 := float32(math.Exp(a1)) a[i3] = a2 denom += a2 } rcp := 1 / denom for i3 := 0; i3 < m; i3++ { a[i3] *= rcp } for i3 := i1; i3 < i1+n; i3++ { var yy float32 for i4 := 0; i4 < m; i4++ { v := v[i3*m+i4] yy += a[i4] * v } y[i2*hn+i3] = yy } } } for i1 := 0; i1 < m; i1++ { for i2 := 0; i2 < hn; i2++ { var o float32 for i3 := 0; i3 < hn; i3++ { y := y[i1*hn+i3] o += y * wo[i2*hn+i3] } x[i1*hn+i2] = o } } }
To receive a hint, submit unfixed code.