写在前面

大数运算这个实验本来是在 AES 之前就发布了的,结果这篇博客却是在 AES 之后才写出来,这自然是因为我没有做出来

显然这个实验不可能只用一般的模拟手工运算的方法就能通过所有样例,需要自己去学习实现更高效的算法。我亲眼见证助教把时限从 30000ms 放宽到 60000ms,然后又把最后一个样例删掉了。

尽管如此放宽,但我实现的更优算法的代码根本得不到正确的结果,花费了大量时间进行调试,中间一度放弃。终于在今天(2024 年 11 月 29 日)下午,在删掉一个下划线,把一个变量变成另一个之后,我的代码终于跑通了!怀着激动的心情,写下这篇博客。

实验描述

完整的实验描述在实验平台上,完整代码也已经放到Github上了。

本次的实验的前置知识很简单,只要知道模和Zp\mathbb{Z}_p的概念就行了。

处理输入

样例输入给出的大数仍然是以十进制给出的,这对于我们人看非常友好,但是对于计算机来说可就不友好了。在Zp\mathbb{Z}_p下,由于pp是不高于 2048bit 的,所以a,b<22048a, b<2^{2048}。没有任何一种基本类型能直接存储这么大的数,只能够当做字符串读入。

进一步,我们人类计算时,一般习惯从低位到高位。而大数作为字符串一读进来,就是从高位到低位的。所以我们还需要把字符串反转一下,这样才能方便我们进行运算。

对于我的第一版代码就只考虑到了这里,就开始实现了。也即整个大数,当成一个字符串来处理。但是这种存储形式有非常大的局限性,无法用一些更高效的算法。我也考虑过以(十的幂次)为基的压位高精度。经过一系列搜索之后,为了与后续选用的算法配合,我最终决定将大数转化为二进制,然后以 32 位为一个进行划分,将整个大数存储在一个uint32_t数组中。

其实实验描述中也有提示,pp 的范围是以二进制下的位数给出的,这引导我们考虑以二进制为基的存储方式。所以,对于本实验的数据,一个大数最多用2048/32=642048/32=64uint32_t就能存下了,为了便于存储运算的中间结果,最终我们统一采用uint32_t[128]来存储一个大数。

所以主函数大概就是下面这样:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
int main(void)
{
int n;
cin >> n >> p;
reverse(p.begin(), p.end());
str2bi(P, p);
P_bits = getBits(P);

pre_cal();

while (n--)
{
string a, b;
cin >> a >> b;
reverse(a.begin(), a.end());
reverse(b.begin(), b.end());

uint32_t A[128] = {0}, B[128] = {0};
str2bi(A, a);
str2bi(B, b);

uint32_t res1[128] = {0};
mod_add(res1, A, B);
cout << bi2str(res1) << '\n';

uint32_t res2[128] = {0};
mod_sub(res2, A, B);
cout << bi2str(res2) << '\n';

uint32_t res3[128] = {0};
mod_mul(res3, A, B);
cout << bi2str(res3) << '\n';

uint32_t x[128] = {0}, y[128] = {0};
inv_exculid(A, P, x, y);
cout << bi2str(x) << '\n';

uint32_t res5[128] = {0};
mod_pow(res5, A, B);
cout << bi2str(res5) << '\n';

if (n)
cout << '\n';
}
}

以字符串形式读入大数之后,先反转,然后用str2bi()这个函数转为二进制:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
uint32_t pow2[32] = {1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192,
16384, 32768, 65536, 131072, 262144, 524288,1048576, 2097152,
4194304, 8388608, 16777216, 33554432,67108864, 134217728,
268435456, 536870912,1073741824,2147483648};

void str2bi(uint32_t res[128], string &s)
{
int count = 0;
while (s != "0")
{
if ((s[0] - '0') & 1)
res[count / 32] += pow2[count % 32];
str_div2(s);
count++;
}
}

其逻辑是,每次看最低位是奇数还是偶数,如果是奇数,就在对应的“字”上累加上相应的值,然后将整个大数除以 2。这里对字符串形式的大数做除法运算,我单独写了一个函数str_div2()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
void remove_leading_zeros(string &s)
{
size_t end = s.find_last_not_of('0');
if (end != string::npos)
s.erase(end + 1);
else
s = "0";
}

void str_div2(string &s)
{
int carry = 0;
for (int i = s.length() - 1; i >= 0; i--)
{
int x = s[i] - '0';
s[i] = (x + carry * 10) / 2 + '0';
carry = x % 2;
}

remove_leading_zeros(s);
}

由于除法是会产生前导零的,所以我们需要一个函数来去掉这些前导零。记得如果结果就是 0,要保留这个零。

