voidtranspose(int *a, int n, int N){ if (n <= 32) { for (int i = 0; i < n; i++) for (int j = 0; j < i; j++) swap(a[i * N + j], a[j * N + i]); } else { int k = n / 2;
transpose(a, k, N); transpose(a + k, k, N); transpose(a + k * N, k, N); transpose(a + k * N + k, k, N); for (int i = 0; i < k; i++) for (int j = 0; j < k; j++) swap(a[i * N + (j + k)], a[(i + k) * N + j]); if (n & 1) for (int i = 0; i < n - 1; i++) swap(a[i * N + n - 1], a[(n - 1) * N + i]); } }
// don't forget to initialize c[][] with zeroes for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) for (int k = 0; k < n; k++) c[i * n + j] += a[i * n + k] * b[k * n + j];
它需要总共访问O(N3)的块,因为每个标量乘法需要读取一个单独的块。
一个众所周知的优化是首先转置B。
for (int i = 0; i < n; i++) for (int j = 0; j < i; j++) swap(b[j][i], b[i][j]) // ^ or use our faster transpose from before
for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) for (int k = 0; k < n; k++) c[i * n + j] += a[i * n + k] * b[j * n + k]; // <- note the indices
voidmatmul(constfloat *a, constfloat *b, float *c, int n, int N){ if (n <= 32) { for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) for (int k = 0; k < n; k++) c[i * N + j] += a[i * N + k] * b[k * N + j]; } else { int k = n / 2;
// c11 = a11 b11 + a12 b21 matmul(a, b, c, k, N); matmul(a + k, b + k * N, c, k, N); // c12 = a11 b12 + a12 b22 matmul(a, b + k, c + k, k, N); matmul(a + k, b + k * N + k, c + k, k, N); // c21 = a21 b11 + a22 b21 matmul(a + k * N, b, c + k * N, k, N); matmul(a + k * N + k, b + k * N, c + k * N, k, N); // c22 = a21 b12 + a22 b22 matmul(a + k * N, b + k, c + k * N + k, k, N); matmul(a + k * N + k, b + k * N + k, c + k * N + k, k, N);
if (n & 1) { for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) for (int k = (i < n - 1 && j < n - 1) ? n - 1 : 0; k < n; k++) c[i * N + j] += a[i * N + k] * b[k * N + j]; } } }