random.c 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. #include "pocketpy/interpreter/vm.h"
  2. #include "pocketpy/pocketpy.h"
  3. #include <time.h>
  4. int64_t time_ns(); // from random.c
  5. /* https://github.com/clibs/mt19937ar
  6. Copyright (c) 2011 Mutsuo Saito, Makoto Matsumoto, Hiroshima
  7. University and The University of Tokyo. All rights reserved.
  8. Redistribution and use in source and binary forms, with or without
  9. modification, are permitted provided that the following conditions are
  10. met:
  11. * Redistributions of source code must retain the above copyright
  12. notice, this list of conditions and the following disclaimer.
  13. * Redistributions in binary form must reproduce the above
  14. copyright notice, this list of conditions and the following
  15. disclaimer in the documentation and/or other materials provided
  16. with the distribution.
  17. * Neither the name of the Hiroshima University nor the names of
  18. its contributors may be used to endorse or promote products
  19. derived from this software without specific prior written
  20. permission.
  21. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  22. "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  23. LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  24. A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  25. OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  26. SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  27. LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  28. DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  29. THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  30. (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  31. OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  32. */
  33. /* Period parameters */
  34. #define N 624
  35. #define M 397
  36. #define MATRIX_A 0x9908b0dfUL /* constant vector a */
  37. #define UPPER_MASK 0x80000000UL /* most significant w-r bits */
  38. #define LOWER_MASK 0x7fffffffUL /* least significant r bits */
  39. typedef struct mt19937 {
  40. uint32_t mt[N]; /* the array for the state vector */
  41. int mti; /* mti==N+1 means mt[N] is not initialized */
  42. } mt19937;
  43. /* initializes mt[N] with a seed */
  44. static void mt19937__seed(mt19937* self, uint32_t s) {
  45. self->mt[0] = s & 0xffffffffUL;
  46. for(self->mti = 1; self->mti < N; self->mti++) {
  47. self->mt[self->mti] =
  48. (1812433253UL * (self->mt[self->mti - 1] ^ (self->mt[self->mti - 1] >> 30)) +
  49. self->mti);
  50. /* See Knuth TAOCP Vol2. 3rd Ed. P.106 for multiplier. */
  51. /* In the previous versions, MSBs of the seed affect */
  52. /* only MSBs of the array mt[]. */
  53. /* 2002/01/09 modified by Makoto Matsumoto */
  54. self->mt[self->mti] &= 0xffffffffUL;
  55. /* for >32 bit machines */
  56. }
  57. }
  58. static void mt19937__ctor(mt19937* self) { self->mti = N + 1; }
  59. /* generates a random number on [0,0xffffffff]-interval */
  60. static uint32_t mt19937__next_uint32(mt19937* self) {
  61. uint32_t* mt = self->mt;
  62. uint32_t y;
  63. static uint32_t mag01[2] = {0x0UL, MATRIX_A};
  64. /* mag01[x] = x * MATRIX_A for x=0,1 */
  65. if(self->mti >= N) { /* generate N words at one time */
  66. int kk;
  67. if(self->mti == N + 1) { /* if init_genrand() has not been called, */
  68. int64_t seed = time_ns();
  69. mt19937__seed(self, (uint32_t)seed);
  70. // seed(5489UL); /* a default initial seed is used */
  71. }
  72. for(kk = 0; kk < N - M; kk++) {
  73. y = (mt[kk] & UPPER_MASK) | (mt[kk + 1] & LOWER_MASK);
  74. mt[kk] = mt[kk + M] ^ (y >> 1) ^ mag01[y & 0x1UL];
  75. }
  76. for(; kk < N - 1; kk++) {
  77. y = (mt[kk] & UPPER_MASK) | (mt[kk + 1] & LOWER_MASK);
  78. mt[kk] = mt[kk + (M - N)] ^ (y >> 1) ^ mag01[y & 0x1UL];
  79. }
  80. y = (mt[N - 1] & UPPER_MASK) | (mt[0] & LOWER_MASK);
  81. mt[N - 1] = mt[M - 1] ^ (y >> 1) ^ mag01[y & 0x1UL];
  82. self->mti = 0;
  83. }
  84. y = mt[self->mti++];
  85. /* Tempering */
  86. y ^= (y >> 11);
  87. y ^= (y << 7) & 0x9d2c5680UL;
  88. y ^= (y << 15) & 0xefc60000UL;
  89. y ^= (y >> 18);
  90. return y;
  91. }
  92. static uint64_t mt19937__next_uint64(mt19937* self) {
  93. return (uint64_t)mt19937__next_uint32(self) << 32 | mt19937__next_uint32(self);
  94. }
  95. static double mt19937__random(mt19937* self) {
  96. // from cpython
  97. uint32_t a = mt19937__next_uint32(self) >> 5;
  98. uint32_t b = mt19937__next_uint32(self) >> 6;
  99. return (a * 67108864.0 + b) * (1.0 / 9007199254740992.0);
  100. }
  101. static double mt19937__uniform(mt19937* self, double a, double b) {
  102. if(a > b) { return b + mt19937__random(self) * (a - b); }
  103. return a + mt19937__random(self) * (b - a);
  104. }
  105. /* generates a random number on [a, b]-interval */
  106. int64_t mt19937__randint(mt19937* self, int64_t a, int64_t b) {
  107. uint64_t delta = b - a + 1;
  108. if(delta < 0x80000000UL) {
  109. return a + mt19937__next_uint32(self) % delta;
  110. } else {
  111. return a + mt19937__next_uint64(self) % delta;
  112. }
  113. }
  114. static bool Random__new__(int argc, py_Ref argv) {
  115. mt19937* ud = py_newobject(py_retval(), py_totype(argv), 0, sizeof(mt19937));
  116. mt19937__ctor(ud);
  117. return true;
  118. }
  119. static bool Random__init__(int argc, py_Ref argv) {
  120. if(argc == 1) {
  121. // do nothing
  122. } else if(argc == 2) {
  123. mt19937* ud = py_touserdata(py_arg(0));
  124. if(!py_isnone(&argv[1])) {
  125. PY_CHECK_ARG_TYPE(1, tp_int);
  126. py_i64 seed = py_toint(py_arg(1));
  127. mt19937__seed(ud, (uint32_t)seed);
  128. }
  129. } else {
  130. return TypeError("Random(): expected 1 or 2 arguments, got %d");
  131. }
  132. py_newnone(py_retval());
  133. return true;
  134. }
  135. static bool Random_seed(int argc, py_Ref argv) {
  136. PY_CHECK_ARGC(2);
  137. mt19937* ud = py_touserdata(py_arg(0));
  138. py_i64 seed;
  139. if(py_isnone(&argv[1])) {
  140. seed = time_ns();
  141. } else {
  142. PY_CHECK_ARG_TYPE(1, tp_int);
  143. seed = py_toint(py_arg(1));
  144. }
  145. mt19937__seed(ud, (uint32_t)seed);
  146. py_newnone(py_retval());
  147. return true;
  148. }
  149. static bool Random_random(int argc, py_Ref argv) {
  150. PY_CHECK_ARGC(1);
  151. mt19937* ud = py_touserdata(py_arg(0));
  152. py_f64 res = mt19937__random(ud);
  153. py_newfloat(py_retval(), res);
  154. return true;
  155. }
  156. static bool Random_uniform(int argc, py_Ref argv) {
  157. PY_CHECK_ARGC(3);
  158. mt19937* ud = py_touserdata(py_arg(0));
  159. py_f64 a, b;
  160. if(!py_castfloat(py_arg(1), &a)) return false;
  161. if(!py_castfloat(py_arg(2), &b)) return false;
  162. py_f64 res = mt19937__uniform(ud, a, b);
  163. py_newfloat(py_retval(), res);
  164. return true;
  165. }
  166. static bool Random_shuffle(int argc, py_Ref argv) {
  167. PY_CHECK_ARGC(2);
  168. PY_CHECK_ARG_TYPE(1, tp_list);
  169. mt19937* ud = py_touserdata(py_arg(0));
  170. py_Ref L = py_arg(1);
  171. int length = py_list_len(L);
  172. for(int i = length - 1; i > 0; i--) {
  173. int j = mt19937__randint(ud, 0, i);
  174. py_list_swap(L, i, j);
  175. }
  176. py_newnone(py_retval());
  177. return true;
  178. }
  179. static bool Random_randint(int argc, py_Ref argv) {
  180. PY_CHECK_ARGC(3);
  181. PY_CHECK_ARG_TYPE(1, tp_int);
  182. PY_CHECK_ARG_TYPE(2, tp_int);
  183. mt19937* ud = py_touserdata(py_arg(0));
  184. py_i64 a = py_toint(py_arg(1));
  185. py_i64 b = py_toint(py_arg(2));
  186. if(a > b) return ValueError("randint(a, b): a must be less than or equal to b");
  187. py_i64 res = mt19937__randint(ud, a, b);
  188. py_newint(py_retval(), res);
  189. return true;
  190. }
  191. static bool Random_choice(int argc, py_Ref argv) {
  192. PY_CHECK_ARGC(2);
  193. mt19937* ud = py_touserdata(py_arg(0));
  194. py_TValue* p;
  195. int length = pk_arrayview(py_arg(1), &p);
  196. if(length == -1) return TypeError("choice(): argument must be a list or tuple");
  197. if(length == 0) return IndexError("cannot choose from an empty sequence");
  198. int index = mt19937__randint(ud, 0, length - 1);
  199. py_assign(py_retval(), p + index);
  200. return true;
  201. }
  202. static bool Random_choices(int argc, py_Ref argv) {
  203. mt19937* ud = py_touserdata(py_arg(0));
  204. py_TValue* p;
  205. int length = pk_arrayview(py_arg(1), &p);
  206. if(length == -1) return TypeError("choices(): argument must be a list or tuple");
  207. if(length == 0) return IndexError("cannot choose from an empty sequence");
  208. py_Ref weights = py_arg(2);
  209. if(!py_checktype(py_arg(3), tp_int)) return false;
  210. py_i64 k = py_toint(py_arg(3));
  211. py_f64* cum_weights = PK_MALLOC(sizeof(py_f64) * length);
  212. if(py_isnone(weights)) {
  213. for(int i = 0; i < length; i++)
  214. cum_weights[i] = i + 1;
  215. } else {
  216. py_TValue* w;
  217. int wlen = pk_arrayview(weights, &w);
  218. if(wlen == -1) {
  219. PK_FREE(cum_weights);
  220. return TypeError("choices(): weights must be a list or tuple");
  221. }
  222. if(wlen != length) {
  223. PK_FREE(cum_weights);
  224. return ValueError("len(weights) != len(population)");
  225. }
  226. if(!py_castfloat(&w[0], &cum_weights[0])) {
  227. PK_FREE(cum_weights);
  228. return false;
  229. }
  230. for(int i = 1; i < length; i++) {
  231. py_f64 tmp;
  232. if(!py_castfloat(&w[i], &tmp)) {
  233. PK_FREE(cum_weights);
  234. return false;
  235. }
  236. cum_weights[i] = cum_weights[i - 1] + tmp;
  237. }
  238. }
  239. py_f64 total = cum_weights[length - 1];
  240. if(total <= 0) {
  241. PK_FREE(cum_weights);
  242. return ValueError("total of weights must be greater than zero");
  243. }
  244. py_newlistn(py_retval(), k);
  245. for(int i = 0; i < k; i++) {
  246. py_f64 key = mt19937__random(ud) * total;
  247. int index;
  248. c11__lower_bound(py_f64, cum_weights, length, key, c11__less, &index);
  249. assert(index != length);
  250. py_list_setitem(py_retval(), i, p + index);
  251. }
  252. PK_FREE(cum_weights);
  253. return true;
  254. }
  255. void pk__add_module_random() {
  256. py_Ref mod = py_newmodule("random");
  257. py_Type type = py_newtype("Random", tp_object, mod, NULL);
  258. py_bindmagic(type, __new__, Random__new__);
  259. py_bindmagic(type, __init__, Random__init__);
  260. py_bindmethod(type, "seed", Random_seed);
  261. py_bindmethod(type, "random", Random_random);
  262. py_bindmethod(type, "uniform", Random_uniform);
  263. py_bindmethod(type, "randint", Random_randint);
  264. py_bindmethod(type, "shuffle", Random_shuffle);
  265. py_bindmethod(type, "choice", Random_choice);
  266. py_bind(py_tpobject(type), "choices(self, population, weights=None, k=1)", Random_choices);
  267. py_Ref inst = py_pushtmp();
  268. if(!py_tpcall(type, 0, NULL)) goto __ERROR;
  269. py_assign(inst, py_retval());
  270. #define ADD_INST_BOUNDMETHOD(name) \
  271. if(!py_getattr(inst, py_name(name))) goto __ERROR; \
  272. py_setdict(mod, py_name(name), py_retval());
  273. ADD_INST_BOUNDMETHOD("seed");
  274. ADD_INST_BOUNDMETHOD("random");
  275. ADD_INST_BOUNDMETHOD("uniform");
  276. ADD_INST_BOUNDMETHOD("randint");
  277. ADD_INST_BOUNDMETHOD("shuffle");
  278. ADD_INST_BOUNDMETHOD("choice");
  279. ADD_INST_BOUNDMETHOD("choices");
  280. #undef ADD_INST_BOUNDMETHOD
  281. py_pop(); // pop inst
  282. return;
  283. __ERROR:
  284. py_printexc();
  285. c11__abort("failed to add module random");
  286. }
  287. #undef N
  288. #undef M
  289. #undef MATRIX_A
  290. #undef UPPER_MASK
  291. #undef LOWER_MASK
  292. #undef ADD_INST_BOUNDMETHOD
  293. void py_newRandom(py_OutRef out) {
  294. py_Type type = py_gettype("random", py_name("Random"));
  295. assert(type != 0);
  296. mt19937* ud = py_newobject(out, type, 0, sizeof(mt19937));
  297. mt19937__ctor(ud);
  298. }
  299. void py_Random_seed(py_Ref self, py_i64 seed) {
  300. mt19937* ud = py_touserdata(self);
  301. mt19937__seed(ud, (uint32_t)seed);
  302. }
  303. py_f64 py_Random_random(py_Ref self) {
  304. mt19937* ud = py_touserdata(self);
  305. return mt19937__random(ud);
  306. }
  307. py_f64 py_Random_uniform(py_Ref self, py_f64 a, py_f64 b) {
  308. mt19937* ud = py_touserdata(self);
  309. return mt19937__uniform(ud, a, b);
  310. }
  311. py_i64 py_Random_randint(py_Ref self, py_i64 a, py_i64 b) {
  312. mt19937* ud = py_touserdata(self);
  313. if(a > b) { c11__abort("randint(a, b): a must be less than or equal to b"); }
  314. return mt19937__randint(ud, a, b);
  315. }