最后输出是十进制的,所以还需要一个bi2str()函数来转换回去,以及相应的辅助函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
void str_mul2(string &s)
{
int carry = 0;
for (std::string::size_type i = 0; i < s.length(); i++)
{
int x = s[i] - '0';
s[i] = (x * 2 + carry) % 10 + '0';
carry = (x * 2 + carry) / 10;
}

if (carry)
s = s + "1";
}

void str_add1(string &s)
{
int carry = 1;
for (std::string::size_type i = 0; i < s.length(); i++)
{
int x = s[i] - '0';
s[i] = (x + carry) % 10 + '0';
carry = (x + carry) / 10;
}
if (carry)
s = s + "1";
}

string bi2str(uint32_t res[128])
{
string s = "0";
for (int i = 127; i >= 0; i--)
for (int j = 31; j >= 0; j--)
{
str_mul2(s);
if (res[i] & POW2[j])
str_add1(s);
}

reverse(s.begin(), s.end());
return s;
}

把获得大数位数的函数也一并介绍了:

1
2
3
4
5
6
7
8
9
10
11
12
int getBits(const uint32_t a[128])
{
for (int i = 127; i >= 0; i--)
{
if (a[i])
for (int j = 31; j >= 0; j--)
if (a[i] & POW2[j])
return i * 32 + j;
}

return 0;
}

这里返回的位数比实际的位数少 1,是为了方便直接计算出下标。

普通加法

压位高精度的思想,逐字相加即可。

这里一定要强调的是,写函数的时候要想好输入的数据范围,记住你的代码干了什么,对怎样的输入会有怎样的输出

写这种有很多函数的代码的时候,应当遵循单测的思想,也即一个个的函数测试,而不是一次性写完所有函数再测试。要保证底层的函数都是“正确”的,然后再去写上层的函数。

这里的正确打引号是因为,函数总是有局限的,不可能覆盖所有的情况。有时候在泛用性和健壮性之间会做出一些取舍。

例如,我这里的普通加法add(),因为循环到i=127,所以理论上支持两个uint32_t[128]相加,但是最后的进位会被舍去。所以,如果后面有用到这个函数的地方,要考虑到这一点,可能这就是引起 bug 的原因。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
void add(uint32_t res[128], uint32_t a[128], uint32_t b[128])
{
uint64_t carry = 0;
uint32_t temp[128] = {0};
for (int i = 0; i < 128; i++)
{
uint64_t sum = carry + a[i] + b[i];
temp[i] = sum & 0xffffffff;
carry = sum >> 32;
}

for (int i = 0; i < 128; i++)
res[i] = temp[i];
}

普通减法

这里的实现的普通减法,必须满足a>ba>b,支持两个uint32_t[128]相减。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
const uint64_t BASE = 0x100000000;

void sub(uint32_t res[128], uint32_t a[128], uint32_t b[128])
{
uint64_t borrow = 0;
uint32_t temp[128] = {0};
for (int i = 0; i < 128; i++)
{
uint64_t temp1 = b[i] + borrow;
if (a[i] < temp1)
{
temp[i] = BASE + a[i] - temp1;
borrow = 1;
}
else
{
temp[i] = a[i] - temp1;
borrow = 0;
}
}

for (int i = 0; i < 128; i++)
res[i] = temp[i];
}

这里的BASE是我们的基,也即2322^{32}

普通乘法

这里的普通乘法,最多支持两个uint32_t[64]相乘,结果可以存储在uint32_t[128]中。

让我调试了很久的 BUG 之一,就是用这个函数去计算结果可能超过uint32_t[128]的乘法。溢出的部分直接抛弃了,结果自然不对,被自己蠢哭。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
void mul(uint32_t res[128], uint32_t a[128], uint32_t b[128])
{
uint32_t temp[256] = {0};

for (int i = 0; i < 128; i++)
{
uint64_t carry = 0;
for (int j = 0; j < 128; j++)
{
uint64_t sum = (uint64_t)a[i] * b[j] + temp[i + j] + carry;
temp[i + j] = sum & 0xffffffff;
carry = sum >> 32;
}
if (carry)
temp[i + 128] += carry;
}

for (int i = 0; i < 128; i++)
res[i] = temp[i];
}

普通除法

求逆元的时候,我们需要用到扩展欧几里得算法;而拓展欧几里得算法的基础之一是取模;取模运算的一种做法就是直接“求余数”,自然要用到普通除法。

也许有针对于大模数的快速取模算法,但我没有去搜索,最终选择实现一个除法来支持后面的取模等操作。

模拟手工除法自然是可行的,个人认为写起来还是有点麻烦了。经过一番搜索,找到了这篇文章RSA 大数运算实现 Knuth 除法(这位博主的 RSA 大数运算实现这一系列的文章都给了我很多帮助),其中提到了Knuth 除法

