蒙哥马利乘法

不出所料,模运算中的大部分计算通常用于取模运算,其速度与一般的整数除法一样慢,通常需要15-20个周期,具体取决于操作数的大小。

解决这个问题的最好方法是完全避免使用取模运算,通过使用分支预测来推迟或代替它,例如在计算模之和时就可以这样做:

const int M = 1e9 + 7;

// input: array of n integers in the [0, M) range
// output: sum modulo M
int slow_sum(int *a, int n) {
int s = 0;
for (int i = 0; i < n; i++)
s = (s + a[i]) % M;
return s;
}

int fast_sum(int *a, int n) {
int s = 0;
for (int i = 0; i < n; i++) {
s += a[i]; // s < 2 * M
s = (s >= M ? s - M : s); // will be replaced with cmov
}
return s;
}

int faster_sum(int *a, int n) {
long long s = 0; // 64-bit integer to handle overflow
for (int i = 0; i < n; i++)
s += a[i]; // will be vectorized
return s % M;
}

然而,有时你需要进行一连串的模乘法(指对两个数进行乘法运算,然后取结果的模)运算,这时没有好的方法可以避免取模运算,除了使用一些整数除法技巧(这要求模值为常数)和一些预计算。

但是,这里有一种专为模运算设计的技术,称为蒙哥马利乘法(Montgomery multiplication)。

蒙哥马利空间

蒙哥马利乘法首先将乘数转换到蒙哥马利空间(Montgomery space),在这个空间中可以以较低的代价执行模乘法,然后当需要实际的值时,再将它们转换回去。与常规的整数除法方法不同,对于只进行一次模运算而言,蒙哥马利乘法并不高效,只有在执行一系列模操作时才值得使用。

这个空间是由模数nn和一个与nn互质的正整数rnr≥n定义的。算法涉及到对r的模运算和除法,所以在实践中,rr通常被选择为2322^{32}2642^{64},这样这些操作可以分别通过右移和位与操作来完成。

定义:一个数xx在蒙哥马利空间中的表示x\overline{x}被定义为

x=xrmodn\overline{x} = x \cdot r \mod n

这种转换的计算包括一次乘法和一次取模运算——这是一种我们最初想要优化的代价昂贵的操作——这就是为什么我们只在数字与蒙哥马利空间表示相互转换的额外开销值得时才使用这种方法,而不是用于一般的模乘法。

在蒙哥马利空间中,加法、减法和相等性检查与平常相同:

xr+yr(x+y)rmodnx \cdot r + y \cdot r ≡ (x + y) \cdot r \mod n

但是,乘法却并非如此。我们将蒙哥马利空间中的乘法表示为 * ,将普通的乘法表示 \cdot ,我们期望的结果是:

xy=xy=(xy)rmodn\overline{x} * \overline{y} = \overline{x\cdot y} = (x \cdot y) \cdot r \mod n

但是对于蒙哥马利空间中的普通乘法有:

xy=(xy)rrmodn\overline{x} \cdot \overline{y} = (x \cdot y) \cdot r \cdot r \mod n

因此,蒙哥马利空间中的乘法被定义为:

xy=xyr1modn\overline{x} * \overline{y} = \overline{x} \cdot \overline{y} \cdot r^{-1} \mod n

这意味着,在蒙哥马利空间中正常乘两个数后,我们需要通过乘以r1r^{−1}来“减小”结果并取模 —— 并且有一种有效的方式来执行这个特定的“减小”操作。

蒙哥马利模余法(Montgomery reduction)

假设r=232r = 2^{32},模数nn是32位的,我们需要“减小”的数xx是64位的(两个32位数的乘积),我们的目标是计算y=xr1modny = x \cdot r^{-1} \mod n

由于rrnn互质,我们知道在[0,n)[0,n)范围内有两个数字r1r^{-1}nn'满足:

rr1+nn=1r \cdot r^{-1} + n \cdot n' = 1

其中r1r^{-1}nn'都是可以计算的,使用拓展欧几里得算法

