blueloveTH 2 лет назад
Родитель
Сommit
ba248ae0f3
1 измененных файлов с 9 добавлено и 102 удалено
  1. 9 102
      python/_long.py

+ 9 - 102
python/_long.py

@@ -3,13 +3,8 @@ from c import sizeof
 # https://www.cnblogs.com/liuchanglc/p/14203783.html
 # https://www.cnblogs.com/liuchanglc/p/14203783.html
 if sizeof('void_p') == 4:
 if sizeof('void_p') == 4:
     PyLong_SHIFT = 28//2 - 1
     PyLong_SHIFT = 28//2 - 1
-    PyLong_NTT_P = 12289
-    PyLong_NTT_PR = 11
 elif sizeof('void_p') == 8:
 elif sizeof('void_p') == 8:
     PyLong_SHIFT = 60//2 - 1
     PyLong_SHIFT = 60//2 - 1
-    # 998244353 can not be compiled in 32-bit platform (even it is not used)
-    PyLong_NTT_P = 998244353   # PyLong_NTT_P**2 should not overflow
-    PyLong_NTT_PR = 3
 else:
 else:
     raise NotImplementedError
     raise NotImplementedError
 
 
@@ -18,62 +13,6 @@ PyLong_MASK = PyLong_BASE - 1
 PyLong_DECIMAL_SHIFT = 4
 PyLong_DECIMAL_SHIFT = 4
 PyLong_DECIMAL_BASE = 10 ** PyLong_DECIMAL_SHIFT
 PyLong_DECIMAL_BASE = 10 ** PyLong_DECIMAL_SHIFT
 
 
-assert PyLong_NTT_P > PyLong_BASE
-
-#----------------------------------------------------------------------------#
-#                                                                            #
-#                         Number Theoretic Transform                         #
-#                                                                            #
-#----------------------------------------------------------------------------#
-
-def ibin(n, bits):
-    assert type(bits) is int and bits >= 0
-    return bin(n)[2:].rjust(bits, "0")
-
-def _number_theoretic_transform(a: list, p, pr, inverse):
-    n = len(a)
-    assert n&(n - 1) == 0
-
-    a = [x % p for x in a]
-    b = n.bit_length() - 1
-
-    for i in range(1, n):
-        j = int(ibin(i, b)[::-1], 2)
-        if i < j:
-            a[i], a[j] = a[j], a[i]
-
-    rt = pow(pr, (p - 1) // n, p)
-    if inverse:
-        rt = pow(rt, p - 2, p)
-
-    w = [1]*(n // 2)
-    for i in range(1, n // 2):
-        w[i] = w[i - 1]*rt % p
-
-    h = 2
-    while h <= n:
-        hf, ut = h // 2, n // h
-        for i in range(0, n, h):
-            for j in range(hf):
-                u = a[i + j]
-                v = a[i + j + hf] * w[ut * j] % p
-                a[i + j] = (u + v) % p
-                a[i + j + hf] = (u - v + p) % p
-        h *= 2
-
-    if inverse:
-        rv = pow(n, p - 2, p)
-        a = [x*rv % p for x in a]
-
-    return a
-
-
-def ntt(a, p, pr):
-    return _number_theoretic_transform(a, p, pr, False)
-
-def intt(a, p, pr):
-    return _number_theoretic_transform(a, p, pr, True)
-
 ##############################################################
 ##############################################################
 
 
 def ulong_fromint(x: int):
 def ulong_fromint(x: int):
@@ -174,49 +113,17 @@ def ulong_muli(a: list, b: int):
 
 
 def ulong_mul(a: list, b: list):
 def ulong_mul(a: list, b: list):
     N = len(a) + len(b)
     N = len(a) + len(b)
-    if False:
-        # use grade-school multiplication
-        res = [0] * N
-        for i in range(len(a)):
-            carry = 0
-            for j in range(len(b)):
-                carry += res[i+j] + a[i] * b[j]
-                res[i+j] = carry & PyLong_MASK
-                carry >>= PyLong_SHIFT
-            res[i+len(b)] = carry
-        ulong_unpad_(res)
-        return res
-    else:
-        # use fast number-theoretic transform
-        limit = 1
-        while limit < N:
-            limit <<= 1
-        a += [0]*(limit - len(a))
-        b += [0]*(limit - len(b))
-        # print(a, b)
-        a = ntt(a, PyLong_NTT_P, PyLong_NTT_PR)
-        b = ntt(b, PyLong_NTT_P, PyLong_NTT_PR)
-        # print(a, b)
-        c = [0] * limit
-        for i in range(limit):
-            c[i] = (a[i] * b[i]) % PyLong_NTT_P
-
-        # print(c)
-        c = intt(c, PyLong_NTT_P, PyLong_NTT_PR)
-        # print(c)
-
-        # handle carry
+    # use grade-school multiplication
+    res = [0] * N
+    for i in range(len(a)):
         carry = 0
         carry = 0
-        for i in range(limit-1):
-            carry += c[i]
-            c[i] = carry & PyLong_MASK
+        for j in range(len(b)):
+            carry += res[i+j] + a[i] * b[j]
+            res[i+j] = carry & PyLong_MASK
             carry >>= PyLong_SHIFT
             carry >>= PyLong_SHIFT
-        if carry > 0:
-            c[limit-1] = carry
-        # print(c)
-        ulong_unpad_(c)     # should we use this?
-        # print(c)
-        return c
+        res[i+len(b)] = carry
+    ulong_unpad_(res)
+    return res
 
 
 def ulong_powi(a: list, b: int):
 def ulong_powi(a: list, b: int):
     # b >= 0
     # b >= 0