primes.lua 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. local UPPER_BOUND = 5000000
  2. local PREFIX = 32338
  3. local Node = {}
  4. function Node:new()
  5. local obj = {}
  6. setmetatable(obj, self)
  7. self.__index = self
  8. obj.children = {}
  9. obj.terminal = false
  10. return obj
  11. end
  12. local Sieve = {}
  13. function Sieve:new(limit)
  14. local obj = {}
  15. setmetatable(obj, self)
  16. self.__index = self
  17. obj.limit = limit
  18. obj.prime = {}
  19. for i = 0, limit do
  20. obj.prime[i] = false
  21. end
  22. return obj
  23. end
  24. function Sieve:to_list()
  25. local result = {2, 3}
  26. for p = 5, self.limit do
  27. if self.prime[p] then
  28. table.insert(result, p)
  29. end
  30. end
  31. return result
  32. end
  33. function Sieve:omit_squares()
  34. local r = 5
  35. while r * r < self.limit do
  36. if self.prime[r] then
  37. local i = r * r
  38. while i < self.limit do
  39. self.prime[i] = false
  40. i = i + r * r
  41. end
  42. end
  43. r = r + 1
  44. end
  45. return self
  46. end
  47. function Sieve:step1(x, y)
  48. local n = (4 * x * x) + (y * y)
  49. if n <= self.limit and (n % 12 == 1 or n % 12 == 5) then
  50. self.prime[n] = not self.prime[n]
  51. end
  52. end
  53. function Sieve:step2(x, y)
  54. local n = (3 * x * x) + (y * y)
  55. if n <= self.limit and n % 12 == 7 then
  56. self.prime[n] = not self.prime[n]
  57. end
  58. end
  59. function Sieve:step3(x, y)
  60. local n = (3 * x * x) - (y * y)
  61. if x > y and n <= self.limit and n % 12 == 11 then
  62. self.prime[n] = not self.prime[n]
  63. end
  64. end
  65. function Sieve:loop_y(x)
  66. local y = 1
  67. while y * y < self.limit do
  68. self:step1(x, y)
  69. self:step2(x, y)
  70. self:step3(x, y)
  71. y = y + 1
  72. end
  73. end
  74. function Sieve:loop_x()
  75. local x = 1
  76. while x * x < self.limit do
  77. self:loop_y(x)
  78. x = x + 1
  79. end
  80. end
  81. function Sieve:calc()
  82. self:loop_x()
  83. return self:omit_squares()
  84. end
  85. local function generate_trie(l)
  86. local root = Node:new()
  87. for _, el in ipairs(l) do
  88. local head = root
  89. -- attempt to call a nil value (method 'split')
  90. -- how to fix? use string.split
  91. el = tostring(el)
  92. for i=1, #el do
  93. local ch = el:sub(i, i)
  94. if not head.children[ch] then
  95. head.children[ch] = Node:new()
  96. end
  97. head = head.children[ch]
  98. end
  99. head.terminal = true
  100. end
  101. return root
  102. end
  103. local function find(upper_bound, prefix_)
  104. local primes = Sieve:new(upper_bound):calc()
  105. local str_prefix = tostring(prefix_)
  106. local head = generate_trie(primes:to_list())
  107. for i=1, #str_prefix do
  108. local ch = str_prefix:sub(i, i)
  109. head = head.children[ch]
  110. if head == nil then
  111. return nil
  112. end
  113. end
  114. local queue, result = {{head, str_prefix}}, {}
  115. while #queue > 0 do
  116. local tuple = table.remove(queue)
  117. local top, prefix = tuple[1], tuple[2]
  118. if top.terminal then
  119. table.insert(result, tonumber(prefix))
  120. end
  121. for ch, v in pairs(top.children) do
  122. table.insert(queue, 1, {v, prefix .. ch})
  123. end
  124. end
  125. table.sort(result)
  126. return result
  127. end
  128. local function verify()
  129. local left = {2, 23, 29}
  130. local right = find(100, 2)
  131. if #left ~= #right then
  132. print("length not equal")
  133. os.exit(1)
  134. end
  135. for i, v in ipairs(left) do
  136. if v ~= right[i] then
  137. print(string.format("%s != %s", v, right[i]))
  138. os.exit(1)
  139. end
  140. end
  141. end
  142. verify()
  143. local results = find(UPPER_BOUND, PREFIX)
  144. local expected = {323381, 323383, 3233803, 3233809, 3233851, 3233863, 3233873, 3233887, 3233897}
  145. for i, v in ipairs(results) do
  146. if v ~= expected[i] then
  147. print(string.format("%s != %s", v, expected[i]))
  148. os.exit(1)
  149. end
  150. end