func qkv(m, h, n int, x, wq, wk, wv, q, k, v []float32) { var ( hn = h * n sincos = sincos(m, n) ) for i1 := 0; i1 < m; i1++ { i2 := 0 for i3 := 0; i3 < hn; i3++ { var ( q1, q2 float32 k1, k2 float32 v1, v2 float32 ) for i4 := 0; i4 < hn; i4++ { x := x[i1*hn+i4] q1 += x * wq[i3*hn+i4] q2 += x * wq[i3*hn+i4+hn] k1 += x * wk[i3*hn+i4] k2 += x * wk[i3*hn+i4+hn] v1 += x * wv[i3*hn+i4] v2 += x * wv[i3*hn+i4+hn] } var ( sin = sincos[i1*n+i2+0] cos = sincos[i1*n+i2+1] ) if i2 += 2; i2 == n { i2 = 0 } q[i1*hn+i3+0] = q1*cos - q2*sin q[i1*hn+i3+1] = q1*sin + q2*cos k[i1*hn+i3+0] = k1*cos - k2*sin k[i1*hn+i3+1] = k1*sin + k2*cos v[i3*m+i1+0] = v1 v[i3*m+i1+m] = v2 } } } var ( mut sync.Mutex lut = make(map[[2]int][]float32) ) func sincos(m, n int) []float32 { mut.Lock() var ( key = [2]int{m, n} ret = lut[key] ) if ret == nil { ret = make([]float32, m*n) lut[key] = ret recip := -1 / float64(n) for i1 := 0; i1 < m; i1++ { for i2 := 0; i2 < n; i2 += 2 { const base = 10000 exp := float64(i2) * recip angle := math.Pow(base, exp) angle *= float64(i1) sin, cos := math.Sincos(angle) ret[i1*n+i2+0] = float32(sin) ret[i1*n+i2+1] = float32(cos) } } } mut.Unlock() return ret }
To receive a hint, submit unfixed code.