缓存遗忘算法

外部内存模型中,有两种类型的高效算法:

  • 一种是缓存感知算法(Cache-aware algorithms),已知缓存块大小BB和内存总大小MM
  • 一种是缓存遗忘算法(Cache-oblivious algorithms),则不需要知道特定的缓存参数下,在任何BBMM情况下都能高效运行;

例如,外部合并排序就是一种缓存感知算法,而非缓存遗忘算法:我们需要知道系统的内存特性,即可用内存与块大小的比例,以找到执行kk路合并排序的正确kk值。

缓存遗忘算法很有趣,因为无论在缓存层次结构中的哪一级,它们都会自动变得最优,而不仅仅是在为其特别调整的那一级。在这篇文章中,我们将探讨它在矩阵计算中的一些应用。

矩阵转置

假设我们有一个大小为N×NN×N的方阵AA,我们需要对它进行转置。根据定义实现的朴素方法大概会是这样:

for (int i = 0; i < n; i++)
for (int j = 0; j < i; j++)
swap(a[j * N + i], a[i * N + j]);

这里我们使用了一个指向内存起始区域的单一指针,而不是二维数组,以更明确地说明其内存操作。

这段代码的I/O复杂性为O(N2)O(N^2),这是由于在执行算法的过程中,数据不是按照存储顺序依次处理的,而是跳跃式地读取和写入,这就产生了大量的I/O开销。即使尝试更改迭代变量的顺序,情况可能会发生变化(如从行优先变为列优先),但总体的I/O复杂性并没有改变,仍然是O(N2)O(N^2)

算法

缓存遗忘算法依赖于以下的块矩阵等式:

(ABCD)T=(ATCTBTDT){\bigg(\begin{matrix} A & B \\ C & D \end{matrix}\bigg)}^T = \bigg(\begin{matrix} A^T & C^T \\ B^T & D^T \end{matrix}\bigg)

我们可以使用分治法(divide-and-conquer,D&C)递归地解决问题:

  1. 将输入矩阵分为4个较小的矩阵。
  2. 递归地对每一个进行转置。
  3. 通过交换反对角的子矩阵来合并结果。

在矩阵上实现分治法比在数组上稍微复杂一些,但是主要的思想是一样的。这里使用了“视图”(views)这种技术,它可以直接在原始数据结构上创建一个引用,而不需要创建新的数据副本,这样就可以降低数据操作的开销。另外,当数据规模变小到可以被放入L1缓存(即最接近CPU的缓存级别,当你不知道cache尺寸,可以选择像32×3232×32这个的小尺寸)时,可以切换到更简单的方法(如朴素方法)进行处理。然而,分治法的一个挑战是,如果矩阵尺寸为奇数,我们就无法将其等分为四个等大子矩阵,所以需要特别处理这种情况。

void transpose(int *a, int n, int N) {
if (n <= 32) {
for (int i = 0; i < n; i++)
for (int j = 0; j < i; j++)
swap(a[i * N + j], a[j * N + i]);
} else {
int k = n / 2;

transpose(a, k, N);
transpose(a + k, k, N);
transpose(a + k * N, k, N);
transpose(a + k * N + k, k, N);

for (int i = 0; i < k; i++)
for (int j = 0; j < k; j++)
swap(a[i * N + (j + k)], a[(i + k) * N + j]);

if (n & 1)
for (int i = 0; i < n - 1; i++)
swap(a[i * N + n - 1], a[(n - 1) * N + i]);
}
}

这种算法的I/O复杂度是O(N2B)O(\frac{N^2}{B})(除BB意味着我们一次I/O,就可以处理BB个元素)。在每个合并阶段,我们仅需要操作大约一半的内存块,这意味着在每个阶段,我们的问题会变得更小。

将这段代码进行调整以适应非方形矩阵的情况,留给读者自己完成。

矩阵乘法

接下来,让我们考虑稍微负载一些的东西:矩阵乘法。

Cij=kAikBkjC_{ij} = \sum_{k}A_{ik}B_{kj}

朴素方法知识将其定义转换为算法:

// don't forget to initialize c[][] with zeroes
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
for (int k = 0; k < n; k++)
c[i * n + j] += a[i * n + k] * b[k * n + j];

它需要总共访问O(N3)O(N^3)的块,因为每个标量乘法需要读取一个单独的块。

一个众所周知的优化是首先转置BB

for (int i = 0; i < n; i++)
for (int j = 0; j < i; j++)
swap(b[j][i], b[i][j])
// ^ or use our faster transpose from before

for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
for (int k = 0; k < n; k++)
c[i * n + j] += a[i * n + k] * b[j * n + k]; // <- note the indices

无论是以朴素方式进行转置还是使用我们之前开发的缓存遗忘方法进行转置,只要其中一个矩阵被转置,乘法就能在O(N3/B+N2)O(N^3/B+N^2)的时间复杂度内工作,因为所有的内存访问现在都是顺序的。

看起来我们似乎没有进一步优化空间了,但事实证明我们可以。

算法

缓存遗忘的矩阵乘法基本上和转置依赖相同的技巧。我们需要划分数据,直到它们能够适应最低级别的缓存(即N2MN^2≤M)。对于矩阵乘法而言,这表示我们需要使用以下公式:

(A11A12A21A22)(B11B12B21B22)=(A11B11+A12B21A11B12+A12B22A21B11+A22B21A21B12+A22B22)\bigg(\begin{matrix} A_{11} & A_{12} \\ A_{21} & A_{22} \end{matrix}\bigg) \bigg(\begin{matrix} B_{11} & B_{12} \\ B_{21} & B_{22} \end{matrix}\bigg) = \bigg(\begin{matrix} A_{11}B_{11} + A_{12}B_{21} & A_{11}B_{12} + A_{12}B_{22} \\ A_{21}B_{11} + A_{22}B_{21} & A_{21}B_{12} + A_{22}B_{22} \end{matrix}\bigg)