Knuth 除法的基本思想是估商法,其基本思想是不直接地去做除法,而是去计算一个容易计算,但不一定精确的估计值,然后再去修正这个估计值。这种思想我们还会遇到。


接下来正式介绍 Knuth 除法。

首先考虑 b 进制下,被除数和除数分别为(unun1u0)b(u_n u_{n-1}\cdots u_0)_b(vn1vn2v0)(v_{n-1}v_{n-2}\cdots v_0),且u/v<bu/v<b

因为u/v<bu/v<b等价于u/b<vu/b<v,等价于(unun1u1)b<(vn1vn2v0)b(u_n u_{n-1} \cdots u_1)_b<(v_{n-1}v_{n-2}\cdots v_0)_b。由此我们知道,在做竖式除法的时候,第一次要上 0。第二次上数,也即最后一次上的数,就是我们想要的商。我们可以这样估计这个数:

q^=unb+un1vn1\hat{q} = \left\lfloor\frac{u_n b+u_{n-1}}{v_{n-1}}\right\rfloor

也即根据uu的前两位和vv的前一位,估计出商的前一位。又因为商只有一位数,所以q^b1\hat{q}\leqslant b-1,故我们有:

q^=min{unb+un1vn1,b1}\hat{q}=\min\left\{\left\lfloor\frac{u_n b+u_{n-1}}{v_{n-1}}\right\rfloor, b-1\right\}

下面我们考虑这个估计值有多接近真实值。首先证明q^>q\hat{q}>q,其中qq是真实的商。

因为qb1q\leqslant b-1,若q^=b1\hat{q}=b-1,则原命题显然成立。否则我们有q^=unb+un1vn1\hat{q}=\left\lfloor\frac{u_n b+u_{n-1}}{v_{n-1}}\right\rfloor,进一步有:

q^+1=unb+un1+vn1vn1unb+un1+kvn1=unb+un1+kvn1unb+un1+1vn1\hat{q}+1=\left\lfloor\frac{u_n b+u_{n-1}+v_{n-1}}{v_{n-1}}\right\rfloor\geqslant \left\lfloor\frac{u_n b+u_{n-1}+k}{v_{n-1}}\right\rfloor=\frac{u_n b+u_{n-1}+k}{v_{n-1}}\geqslant \frac{u_n b+u_{n-1}+1}{v_{n-1}}

这等价于q^vn1unb+un1vn1+1\hat{q}v_{n-1}\geqslant u_n b+u_{n-1}-v_{n-1}+1

又因为uq^vuq^vn1bn1unbn++u0(unbn+un1bn1vn1bn1+bn1)=un2bn2++u0bn1+vn1bn1<vn1bn1vu-\hat{q}v\leqslant u-\hat{q}v_{n-1}b^{n-1}\leqslant u_nb^n+\cdots+u_0-(u_nb^n+u_{n-1}b^{n-1}-v_{n-1}b^{n-1}+b^{n-1})=u_{n-2}b^{n-2}+\cdots+u_0-b^{n-1}+v_{n-1}b^{n-1}<v_{n-1}b^{n-1}\leqslant v

所以uq^v<vu-\hat{q}v<v。如果q^<q\hat{q}<q,则q^q1\hat{q}\leqslant q-1,则uq^vu(q1)v=uqv+vvu-\hat{q}v\geqslant u-(q-1)v=u-qv+v\leqslant v,矛盾。所以必有q^q\hat{q}\geqslant q


继续,我们假设q^q+3\hat{q}\leqslant q+3。因为:

qunb+un1vn1=unbn+un1bn1vn1bn1uvn1bn1<uvbn1q\leqslant \frac{u_nb+u_{n-1}}{v_{n-1}}=\frac{u_nb^n+u_{n-1}b^{n-1}}{v_{n-1}b^{n-1}}\leqslant \frac{u}{v_{n-1}b^{n-1}}<\frac{u}{v-b^{n-1}}

这里vbn10v-b^{n-1}\neq 0,否则v=bn1v=b^{n-1},此时q^=q=b1\hat{q}=q=b-1,矛盾。

又因为(q+1)v>u(q+1)v>u,也即q>u/v1q>u/v-1,所以有:

3q^q<uvbn1(uv+1)=uv(bn1vbn1)+13\leqslant \hat{q}-q < \frac{u}{v-b^{n-1}} - (\frac u v+1)=\frac u v\left( \frac{b^{n-1}}{v-b^{n-1}} \right)+1

所以:

uv>2(vbn1bn1)2(vn11)\frac u v>2\left( \frac{v-b^{n-1}}{b^{n-1}} \right)\geqslant 2(v_{n-1}-1)

b1q^1b-1\geqslant \hat{q}-1,所以b4q^3q=u/v2(vn11)b-4\geqslant \hat{q}-3\geqslant q=\lfloor u/v\rfloor\geqslant 2(v_{n-1}-1)

