前言

Strassen最早于1968年由Volker Srassen发表于论文《Gaussian Elimination is not Optimal》,将矩阵乘法的算法复杂度首次从O(n3),where log2(8)=3O(n^3),where\space log_2(8)=3降低到O(n2.807),where log2(7)2.807O(n^2.807),where\space log_2(7)≈2.807。后来有不同的方法基于此算法进行改进。

MNN Strassenconv1x1采用的是2008年发表的这边文章中的算法《Memory efficient scheduling of Strassen-Winograd’s matrix multiplication algorithm》,主要针对memory访问上进行优化。

原理

普通矩阵乘

C=ABC=AB为例,假设AABB均为n×nn × n矩阵,则CC中每个元素需要通过以下方式计算:

cij=k=1naikbkjc_{ij}=\displaystyle\sum_{k=1}^na_{ik}b_{kj}

即每个元素计算需要nn次乘法,n1n-1次加法,C共有n2n^2个元素,故一共需要n3n^3次乘法和n2(n1)n^2(n-1)次加法,故算法复杂度为O(n3)O(n^3)

分治法

Strassen算法基于分治的思想,因此我们首先考虑一个简单的分治策略。

假设矩阵AA和矩阵BB都是n×n,(n=2k)n×n,(n=2^k)的方阵,求C=ABC=AB,则每个n×nn×n的矩阵都可以分割为四个(n/2)×(n/2)(n/2) ×( n/2)的矩阵:

A=[A11A12A21A22] B=[B11B12B21B22] C=[C11C12C21C22]A=\begin{bmatrix} A_{11} & A_{12} \\ A_{21} & A_{22} \end{bmatrix} \space B=\begin{bmatrix} B_{11} & B_{12} \\ B_{21} & B_{22} \end{bmatrix} \space C=\begin{bmatrix} C_{11} & C_{12} \\ C_{21} & C_{22} \end{bmatrix}

于是C=ABC=AB可以改写为:

[C11C12C21C22]=[A11A12A21A22][B11B12B21B22]\begin{bmatrix} C_{11} & C_{12} \\ C_{21} & C_{22} \end{bmatrix}=\begin{bmatrix} A_{11} & A_{12} \\ A_{21} & A_{22} \end{bmatrix}\begin{bmatrix} B_{11} & B_{12} \\ B_{21} & B_{22} \end{bmatrix}

展开有:

C11=A11B11+A12B21C12=A11B12+A12B22C21=A21B11+A22B21C22=A21B12+A22B22C_{11} = A_{11}B_{11}+A_{12}B_{21} \\ C_{12} = A_{11}B_{12}+A_{12}B_{22} \\ C_{21} = A_{21}B_{11}+A_{22}B_{21} \\ C_{22} = A_{21}B_{12}+A_{22}B_{22}

每个等式需要两次矩阵乘法和一次矩阵加法。若用T(n)T(n)表示两个n×nn×n矩阵之间的乘法,则有如下递归式:

T(n)=8T(n2)+Θ(n2)T(n)=8T(\frac{n}{2})+\varTheta(n^2)

其中:

    1. 8T(n2)8T(\frac{n}{2})表示8次子矩阵乘法,子矩阵规模为(n/2)×(n/2)(n/2) ×( n/2)
    1. Θ(n2)\varTheta(n^2)表示4次子矩阵加法以及合并CC矩阵的时间复杂度。

根据递归式求解,得T(n)=Θ(nlog28)=Θ(n3)T(n)=\varTheta(n^{log_2^8})=\varTheta(n^3),与普通矩阵乘的时间复杂度相同。分治法并不能起到加速的效果。

原始版本Strassen

分治法包含了8次矩阵相乘和4次矩阵相加,相比矩阵加法,矩阵乘法是非常慢的。这8次矩阵相乘正是瓶颈的来源。于是我们想到能不能减少矩阵相乘的次数,哪怕代价是更多的矩阵相加。Strassen正是利用了这一点。

    1. 同样还是每个矩阵分成4份,然后创建如下10个中间矩阵(10次矩阵加法10×n2×n210×\frac{n}{2}×\frac{n}{2},时间复杂度为Θ(n2)\varTheta(n^2)):

S1=B12B22S2=A11+A12S3=A21+A22S4=B21B11S5=A11+A22S6=B11+B22S7=A12A22S8=B21+B22S9=A11A21S10=B11+B12S_1=B_{12}-B_{22} \\ S_2=A_{11}+A_{12} \\ S_3=A_{21}+A_{22} \\ S_4=B_{21}-B_{11} \\ S_5=A_{11}+A_{22} \\ S_6=B_{11}+B_{22} \\ S_7=A_{12}-A_{22} \\ S_8=B_{21}+B_{22} \\ S_9=A_{11}-A_{21} \\ S_{10}=B_{11}+B_{12}

    1. 而后是执行7次矩阵乘法(时间复杂度7T(n2)=Θ(nlog27)=Θ(n2.807)7T(\frac{n}{2})=\varTheta(n^{log_2^7})=\varTheta(n^{2.807})):