这个实现起来稍微困难一些,因为我们现在总共有8个递归矩阵乘法:

void matmul(const float *a, const float *b, float *c, int n, int N) {
if (n <= 32) {
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
for (int k = 0; k < n; k++)
c[i * N + j] += a[i * N + k] * b[k * N + j];
} else {
int k = n / 2;

// c11 = a11 b11 + a12 b21
matmul(a, b, c, k, N);
matmul(a + k, b + k * N, c, k, N);

// c12 = a11 b12 + a12 b22
matmul(a, b + k, c + k, k, N);
matmul(a + k, b + k * N + k, c + k, k, N);

// c21 = a21 b11 + a22 b21
matmul(a + k * N, b, c + k * N, k, N);
matmul(a + k * N + k, b + k * N, c + k * N, k, N);

// c22 = a21 b12 + a22 b22
matmul(a + k * N, b + k, c + k * N + k, k, N);
matmul(a + k * N + k, b + k * N + k, c + k * N + k, k, N);

if (n & 1) {
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
for (int k = (i < n - 1 && j < n - 1) ? n - 1 : 0; k < n; k++)
c[i * N + j] += a[i * N + k] * b[k * N + j];
}
}
}

因为这里涉及许多其它因素,我们并不打算对这个实现进行基准测试,而只是在外部存储器模型中进行其理论性能分析。

分析

算法的计算复杂度保持不变,因为递归式

T(N)=8T(N/2)+Θ(N2)T(N) = 8\cdot T(N/2) + \varTheta(N^2)

的解为T(N)=Θ(N3)T(N) = \varTheta(N^3)(求解原理见递归式求解)。

我们似乎没有解决任何事情,但是让我们考虑一下它的I/O复杂性:

T(N)={O(N2B)NM (we only need to read it)8T(N/2)+O(N2B)otherwiseT(N)=\begin{cases} O(\frac{N^2}{B}) & N \leq\sqrt{M} \text{ (we only need to read it)} \\ 8\cdot T(N/2) + O(\frac{N^2}{B}) &\text{otherwise} \end{cases}

递归主要由O((NM)3)O((\frac{N}{\sqrt{M}})^3) 次基本情况(递归终止情况)主导,这意味着总的复杂度是:

T(N)=O((M)2B(NM)3)=O(N3BM)T(N) = O\bigg( \frac{(\sqrt{M})^2}{B} \cdot {\big( \frac{N}{\sqrt{M}} \big)}^3 \bigg) = O(\frac{N^3}{B\sqrt{M}})

这比仅仅是O(N3B)O(\frac{N^3}{B})要好的多。

Strassen 算法

像Karatsuba算法一样,矩阵乘法也可以分解为7个n2\frac{n}{2}大小的矩阵乘法,将问题等分并应用分治策略可以将时间复杂度降低到O(nlog27)O(n2.81)O(n^{log_2 7})≈O(n^{2.81})。外部存储模型的渐近复杂度也是类似的。

这种被称为Strassen算法的技术类似地将每个矩阵划分为4个子矩阵:

(C11C12C21C22)=(A11A12A21A22)(B11B12B21B22)\bigg(\begin{matrix} C_{11} & C_{12} \\ C_{21} & C_{22} \end{matrix}\bigg) = \bigg(\begin{matrix} A_{11} & A_{12} \\ A_{21} & A_{22} \end{matrix}\bigg) \bigg(\begin{matrix} B_{11} & B_{12} \\ B_{21} & B_{22} \end{matrix}\bigg)

但后来它计算了N2×N2\frac{N}{2}×\frac{N}{2}的中间结果矩阵,并组合它们得到矩阵C:

M1=(A11+A22)(B11+B22)C11=M1+M4M5+M7M2=(A21+A22)B11C12=M3+M5M3=A11(B21B22)C21=M2+M4M4=A22(B21B11)C22=M1M2+M3+M6M5=(A11+A12)B22M6=(A21A11)(B11+B12)M7=(A12A22)(B21+B22)\begin{align*} M_1 &= (A_{11} + A_{22})(B_{11} + B_{22}) \quad & C_{11} &= M_1+M_4-M_5+M_7 \\ M_2 &= (A_{21} + A_{22})B_{11} \quad & C_{12} &= M_3+M_5 \\ M_3 &= A_{11}(B_{21} - B_{22}) \quad & C_{21} &= M_2+M_4 \\ M_4 &= A_{22}(B_{21} - B_{11}) \quad & C_{22} &= M_1 - M_2 + M_3 + M_6 \\ M_5 &= (A_{11} + A_{12})B_{22}\\ M_6 &= (A_{21} - A_{11})(B_{11} + B_{12}) \\ M_7 &= (A_{12} - A_{22})(B_{21} + B_{22}) \end{align*}

你可以通过简单的替换来验证这些公式的正确性。

据我所知,没有主流的线性代数优化库使用Strassen算法,尽管有一些原型实现对大于2000的矩阵是有效的。

实际上这种技术已经被多次扩展,通过考虑更多的子矩阵乘积来进一步降低渐近复杂性。截至2020年,目前的世界记录是O(n2.3728596)O(n^{2.3728596})。关于是否能实现以O(n2)O(n^2)或者至少O(n2logkn)O(n^2 log^k n)的复杂度来进行矩阵乘法,目前还没有答案。

拓展阅读

建议读者阅读Erik Demaine的论文《 Cache-Oblivious Algorithms and Data Structures 》以获取更扎实的理论观点。