所以b2(vn11)+4=2vn1+2b\geqslant 2(v_{n-1}-1)+4=2v_{n-1}+2,最后可得vn1<b/2v_{n-1}<\lfloor b/2\rfloor

于是当vn1b/2v_{n-1}\geqslant \lfloor b/2\rfloor时,有q^q+2\hat{q}\leqslant q+2,结合上面的结论,可知估计出的解满足q^2qq^\hat{q}-2\leqslant q\leqslant \hat{q}注:这里我没有详细推导,为什么可以直接反着用,有待补充进一步的证明


简而言之,如果我们能使除数的vn1v_{n-1}满足vn1b/2v_{n-1}\geqslant \lfloor b/2\rfloor,就可以只通过非常简单的计算,得到一个相对精确的商。

使vn1b/2v_{n-1}\geqslant \lfloor b/2\rfloor的过程称为规格化,可以通过让被除数和除数同时乘以bvn1+1\lfloor\frac{b}{v_{n-1}+1}\rfloor做到。这是因为:

vn1bvn1+1vn1b2b2v_{n-1}\lfloor\frac{b}{v_{n-1}+1}\rfloor \geqslant v_{n-1}\lfloor\frac{b}{2}\rfloor\geqslant \lfloor \frac b 2\rfloor

而且:

vn1bvn1+1<vn1bvn1+1=bbvn1+1<bv_{n-1}\lfloor\frac{b}{v_{n-1}+1}\rfloor < v_{n-1}\frac{b}{v_{n-1}+1} = b-\frac{b}{v_{n-1}+1} < b

除数也不会进位。


上面的结论与bb的大小无关,我们可以直接取b=232b=2^{32},于是可以写出下面的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
void div(uint32_t res[128], uint32_t a[128], uint32_t b[128])
{
uint32_t temp[128] = {0};
int n = getBits(b) / 32;
int m = getBits(a) / 32 - n; // iterate times

uint32_t d[128] = {0};
d[0] = BASE / (b[n] + (uint64_t)1);
uint32_t u_[129] = {0}, v_[128] = {0};
uint64_t carry = 0;
// scale a
for (int i = 0; i < 128; i++)
{
uint64_t temp = (uint64_t)a[i] * d[0] + carry;
u_[i] = temp & 0xffffffff;
carry = temp >> 32;
}
if (carry)
u_[128] = carry;
// scale b
mul(v_, b, d);

int j = m;
while (j >= 0)
{
uint32_t tem[129] = {0};
for (int i = 0; i <= n + 1; i++)
tem[i] = u_[i + j];
uint64_t tem2 = (tem[n + 1] * BASE + tem[n]) / v_[n];
tem2 = min(tem2, BASE - 1);

uint32_t q_hat = static_cast<uint32_t>(tem2 & 0xffffffff);
uint32_t qv[128] = {0};
for (int i = 0; i < 128; i++)
{
uint64_t temp = (uint64_t)v_[i] * q_hat + carry;
qv[i] = temp & 0xffffffff;
carry = temp >> 32;
}

while (bigger(qv, tem))
{
q_hat--;
sub(qv, qv, v_);
}

sub(tem, tem, qv);
for (int i = 0; i <= n + 1; i++)
u_[i + j] = tem[i];

temp[j] = q_hat;

j--;
}

for (int i = 0; i < 128; i++)
res[i] = temp[i];
}

主体过程和手工除法一样,只是“上数”的时候进行了估商。估商相比原来每次上数都要做一次uint32_t[64]的乘法,代价下降了很多。

同时要注意,只说了除数不会进位,没说被除数不会进位。所以规格化被除数的时候不能直接调用mul()

模、模加、模减、模乘、模幂

有了除法,就有了模:

1
2
3
4
5
6
7
void mod(uint32_t res[128], uint32_t a[128])
{
uint32_t temp[128] = {0};
div(temp, a, P);
mul(temp, temp, P);
sub(res, a, temp);
}

于是天堑变坦途:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
void mod_add(uint32_t res[128], uint32_t a[128], uint32_t b[128])
{
add(res, a, b);
mod(res, res);
}

void mod_sub(uint32_t res[128], uint32_t a[128], uint32_t b[128])
{
if (bigger(a, b))
{
sub(res, a, b);
mod(res, res);
}
else
{
sub(res, b, a);
mod(res, res);
sub(res, P, res);
}
}

void mod_mul_trivial(uint32_t res[128], uint32_t a[128], uint32_t b[128])
{
mul(res, a, b);
mod(res, res);
}

void mod_div(uint32_t res[128], uint32_t a[128], uint32_t b[128])
{
div(res, a, b);
mod(res, res);
}

