二进制指数法

在模数运算和计算代数中,你经常需要把一个数增大到它的nn次方。这在进行模数除法,进行素数测试,或计算一些组合值时非常常见。而且你通常希望在Θ(n)次操作内完成计算。

二进制指数计算,也被称为通过平方进行指数运算,是一种用O(logn)O(\log n)次乘法计算nn次幂的方法。这种方法基于:

a2k=(ak)2a2k+1=(ak)2a\begin{align} a^{2k} &= (a^k)^2 \\ a^{2k+1} &= (a^k)^2 \cdot a \end{align}

为了计算ana^n,我们可以递归地计算an/2a^{⌊n/2⌋},然后对其平方,如果nn为奇数,则再乘上aa,对应如下递归计算:

an=f(a,n)={1,n=0f(a,n2)2,2  nf(a,n1)a,2na^n = f(a,n) = \begin{cases} 1,& n = 0\\ f(a,\frac{n}{2})^2, & 2\space |\space n \\ f(a,n-1) \cdot a, & 2 ∤ n \end{cases}

由于每进行2次递归,nn至少减少一半,所以递归深度和总的乘法次数将至多为O(logn)O(\log n)

递归实现

我们已经得到了一个递归关系,自然会将算法实现为一个case匹配的递归函数:

const int M = 1e9 + 7; // modulo
typedef unsigned long long u64;

u64 binpow(u64 a, u64 n) {
if (n == 0)
return 1;
if (n % 2 == 1)
return binpow(a, n - 1) * a % M;
else {
u64 b = binpow(a, n / 2);
return b * b % M;
}
}

在我们的基准测试中,我们设定 n=m2n=m-2,这样我们就能计算出模 mmaa乘法逆元

u64 inverse(u64 a) {
return binpow(a, M - 2);
}

我们使用m=109+7m=10^9+7,这是在编程竞赛中常用的模值,用于在组合问题中计算校验和 - 因为它是质数(允许通过二进制指数运算计算逆元),足够大,在加法中不会溢出int型,在乘法中不会溢出long long型,并且可以方便地以1e9 + 7的形式输入。

由于我们在代码中将其作为编译时常量使用,编译器可以通过替换为乘法来优化模运算(即使它不是编译时常量,手动计算一次魔数并用于快速reduction的代价仍然更低)。

执行路径,以及因此产生的运行时间,取决于nn的值。对于这个特定的nn,baseline实现每次调用大约花费330纳秒。由于递归引入了一些开销,因此将实现展开成迭代过程是有意义的。

迭代实现

ana^n的结果可以视为aa的2的kk次方幂的乘积,其中kk对应nn的二进制展开式中每一位(如果该位为1)的位置(从右向左以0开始计)。例如,如果n=42=32+8+2n = 42 = 32+8+2,则

a42=a32+8+2=a32a8a2a^{42} = a^{32 + 8 + 2} = a^{32} \cdot a^{8} \cdot a^{2}

为了计算这个乘积,我们可以遍历 nn 的每一位,同时维护两个变量:a2ka^{2^k} 的值和已遍历过 nn 的最低 kk 位后的当前乘积。在每一步,如果 nn 的第 kk 位为 1,我们就把当前乘积乘上 a2ka^{2^k},且每一步都要把 aka^k 取平方得到 a2k2=a2k+1a^{2^k⋅2}=a^{2^{k+1}},用于下一次的迭代。

u64 binpow(u64 a, u64 n) {
u64 r = 1;

while (n) {
if (n & 1)
r = res * a % M;
a = a * a % M;
n >>= 1;
}

return r;
}

迭代实现每次调用大约需要180纳秒。重度计算部分是一样的;提升主要来自于降低了依赖链:在循环能够继续之前,需要完成a = a * a % M的计算,现在它可以与r = res * a % M并行执行。

n是常量对于性能也是有益的,这使得所有分支可预测 ,让调度器知道需要提前执行什么。然而,编译器并没有利用这一点,没有展开while(n) n >>= 1循环。我们可以将它重写为执行恒定30次迭代的for循环:

u64 inverse(u64 a) {
u64 r = 1;

#pragma GCC unroll(30)
for (int l = 0; l < 30; l++) {
if ( (M - 2) >> l & 1 )
r = r * a % M;
a = a * a % M;
}

return r;
}

这迫使编译器只生成我们需要的指令,可以再减少10ns,使总运行时间达到170ns。

请注意,性能并不仅仅取决于nn的二进制长度,还取决于二进制中11的数量。如果nn2302^{30},由于我们不必进行任何非关键路径的乘法运算,所以将节省大约20纳秒的时间。