blueloveTH 2 tahun lalu
induk
melakukan
611a4282aa
1 mengubah file dengan 107 tambahan dan 10 penghapusan
  1. 107 10
      python/_long.py

+ 107 - 10
python/_long.py

@@ -1,9 +1,14 @@
 from c import sizeof
 from c import sizeof
 
 
+# https://www.cnblogs.com/liuchanglc/p/14203783.html
 if sizeof('void_p') == 4:
 if sizeof('void_p') == 4:
-    PyLong_SHIFT = 28//2
+    PyLong_SHIFT = 28//2 - 1
+    PyLong_NTT_P = 12289
+    PyLong_NTT_PR = 11
 elif sizeof('void_p') == 8:
 elif sizeof('void_p') == 8:
-    PyLong_SHIFT = 60//2
+    PyLong_SHIFT = 60//2 - 1
+    PyLong_NTT_P = 1004535809   # PyLong_NTT_P**2 should not overflow
+    PyLong_NTT_PR = 3
 else:
 else:
     raise NotImplementedError
     raise NotImplementedError
 
 
@@ -12,6 +17,64 @@ PyLong_MASK = PyLong_BASE - 1
 PyLong_DECIMAL_SHIFT = 4
 PyLong_DECIMAL_SHIFT = 4
 PyLong_DECIMAL_BASE = 10 ** PyLong_DECIMAL_SHIFT
 PyLong_DECIMAL_BASE = 10 ** PyLong_DECIMAL_SHIFT
 
 
+assert PyLong_NTT_P > PyLong_BASE
+
+#----------------------------------------------------------------------------#
+#                                                                            #
+#                         Number Theoretic Transform                         #
+#                                                                            #
+#----------------------------------------------------------------------------#
+
+def ibin(n, bits):
+    assert type(bits) is int and bits >= 0
+    return bin(n)[2:].rjust(bits, "0")
+
+def _number_theoretic_transform(a: list, p, pr, inverse):
+    n = len(a)
+    assert n&(n - 1) == 0
+
+    a = [x % p for x in a]
+    b = n.bit_length() - 1
+
+    for i in range(1, n):
+        j = int(ibin(i, b)[::-1], 2)
+        if i < j:
+            a[i], a[j] = a[j], a[i]
+
+    rt = pow(pr, (p - 1) // n, p)
+    if inverse:
+        rt = pow(rt, p - 2, p)
+
+    w = [1]*(n // 2)
+    for i in range(1, n // 2):
+        w[i] = w[i - 1]*rt % p
+
+    h = 2
+    while h <= n:
+        hf, ut = h // 2, n // h
+        for i in range(0, n, h):
+            for j in range(hf):
+                u = a[i + j]
+                v = a[i + j + hf] * w[ut * j] % p
+                a[i + j] = (u + v) % p
+                a[i + j + hf] = (u - v + p) % p
+        h *= 2
+
+    if inverse:
+        rv = pow(n, p - 2, p)
+        a = [x*rv % p for x in a]
+
+    return a
+
+
+def ntt(a, p, pr):
+    return _number_theoretic_transform(a, p, pr, False)
+
+def intt(a, p, pr):
+    return _number_theoretic_transform(a, p, pr, True)
+
+##############################################################
+
 def ulong_fromint(x: int):
 def ulong_fromint(x: int):
     # return a list of digits and sign
     # return a list of digits and sign
     if x == 0: return [0], 1
     if x == 0: return [0], 1
@@ -109,16 +172,50 @@ def ulong_muli(a: list, b: int):
     return res
     return res
 
 
 def ulong_mul(a: list, b: list):
 def ulong_mul(a: list, b: list):
-    res = [0] * (len(a) + len(b))
-    for i in range(len(a)):
+    N = len(a) + len(b)
+    if False:
+        # use grade-school multiplication
+        res = [0] * N
+        for i in range(len(a)):
+            carry = 0
+            for j in range(len(b)):
+                carry += res[i+j] + a[i] * b[j]
+                res[i+j] = carry & PyLong_MASK
+                carry >>= PyLong_SHIFT
+            res[i+len(b)] = carry
+        ulong_unpad_(res)
+        return res
+    else:
+        # use fast number-theoretic transform
+        limit = 1
+        while limit < N:
+            limit <<= 1
+        a += [0]*(limit - len(a))
+        b += [0]*(limit - len(b))
+        # print(a, b)
+        a = ntt(a, PyLong_NTT_P, PyLong_NTT_PR)
+        b = ntt(b, PyLong_NTT_P, PyLong_NTT_PR)
+        # print(a, b)
+        c = [0] * limit
+        for i in range(limit):
+            c[i] = (a[i] * b[i]) % PyLong_NTT_P
+
+        # print(c)
+        c = intt(c, PyLong_NTT_P, PyLong_NTT_PR)
+        # print(c)
+
+        # handle carry
         carry = 0
         carry = 0
-        for j in range(len(b)):
-            carry += res[i+j] + a[i] * b[j]
-            res[i+j] = carry & PyLong_MASK
+        for i in range(limit-1):
+            carry += c[i]
+            c[i] = carry & PyLong_MASK
             carry >>= PyLong_SHIFT
             carry >>= PyLong_SHIFT
-        res[i+len(b)] = carry
-    ulong_unpad_(res)
-    return res
+        if carry > 0:
+            c[limit-1] = carry
+        # print(c)
+        ulong_unpad_(c)     # should we use this?
+        # print(c)
+        return c
 
 
 def ulong_powi(a: list, b: int):
 def ulong_powi(a: list, b: int):
     # b >= 0
     # b >= 0