uint32_t ONE[128] = {1};
void mod_pow_trivial(uint32_t res[128], uint32_t a[128], uint32_t b[128])
{
uint32_t temp[128] = {1};
uint32_t base[128] = {0};
for (int i = 0; i < 128; i++)
base[i] = a[i];

for (int i = 0; i < 128; i++)
{
if (bigger(temp, ONE) && b[i] == 0)
break;
for (int j = 0; j < 32; j++)
{
if (b[i] & POW2[j])
mod_mul(temp, temp, base);

mod_mul(base, base, base);
}
}
for (int i = 0; i < 128; i++)
res[i] = temp[i];
}

这里mod_pow()仍然用了快速幂的思想。模乘和模幂之所以带上了trivial,当然是因为后面有更好的方法。

求逆

有了之前的铺垫,求逆也不难。因为模总是素数,所以可以利用费马小定理:

1
2
3
4
5
6
7
8
uint32_t TWO[128] = {2};

void inv_fermat(uint32_t res[128], uint32_t a[128])
{
uint32_t temp[128];
sub(temp, P, TWO);
mod_pow(res, a, temp);
}

也可以用拓展欧几里得法。原理不加以介绍,代码参考了OI Wiki 乘法逆元,短小精悍:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
void inv_exculid(uint32_t a[128], uint32_t b[128], uint32_t x[128], uint32_t y[128])
{
if (equal(b, ZERO))
{
x[0] = 1;
y[0] = 0;
return;
}

uint32_t temp[128] = {0};
div(temp, a, b);
mul(temp, b, temp);
sub(temp, a, temp);

inv_exculid(b, temp, y, x);

uint32_t temp2[128] = {0};
div(temp2, a, b);
mul(temp2, temp2, x);

if (bigger(y, temp2))
{
sub(temp2, y, temp2);
mod(temp2, temp2);
}
else
{
sub(temp2, temp2, y);
mod(temp2, temp2);
sub(temp2, P, temp2);
}

for (int i = 0; i < 128; i++)
y[i] = temp2[i];
}

最后x就是逆元。


查资料时,看到RSA 大数运算实现 快速求逆元还提到了一种算法,称比扩展欧几里得快一倍。这里也顺便介绍一下。

如果要求a1modpa^{-1}\mod p,考虑同余方程xyamodpx\equiv ya\mod p。这个方程的两个平凡解是(a,1)(a, 1)(p,0)(p, 0)

显然有:

  • 如果(x1,y1)(x_1, y_1)(x2,y2)(x_2, y_2)是方程的解,则(x1x2,y1y2)(x_1-x_2, y_1-y_2)也是方程的解。
  • 如果(2x1,y1)(2x_1, y_1)是方程的解,则(x1,y1/2)(x_1, y_1/2)也是方程的解。
  • 如果(1,y)(1, y)是方程的解,则ya1modpy\equiv a^{-1} \mod p

于是我们可以迭代地去求解:

  • 一开始,我们有两个平凡解(x1,y1)(x_1, y_1)(x2,y2)(x_2, y_2)
  • 对于任意一组解(xi,yi)(x_i, y_i),如果xi,yix_i, y_i均为偶数,则令(xi,yi)(xi/2,yi/2)(x_i, y_i)\leftarrow(x_i/2, y_i/2);如果只有xix_i为偶数,由于模pp是奇素数(p=2p=2的情况太简单,不考虑),可以令(xi,yi)(xi/2,(pyi)/2)(x_i, y_i)\leftarrow(x_i/2, (p-y_i)/2)
  • x1>x2x_1>x_2,令(x1,y1)(x1x2,y1y2)(x_1, y_1)\leftarrow(x_1-x_2, y_1-y_2),反之亦然
  • x1=1x_1=1或者x2=1x_2=1时,算法结束;否则回到第二步

显然,x1x_1x2x_2的值会快速下降,最终都会到 1,算法是可行的。可以写出代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
void inv_exculid2(uint32_t res[128], uint32_t a[128])
{
uint32_t x1[128] = {0}, y1[128] = {1}, x2[128] = {0}, y2[128] = {0};
for (int i = 0; i < 128; i++)
x1[i] = a[i];
for (int i = 0; i < 128; i++)
x2[i] = P[i];

while (!(equal(x1, ONE) || equal(x2, ONE)))
{
if ((x1[0] & 1) == 0 && (y1[0] & 1) == 0)
{
div(x1, x1, TWO);
div(y1, y1, TWO);
}
else if ((x1[0] & 1) == 0 && (y1[0] & 1) == 1)
{
div(x1, x1, TWO);
add(y1, y1, P);
div(y1, y1, TWO);
}

if ((x2[0] & 1) == 0 && (y2[0] & 1) == 0)
{
div(x2, x2, TWO);
div(y2, y2, TWO);
}
else if ((x2[0] & 1) == 0 && (y2[0] & 1) == 1)
{
div(x2, x2, TWO);
add(y2, y2, P);
div(y2, y2, TWO);
}

if (bigger(x1, x2))
{
mod_sub(x1, x1, x2);
mod_sub(y1, y1, y2);
}
else
{
mod_sub(x2, x2, x1);
mod_sub(y2, y2, y1);
}
}

if (equal(x1, ONE))
for (int i = 0; i < 128; i++)
res[i] = y1[i];
else
for (int i = 0; i < 128; i++)
res[i] = y2[i];
}

