_long.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. from c import sizeof
  2. # https://www.cnblogs.com/liuchanglc/p/14203783.html
  3. if sizeof('void_p') == 4:
  4. PyLong_SHIFT = 28//2 - 1
  5. PyLong_NTT_P = 12289
  6. PyLong_NTT_PR = 11
  7. elif sizeof('void_p') == 8:
  8. PyLong_SHIFT = 60//2 - 1
  9. # 998244353 can not be compiled in 32-bit platform (even it is not used)
  10. PyLong_NTT_P = 998244353 # PyLong_NTT_P**2 should not overflow
  11. PyLong_NTT_PR = 3
  12. else:
  13. raise NotImplementedError
  14. PyLong_BASE = 2 ** PyLong_SHIFT
  15. PyLong_MASK = PyLong_BASE - 1
  16. PyLong_DECIMAL_SHIFT = 4
  17. PyLong_DECIMAL_BASE = 10 ** PyLong_DECIMAL_SHIFT
  18. assert PyLong_NTT_P > PyLong_BASE
  19. #----------------------------------------------------------------------------#
  20. # #
  21. # Number Theoretic Transform #
  22. # #
  23. #----------------------------------------------------------------------------#
  24. def ibin(n, bits):
  25. assert type(bits) is int and bits >= 0
  26. return bin(n)[2:].rjust(bits, "0")
  27. def _number_theoretic_transform(a: list, p, pr, inverse):
  28. n = len(a)
  29. assert n&(n - 1) == 0
  30. a = [x % p for x in a]
  31. b = n.bit_length() - 1
  32. for i in range(1, n):
  33. j = int(ibin(i, b)[::-1], 2)
  34. if i < j:
  35. a[i], a[j] = a[j], a[i]
  36. rt = pow(pr, (p - 1) // n, p)
  37. if inverse:
  38. rt = pow(rt, p - 2, p)
  39. w = [1]*(n // 2)
  40. for i in range(1, n // 2):
  41. w[i] = w[i - 1]*rt % p
  42. h = 2
  43. while h <= n:
  44. hf, ut = h // 2, n // h
  45. for i in range(0, n, h):
  46. for j in range(hf):
  47. u = a[i + j]
  48. v = a[i + j + hf] * w[ut * j] % p
  49. a[i + j] = (u + v) % p
  50. a[i + j + hf] = (u - v + p) % p
  51. h *= 2
  52. if inverse:
  53. rv = pow(n, p - 2, p)
  54. a = [x*rv % p for x in a]
  55. return a
  56. def ntt(a, p, pr):
  57. return _number_theoretic_transform(a, p, pr, False)
  58. def intt(a, p, pr):
  59. return _number_theoretic_transform(a, p, pr, True)
  60. ##############################################################
  61. def ulong_fromint(x: int):
  62. # return a list of digits and sign
  63. if x == 0: return [0], 1
  64. sign = 1 if x > 0 else -1
  65. if sign < 0: x = -x
  66. res = []
  67. while x:
  68. res.append(x & PyLong_MASK)
  69. x >>= PyLong_SHIFT
  70. return res, sign
  71. def ulong_cmp(a: list, b: list) -> int:
  72. # return 1 if a>b, -1 if a<b, 0 if a==b
  73. if len(a) > len(b): return 1
  74. if len(a) < len(b): return -1
  75. for i in range(len(a)-1, -1, -1):
  76. if a[i] > b[i]: return 1
  77. if a[i] < b[i]: return -1
  78. return 0
  79. def ulong_pad_(a: list, size: int):
  80. # pad leading zeros to have `size` digits
  81. delta = size - len(a)
  82. if delta > 0:
  83. a.extend([0] * delta)
  84. def ulong_unpad_(a: list):
  85. # remove leading zeros
  86. while len(a)>1 and a[-1]==0:
  87. a.pop()
  88. def ulong_add(a: list, b: list) -> list:
  89. res = [0] * max(len(a), len(b))
  90. ulong_pad_(a, len(res))
  91. ulong_pad_(b, len(res))
  92. carry = 0
  93. for i in range(len(res)):
  94. carry += a[i] + b[i]
  95. res[i] = carry & PyLong_MASK
  96. carry >>= PyLong_SHIFT
  97. if carry > 0:
  98. res.append(carry)
  99. return res
  100. def ulong_sub(a: list, b: list) -> list:
  101. # a >= b
  102. res = []
  103. borrow = 0
  104. for i in range(len(b)):
  105. tmp = a[i] - b[i] - borrow
  106. if tmp < 0:
  107. tmp += PyLong_BASE
  108. borrow = 1
  109. else:
  110. borrow = 0
  111. res.append(tmp)
  112. for i in range(len(b), len(a)):
  113. tmp = a[i] - borrow
  114. if tmp < 0:
  115. tmp += PyLong_BASE
  116. borrow = 1
  117. else:
  118. borrow = 0
  119. res.append(tmp)
  120. ulong_unpad_(res)
  121. return res
  122. def ulong_divmodi(a: list, b: int):
  123. # b > 0
  124. res = []
  125. carry = 0
  126. for i in range(len(a)-1, -1, -1):
  127. carry <<= PyLong_SHIFT
  128. carry += a[i]
  129. res.append(carry // b)
  130. carry %= b
  131. res.reverse()
  132. ulong_unpad_(res)
  133. return res, carry
  134. def ulong_floordivi(a: list, b: int):
  135. # b > 0
  136. return ulong_divmodi(a, b)[0]
  137. def ulong_muli(a: list, b: int):
  138. # b >= 0
  139. res = [0] * len(a)
  140. carry = 0
  141. for i in range(len(a)):
  142. carry += a[i] * b
  143. res[i] = carry & PyLong_MASK
  144. carry >>= PyLong_SHIFT
  145. if carry > 0:
  146. res.append(carry)
  147. return res
  148. def ulong_mul(a: list, b: list):
  149. N = len(a) + len(b)
  150. if False:
  151. # use grade-school multiplication
  152. res = [0] * N
  153. for i in range(len(a)):
  154. carry = 0
  155. for j in range(len(b)):
  156. carry += res[i+j] + a[i] * b[j]
  157. res[i+j] = carry & PyLong_MASK
  158. carry >>= PyLong_SHIFT
  159. res[i+len(b)] = carry
  160. ulong_unpad_(res)
  161. return res
  162. else:
  163. # use fast number-theoretic transform
  164. limit = 1
  165. while limit < N:
  166. limit <<= 1
  167. a += [0]*(limit - len(a))
  168. b += [0]*(limit - len(b))
  169. # print(a, b)
  170. a = ntt(a, PyLong_NTT_P, PyLong_NTT_PR)
  171. b = ntt(b, PyLong_NTT_P, PyLong_NTT_PR)
  172. # print(a, b)
  173. c = [0] * limit
  174. for i in range(limit):
  175. c[i] = (a[i] * b[i]) % PyLong_NTT_P
  176. # print(c)
  177. c = intt(c, PyLong_NTT_P, PyLong_NTT_PR)
  178. # print(c)
  179. # handle carry
  180. carry = 0
  181. for i in range(limit-1):
  182. carry += c[i]
  183. c[i] = carry & PyLong_MASK
  184. carry >>= PyLong_SHIFT
  185. if carry > 0:
  186. c[limit-1] = carry
  187. # print(c)
  188. ulong_unpad_(c) # should we use this?
  189. # print(c)
  190. return c
  191. def ulong_powi(a: list, b: int):
  192. # b >= 0
  193. if b == 0: return [1]
  194. res = [1]
  195. while b:
  196. if b & 1:
  197. res = ulong_mul(res, a)
  198. a = ulong_mul(a, a)
  199. b >>= 1
  200. return res
  201. def ulong_repr(x: list) -> str:
  202. res = []
  203. while len(x)>1 or x[0]>0: # non-zero
  204. x, r = ulong_divmodi(x, PyLong_DECIMAL_BASE)
  205. res.append(str(r).zfill(PyLong_DECIMAL_SHIFT))
  206. res.reverse()
  207. s = ''.join(res)
  208. if len(s) == 0: return '0'
  209. if len(s) > 1: s = s.lstrip('0')
  210. return s
  211. def ulong_fromstr(s: str):
  212. if s[-1] == 'L':
  213. s = s[:-1]
  214. res, base = [0], [1]
  215. if s[0] == '-':
  216. sign = -1
  217. s = s[1:]
  218. else:
  219. sign = 1
  220. s = s[::-1]
  221. for c in s:
  222. c = ord(c) - 48
  223. assert 0 <= c <= 9
  224. res = ulong_add(res, ulong_muli(base, c))
  225. base = ulong_muli(base, 10)
  226. return res, sign
  227. class long:
  228. def __init__(self, x):
  229. if type(x) is tuple:
  230. self.digits, self.sign = x
  231. elif type(x) is int:
  232. self.digits, self.sign = ulong_fromint(x)
  233. elif type(x) is float:
  234. self.digits, self.sign = ulong_fromint(int(x))
  235. elif type(x) is str:
  236. self.digits, self.sign = ulong_fromstr(x)
  237. elif type(x) is long:
  238. self.digits, self.sign = x.digits.copy(), x.sign
  239. else:
  240. raise TypeError('expected int or str')
  241. def __add__(self, other):
  242. if type(other) is int:
  243. other = long(other)
  244. elif type(other) is not long:
  245. return NotImplemented
  246. if self.sign == other.sign:
  247. return long((ulong_add(self.digits, other.digits), self.sign))
  248. else:
  249. cmp = ulong_cmp(self.digits, other.digits)
  250. if cmp == 0:
  251. return long(0)
  252. if cmp > 0:
  253. return long((ulong_sub(self.digits, other.digits), self.sign))
  254. else:
  255. return long((ulong_sub(other.digits, self.digits), other.sign))
  256. def __radd__(self, other):
  257. return self.__add__(other)
  258. def __sub__(self, other):
  259. if type(other) is int:
  260. other = long(other)
  261. elif type(other) is not long:
  262. return NotImplemented
  263. if self.sign != other.sign:
  264. return long((ulong_add(self.digits, other.digits), self.sign))
  265. cmp = ulong_cmp(self.digits, other.digits)
  266. if cmp == 0:
  267. return long(0)
  268. if cmp > 0:
  269. return long((ulong_sub(self.digits, other.digits), self.sign))
  270. else:
  271. return long((ulong_sub(other.digits, self.digits), -other.sign))
  272. def __rsub__(self, other):
  273. if type(other) is int:
  274. other = long(other)
  275. elif type(other) is not long:
  276. return NotImplemented
  277. return other.__sub__(self)
  278. def __mul__(self, other):
  279. if type(other) is int:
  280. return long((
  281. ulong_muli(self.digits, abs(other)),
  282. self.sign * (1 if other >= 0 else -1)
  283. ))
  284. elif type(other) is long:
  285. return long((
  286. ulong_mul(self.digits, other.digits),
  287. self.sign * other.sign
  288. ))
  289. return NotImplemented
  290. def __rmul__(self, other):
  291. return self.__mul__(other)
  292. #######################################################
  293. def __divmod__(self, other):
  294. if type(other) is int:
  295. assert type(other) is int and other > 0
  296. assert self.sign == 1
  297. q, r = ulong_divmodi(self.digits, other)
  298. return long((q, 1)), r
  299. raise NotImplementedError
  300. def __floordiv__(self, other: int):
  301. return self.__divmod__(other)[0]
  302. def __mod__(self, other: int):
  303. return self.__divmod__(other)[1]
  304. def __pow__(self, other: int):
  305. assert type(other) is int and other >= 0
  306. if self.sign == -1 and other & 1:
  307. sign = -1
  308. else:
  309. sign = 1
  310. return long((ulong_powi(self.digits, other), sign))
  311. def __lshift__(self, other: int):
  312. assert type(other) is int and other >= 0
  313. x = self.digits.copy()
  314. q, r = divmod(other, PyLong_SHIFT)
  315. x = [0]*q + x
  316. for _ in range(r): x = ulong_muli(x, 2)
  317. return long((x, self.sign))
  318. def __rshift__(self, other: int):
  319. assert type(other) is int and other >= 0
  320. x = self.digits.copy()
  321. q, r = divmod(other, PyLong_SHIFT)
  322. x = x[q:]
  323. if not x: return long(0)
  324. for _ in range(r): x = ulong_floordivi(x, 2)
  325. return long((x, self.sign))
  326. def __neg__(self):
  327. return long((self.digits, -self.sign))
  328. def __cmp__(self, other):
  329. if type(other) is int:
  330. other = long(other)
  331. else:
  332. assert type(other) is long
  333. if self.sign > other.sign:
  334. return 1
  335. elif self.sign < other.sign:
  336. return -1
  337. else:
  338. return ulong_cmp(self.digits, other.digits)
  339. def __eq__(self, other):
  340. return self.__cmp__(other) == 0
  341. def __lt__(self, other):
  342. return self.__cmp__(other) < 0
  343. def __le__(self, other):
  344. return self.__cmp__(other) <= 0
  345. def __gt__(self, other):
  346. return self.__cmp__(other) > 0
  347. def __ge__(self, other):
  348. return self.__cmp__(other) >= 0
  349. def __repr__(self):
  350. prefix = '-' if self.sign < 0 else ''
  351. return prefix + ulong_repr(self.digits) + 'L'