利用这一特性,我们可以将rr1r \cdot r^{-1}表示为(1nn)(1 - n \cdot n'),并将xr1x \cdot r^{-1}写为

xr1=xrr1/r=x(1nn)/r=(xxnn)/r(xxnn+krn)/r(modn)(for any integer k)(x(xnkr)n)/r(modn)\begin{align} x \cdot r^{-1} &= x \cdot r \cdot r^{-1} / r \\ &= x \cdot (1 - n \cdot n') / r \\ &= (x - x \cdot n \cdot n') / r \\ &≡ (x - x \cdot n \cdot n' + k \cdot r \cdot n) / r \quad (\mod n) \quad (for \space any \space integer \space k) \\ &≡ (x - (x \cdot n' - k \cdot r)\cdot n) /r \quad (\mod n ) \end{align}

现在,如果我们令kkxn/r⌊x⋅n′/r⌋(乘积xnx \cdot n'的高32位(原文中是64位)),就能进行约简,krxnk \cdot r - x \cdot n'就等于xnmodrx \cdot n' \mod rxnx \cdot n'的低32位),有:

xr1(xxnmodrn)/rx \cdot r^{-1} ≡ (x - x \cdot n' \mod r \cdot n)/r

算法本身就是在计算这个公式,执行两次乘法来计算$q=x⋅n’ \mod r $和 m=qnm=q⋅n,然后从xx中减去结果,然后通过右移执行除以rr操作。

唯一需要注意的是,结果可能不在[0,n)[0,n)范围内,但是由于

x<nn<rnx/r<nx < n \cdot n < r \cdot n \Longrightarrow x/r < n

m=qn<rnm/r<nm = q \cdot n < r * n \Longrightarrow m /r < n

这就能保证

n<(xm)/r<n-n < (x-m)/r < n

因此,我们可以简单地检查结果是否为负,若为负,则加上n,有以下算法:

typedef __uint32_t u32;
typedef __uint64_t u64;

const u32 n = 1e9 + 7, nr = inverse(n, 1ull << 32);

u32 reduce(u64 x) {
u32 q = u32(x) * nr; // q = x * n' mod r
u64 m = (u64) q * n; // m = q * n
u32 y = (x - m) >> 32; // y = (x - m) / r
return x < m ? y + n : y; // if y < 0, add n to make it be in the [0, n) range
}

最后一次的检查相对便宜,但仍然在关键路径上。如果我们可以接受结果在[0,2n2][0,2n-2]范围内,而不是[0,n)[0,n)范围内,我们可以移除这个检查,并无条件地将nn添加到结果中。

u32 reduce(u64 x) {
u32 q = u32(x) * nr;
u64 m = (u64) q * n;
u32 y = (x - m) >> 32;
return y + n
}

我们也可以将>>32操作在计算图中提前一步,计算x/rm/r⌊x/r⌋−⌊m/r⌋,而不是计算(xm)/r(x-m)/r。这样做是没问题的,因为xxmm的低32位在任何情况下都是相等的,因为有

m=xnnx(modr)m = x \cdot n' \cdot n ≡ x (\mod r)

为什么我们会主动选择进行两次右移,而不是只进行一次呢?这样做是有利的,因为对于((u64) q * n) >> 32,我们需要执行一个32位乘以32位的乘法,并取结果的上32位(x86的mul指令已经将其写入到一个单独的寄存器 ,所以这不会有任何额外的代价),而另一个右移 x >> 32 不在关键路径上。

u32 reduce(u64 x) {
u32 q = u32(x) * nr;
u32 m = ((u64) q * n) >> 32;
return (x >> 32) + n - m;
}

蒙哥马利乘法相比其他模余方法的主要优势之一是它不需要非常大的数据类型:它只需要一个r×rr × r乘法,该乘法提取结果的低rr位和高rr位,这在大多数硬件上都有特定的支持,也使得它容易推广到SIMD和更大的数据类型。

typedef __uint128_t u128;

u64 reduce(u128 x) const {
u64 q = u64(x) * nr;
u64 m = ((u128) q * n) >> 64;
return (x >> 64) + n - m;
}

请注意,一般的整数除法技巧无法实现128位对64位的模除运算:编译器会退化成调用一个慢速的长整数运算库函数来支持它。

更快的逆元变换

Montgomery乘法本身很快,但它需要一些预计算:

  • nnrr求逆元以计算nn'
  • 将一个数转换到Montgomery空间,
  • 将一个数从Montgomery空间转换出来。

上面实现的reduce方法已经可以有效地执行最后一步操作,但是前两步还可以稍微优化一下。

计算逆元

计算n=n1modrn'=n^{-1} \mod r,有比使用扩展欧几里得算法更快的方法,这是因为rr22的幂,可以利用以下的特性:

ax1mod2kax(2ax)1mod22ka \cdot x ≡ 1 \mod 2^k \Longrightarrow a \cdot x \cdot (2 - a \cdot x) ≡ 1 \mod 2^{2k}

证明如下:

ax(2ax)=2ax(ax)2=2(1+m2k)(1+m2k)2=2+2m2k12m2km222k=1m222k1mod22k\begin{align} a \cdot x \cdot (2 - a \cdot x) &= 2 \cdot a \cdot x - (a \cdot x)^2 \\ &=2 \cdot (1 + m \cdot 2^k)-(1 + m \cdot 2^k)^2 \\ &=2 + 2\cdot m \cdot 2^k - 1 - 2 \cdot m \cdot 2^k - m^2 \cdot 2^{2k} \\ &=1 - m^2 \cdot 2^{2k} \\ &≡ 1 \mod 2^{2k} \end{align}

我们一开始可以用x=1x=1作为amod21a \mod 2^1的逆(因为a1=a212=1a^{-1} = a^{2^1-2} = 1),然后应用上面这个等式log2rlog_2r次,每次都会将逆中的bit数翻倍 - 这有点类似于牛顿法

将一个数转换到蒙哥马利空间

可以通过将其乘以rr并进行取模运算来实现,但我们也可以利用下面这个等式:

x=xrmodn=xr2\overline{x} = x \cdot r \mod n = x * r^2

将一个数字转换到蒙哥马利空间只需要乘以r2r^2。因此,我们可以预先计算r2modnr^2 \mod n,然后执行乘法和模余操作,这样做速度可能会更快,也可能不会更快,因为将一个数字乘以r=2kr=2^k可以用左移位实现,而将一个数乘以r2modnr^2 \mod n则无法利用左移位。

计算实现

将所有内容封装到单个constexpr结构体:

struct Montgomery {
u32 n, nr;

constexpr Montgomery(u32 n) : n(n), nr(1) {
// log(2^32) = 5
for (int i = 0; i < 5; i++)
nr *= 2 - n * nr;
}

u32 reduce(u64 x) const {
u32 q = u32(x) * nr;
u32 m = ((u64) q * n) >> 32;
return (x >> 32) + n - m;
// returns a number in the [0, 2 * n - 2] range
// (add a "x < n ? x : x - n" type of check if you need a proper modulo)
}

u32 multiply(u32 x, u32 y) const {
return reduce((u64) x * y);
}

u32 transform(u32 x) const {
return (u64(x) << 32) % n;
// can also be implemented as multiply(x, r^2 mod n)
}
};

为了测试其性能,我们可以将蒙哥马利乘法插入到二进制幂运算中:

constexpr Montgomery space(M);

int inverse(int _a) {
u64 a = space.transform(_a);
u64 r = space.transform(1);

#pragma GCC unroll(30)
for (int l = 0; l < 30; l++) {
if ( (M - 2) >> l & 1 )
r = space.multiply(r, a);
a = space.multiply(a, a);
}

return space.reduce(r);
}

编译器生成的普通二进制幂运算,即使是使用快速模运算技巧,每次inverse也需要大约170纳秒,而这个实现只需要大约166纳秒,如果我们忽略transformreduce(一个合理的用例是用inverse作为更大的模运算中的子过程),这个时间可以降低到大约158纳秒。这是一个小的改进,但对于SIMD应用程序和更大的数据类型,蒙哥马利乘法变得更有优势。

练习题:实现高效的模矩阵乘法