由于我并没有实现移位,反倒是调用了这么多除法(33000ms)导致比原来(21000ms)还慢。最终还是用的拓展欧几里得法。

蒙哥马利算法

蒙哥马利算法的基本思想是,把要运算的数转化为蒙哥马利形式再进行运算,最后要结果的时候,再转换回去。由于蒙哥马利形式下进行的蒙哥马利约简比较高效,所以如果要进行少量数之间要进行大量的模乘运算,会取得很好的效果。

这一部分参考的文章包括RSA 大数运算实现 蒙哥马利模幂Montgomery 模乘。但是看了那么多之后,我建议直接看这篇论文Analyzing and Comparing Montgomery Multiplication Algorithms,这篇论文非常详细严谨,看这篇就够了。


首先介绍蒙哥马利形式。我们引入一个与pp互质的数RR,对于任意的aa,我们把aRmodpaR \mod p称为其蒙哥马利形式。

这里的RRpp互素就行了。一般,我们会选取RR为 2 的幂次,这样关于RR的除法和模运算就可以通过简单的位运算来实现。进一步,我们取RR为恰好比pp大的 2 的幂次(后面证明会用到)。


蒙哥马利约简是一个二元运算,我们记为:

REDC(a,b)=abR1modp\text{REDC}(a, b)=abR^{-1} \mod p

其中R1R^{-1}RR在模pp意义下的逆元。

REDC 之后,原来的模乘可以通过四次 REDC 完成:

1
2
3
4
5
6
7
8
9
10
11
12
void mod_mul_mont(uint32_t res[128], uint32_t a[128], uint32_t b[128])
{
uint32_t x_[128] = {0}, y_[128] = {0};

REDC(x_, a, R2);
REDC(y_, b, R2);
REDC(x_, x_, y_);
REDC(x_, x_, ONE);

for (int i = 0; i < 128; i++)
res[i] = x_[i];
}

不过更适合它发挥的地方是模运算。因为只涉及base的自乘和与当前数temp相乘。只要一开始转换到蒙哥马利形式下,中间的乘法可以都可以换成 REDC:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
void mod_pow_mont(uint32_t res[128], uint32_t a[128], uint32_t b[128])
{
uint32_t temp[128] = {0};
uint32_t base[128] = {0};
REDC(temp, ONE, R2);
REDC(base, a, R2);

for (int i = 0; i < 128; i++)
{
if (bigger(temp, ONE) && b[i] == 0)
break;
for (int j = 0; j < 32; j++)
{
if (b[i] & POW2[j])
REDC(temp, temp, base);

REDC(base, base, base);
}
}

REDC(res, temp, ONE);
}