P1=A11S1P2=S2B22P3=S3B11P4=A22S4p5=S5S6P6=S7S8P7=S9S10P_1=A_{11}S_1 \\ P_2=S_2B_{22} \\ P_3=S_3B_{11} \\ P_4=A_{22}S_4 \\ p_5=S_5S_6 \\ P_6=S_7S_8 \\ P_7=S_9S_{10}

    1. 最后通过这7个矩阵计算得到C矩阵(8次矩阵加法8×n2×n28×\frac{n}{2}×\frac{n}{2},时间复杂度Θ(n2)\varTheta(n^2)):

C11=P5+P4P2+P6C12=P1+P2C21=P3+P4C22=P5+P1P3P7C_{11}=P_5+P_4-P_2+P_6 \\ C_{12}=P_1+P_2 \\ C_{21}=P_3+P_4 \\ C_{22}=P_5+P_1-P_3-P_7

综合可得如下递归式:

T(n)={Θ(1)if n=17T(n2)Θ(n2)if n>1T(n)=\begin{cases} \varTheta(1) &\text{if } n=1 \\ 7T(\frac{n}{2})\varTheta(n^2) &\text{if } n>1 \end{cases}

进而总的时间复杂度为T(n)=Θ(nlog27)=Θ(n2.807)T(n)=\varTheta(n^{log_2^7})=\varTheta(n^{2.807})

改进版本Strassen

    1. 8次矩阵加法:

S1=A21+A22T1=B12B11S2=S1A11T2=B22T1S3=A11A21T3=B22B12S4=A12S2T4=T2B21\begin{align*} & S_1=A_{21}+A_{22} \quad T_1=B_{12}-B_{11} \\ & S_2=S_1-A_{11} \quad\enspace T_2=B_{22}-T_1 \\ & S_3=A_{11}-A_{21} \quad T_3=B_{22}-B_{12} \\ & S_4=A_{12}-S_2 \quad\enspace T_4=T_2-B_{21} \end{align*}

    1. 7次矩阵乘法:

P1=A11B11P5=S1T1P2=A12B21P6=S2T2P3=S4B22P7=S3T3P4=A22T4\begin{align*} & P_1=A_{11}B_{11} \quad & P_5=S_1T_1 \\ & P_2=A_{12}B_{21} \quad & P_6=S_2T_2 \\ & P_3=S_4B_{22} \quad & P_7=S_3T_3 \\ & P_4=A_{22}T_4 \end{align*}

    1. 7次矩阵加法:

U1=P1+P2U5=U4+P3U2=P1+P6U6=U3P4U3=U2+P7U7=U3+P5U4=U2+P5\begin{align*} &U_1=P_1+P_2 \quad U_5=U_4+P_3 \\ &U_2=P_1+P_6 \quad U_6=U_3-P_4 \\ &U_3=U_2+P_7 \quad U_7=U_3+P_5 \\ &U_4=U_2+P_5 \end{align*}

    1. 最终结果:

C=[U1U5U6U7]C=\begin{bmatrix} U_1 & U_5 \\ U_6 & U_7 \end{bmatrix}

改进版本的Strassen相比原始版本矩阵加法次数由18次降低到15次。同时该算法同时分析各中间矩阵块之间的依赖关系,充分考虑到了计算中的内存复用,每次递归只需要动态分配两块临时内存X1X_1X2X_2即可,C11C_{11}C12C_{12}C21C_{21}C22C_{22}都可以在计算中复用。

计算流程图如下:

1

计算流程表如下:

2

总结

Strassen优化原理是减少矩阵乘法,额外增加的是更多矩阵加法,矩阵乘法的时间复杂度O(n3)O(n^3),而矩阵加法的复杂度是O(n2)O(n^2)。总体上乘加运算次数是减少的,来达到加速目的。

下图展示了普通矩阵乘法和Strassen算法的性能差异,nn越大,Strassen算法节约的时间越多。

3

实际使用中,考虑到每次递归计算分块矩阵的访存问题,会与普通矩阵乘法的做一个cost比较,满足一定条件下切换到普通矩阵乘法,终止递归。

参考

致谢

文章主体框架参考自东哥的MNN源码解读的内部分享,加上了自己的一些看法。有幸被看到的话,希望能给点个赞~~