最后来看一下 REDC 怎么实现。如果要计算t=REDC(a, b)t=\text{REDC(a, b)}

  • 预先计算(p)1pmodR(-p)^{-1}\equiv p' \mod R
  • 计算T=abT=ab
  • 计算m=((TmodR)p)modRm=((T\mod R)p')\mod R
  • 计算t=(T+mp)/Rt=(T+mp)/R
  • 如果tpt\geqslant p,则返回tpt-p;否则返回tt

然后我们来证明t=abR1modpt=abR^{-1}\mod p。首先证明tt是整数,也即证明RR整除t+mpt+mp。这是因为:

T+mpT+((TmodR)p)pT+TppT(1+pp)0modRT+mp\equiv T+((T\mod R)p')p\equiv T+Tp'p\equiv T(1+p'p)\equiv 0\mod R

然后有:

t=(T+mp)/R(ab+mN)/R((p1)(p1)+(R1)p)/R<(Rp+Rp)/R=2pt=(T+mp)/R\equiv(ab+mN)/R\leqslant((p-1)(p-1)+(R-1)p)/R<(Rp+Rp)/R=2p

所以至多只要再做一次减法,就能得到正确的结果。

可以发现 REDC 中耗时较长的除法和取模都是关于 R 的,这可以通过简单的位运算实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
void mod_R(uint32_t res[128], uint32_t a[128])
{
uint32_t temp[128] = {0};
for (int i = 0; i < R_bits / 32; i++)
temp[i] = a[i];
for (int i = 0; i < R_bits % 32; i++)
if (a[R_bits / 32] & POW2[i])
temp[R_bits / 32] += POW2[i];

for (int i = 0; i < 128; i++)
res[i] = temp[i];
}

void div_R(uint32_t res[128], uint32_t a[129])
{
uint32_t temp[128] = {0};

int word_shift = R_bits / 32;
int bits_shift = R_bits % 32;
for (int i = 0; i + word_shift < 129; i++)
temp[i] = a[i + word_shift];

if (bits_shift != 0)
{
for (int i = 0; i < 128 - 1; i++)
temp[i] = temp[i] >> bits_shift | (temp[i + 1] << (32 - bits_shift));
temp[127] = temp[127] >> bits_shift;
}

for (int i = 0; i < 128; i++)
res[i] = temp[i];
}

于是最终 REDC 的代码为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
void REDC(uint32_t res[128], uint32_t a[128], uint32_t b[128])
{
uint32_t T[128] = {0};

mul(T, a, b);

uint32_t m[128] = {0};
uint32_t t[128] = {0};
mod_R(m, T); // m=T mod R

mul(m, m, P_); // m=(T mod R) * P_
mod_R(m, m); // m=((T mod R) * P_) mod R
mul(t, m, P); // t=m*P

uint64_t carry = 0;
uint32_t te[129] = {0};
for (int i = 0; i < 128; i++)
{
uint64_t sum = carry + t[i] + T[i];
te[i] = sum & 0xffffffff;
carry = sum >> 32;
}
if (carry)
te[128] = carry;

div_R(t, te);

if (bigger(t, P) || equal(t, P))
sub(t, t, P);

for (int i = 0; i < 128; i++)
res[i] = t[i];
}

采用这种实现的 REDC 交上去(10603ms),比纯 trivial 的模乘(6307ms)还慢?仔细一看,原来是模幂还没用蒙哥马利算法。改成蒙哥马利模幂后,耗时来到了 4030ms。


还能更快吗?

知乎文章和论文中都提到了一种称为 SOS(Seperated Operand Scanning)的算法,用这种算法计算 REDC 时,无需先求出大整数mm,再算T+mpT+mp,而是可以变算mm边算T+mpT+mp

考虑T+m1p0mod2wT+m_1p\equiv 0 \mod 2^w,其中ww是字的长度。则有m1=Tpmod2wm_1=Tp' \mod 2^wm1m_1相当于TT的第一个字与pp'的第一个字相乘。

进一步,考虑T+m1p+m22wp0mod22wT+m_1p+m_2 2^w p\equiv 0 \mod 2^{2w}。则m22w=(T+m1p)pmod22wm_2 2^w=(T+m_1p)p' \mod 2^{2w}。由于T+m1p0mod2wT+m_1p\equiv 0 \mod 2^w,所以其低ww位为 0。所以m2m_2就等于T+m1pT+m_1p的第二个字乘以pp'的第一个字。

以此类推,就有了论文中的伪代码。我根据伪代码实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
void REDC(uint32_t res[128], uint32_t a[128], uint32_t b[128])
{
uint32_t T[128] = {0};
mul(T, a, b);

int s = P_words + 1;
uint32_t T_[129] = {0};
for (int i = 0; i < 128; i++)
T_[i] = T[i];

for (int i = 0; i < s; i++)
{
uint64_t carry = 0;
uint64_t mi = ((uint64_t)T_[i] * P_[0]) & 0xffffffff;
for (int j = 0; j < s; j++)
{
uint64_t temp = mi * P[j] + (uint64_t)T_[i + j] + carry;
T_[i + j] = temp & 0xffffffff;
carry = temp >> 32;
}

int count = 0;
while (carry)
{
uint64_t temp = (uint64_t)T_[i + s + count] + carry;
T_[i + s + count] = temp & 0xffffffff;
carry = temp >> 32;
count++;
}
}

for (int i = 0; i <= s; i++)
res[i] = T_[i + s];

if (bigger(res, P) || equal(res, P))
sub(res, res, P);
}

这种实现交上去后,耗时降低到了 3145 ms。

写在最后

至此,整个大数运算的实验就到此结束了。也是把模幂改成蒙哥马利模幂之后,拿下了目前最快的成绩。

调试过程至少耗费了二十个小时起步,大部分时间浪费在了定位问题上。这次调试的经验也让我在写 AES 的时候,一开始就对着样例一步步调试,确保下面的函数正确之后,后续几乎是“自然而然”就都对了。

为了方便调试,我还编写了两个脚本。gen_input.py用于生成输入到input.txt,然后gen_output.py用于读取input.txt的数据,生成相应的标准答案到output.txt。这样自己调试起来就很方便了。

以上。

后记

把代码放到 Github 之上,没想到引来助教亲自过来指导。他提到,可以改用uint64_t[],来存储大整数。中间结果,可以用 GCC 编译器支持的拓展类型__int128unsigned __int128来暂存。(事实上,虽然GCC 的文档里没写,但是可以用__uint128_t来代替unsigned __int128,两者完全等价。这篇回答里也提到了这一点。)

于是我着手将代码改写。同时,还将字长和大整数需要最多的字抽象为全局变量W以及MAXLEN。另外,我还只保留了效率最高的函数,精简代码,最终得到了bigint_v2.cpp

bigint_v2.cpp通过所有样例仅需 3000ms,比以uint32_t[]实现的版本(11000ms)快了许多。同时,上面在蒙哥马利算法中讨论的耗时,都是uint64_t[]版跑出的数据,仅提供大小关系的参考。


__uint128_t有几点需要注意的地方,由于其只支持基本的运算,没有min,输入输出或者比较运算的重载,使用的时候要特别小心。

例如改写sub()函数的时候,里面涉及到一处__uint128_tuint64_t的比较,这是一个未定义行为,这个 UB 导致我的模乘和模幂出错了。由于我的 SOS 版的 REDC 实现没有错,一开始我都没有怀疑到precal()里计算的P_,也即pmodR-p\mod R上。

后来我才发现 SOS 版的 REDC 实现中,只用到了P_的第一个字,恰巧,只有第一个字是对的。最后,我才定位到了sub的头上,才觉得可能是这个比较搞得鬼。

助教也提示我可以用 GCC/Clang 提供的检查溢出的内建函数Integer Overflow Builtins (Using the GNU Compiler Collection (GCC))。这些内建函数的行为是对第一个和第二个参数提升到无限精度做运算,然后将得到的结果(由于是无限精度,一定是正确的)与放到第三个参数对应的变量之后的结果进行比较,如果发生了溢出,两者就不相同,会返回true;反之,说明没有溢出,结果准确,返回false

利用这些内建函数,可以重写add()sub()如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
void add(uint64_t res[MAXLEN], uint64_t a[MAXLEN], uint64_t b[MAXLEN])
{
uint64_t carry = 0;
uint64_t temp[MAXLEN] = {0};
for (int i = 0; i < MAXLEN; i++)
{
uint64_t sum;
if (__builtin_add_overflow(a[i], carry, &sum) || __builtin_add_overflow(sum, b[i], &sum))
{
temp[i] = ((__uint128_t)a[i] + b[i] + carry) & MASK;
carry = 1;
}
else
{
temp[i] = sum;
carry = 0;
}
}

for (int i = 0; i < MAXLEN; i++)
res[i] = temp[i];
}

void sub(uint64_t res[MAXLEN], uint64_t a[MAXLEN], uint64_t b[MAXLEN])
{
uint64_t borrow = 0;
uint64_t temp[MAXLEN] = {0};
for (int i = 0; i < MAXLEN; i++)
{
uint64_t gap;
if (__builtin_sub_overflow(a[i], borrow, &gap) || __builtin_sub_overflow(gap, b[i], &gap))
{
temp[i] = (BASE + a[i] - borrow - b[i]) & MASK;
borrow = 1;
}
else
{
temp[i] = gap;
borrow = 0;
}
}

for (int i = 0; i < MAXLEN; i++)
res[i] = temp[i];
}

另外值得一提的是,蒙哥马利算法的RR本来是去恰好比pp大的二的幂。为了方便(实际上也正是论文中的说法),如果p,a,bp, a, b最多需要SS个字来表示,那么就取R=2SWR=2^{SW}。这样初始化、模RR和除RR都会方便许多:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
void pre_cal()
{
P_words = P_bits / W;
R_words = P_words + 1;
R[R_words] = 1;
R_bits = getBits(R);

// ...
}

void mod_R(uint64_t res[MAXLEN], uint64_t a[MAXLEN])
{
uint64_t temp[MAXLEN] = {0};
for (int i = 0; i < R_words; i++)
temp[i] = a[i];
for (int i = 0; i < MAXLEN; i++)
res[i] = temp[i];
}

void div_R(uint64_t res[MAXLEN], uint64_t a[MAXLEN + 1])
{
uint64_t temp[MAXLEN] = {0};
for (int i = 0; i + R_words < MAXLEN + 1; i++)
temp[i] = a[i + R_words];
for (int i = 0; i < MAXLEN; i++)
res[i] = temp[i];
}

还有一点,文档中提到:

There is no support in GCC for expressing an integer constant of type __int128 for targets with long long integer less than 128 bits wide.

如果我们需要一个 128 位的整数常量,例如基BASE,其值是 2642^{64},我们可以这样写:

1
const __uint128_t BASE = (static_cast<__uint128_t>(1) << 64);

最后,再次感谢助教的指点。(要说还有能改进的地方,就是蒙哥马利约简有常数更小的实现,不过有机会再写吧 qwq)