random.cpp 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. #include "pocketpy/random.h"
  2. namespace pkpy{
  3. struct Random{
  4. PY_CLASS(Random, random, Random)
  5. std::mt19937 gen;
  6. Random(){
  7. gen.seed(std::chrono::high_resolution_clock::now().time_since_epoch().count());
  8. }
  9. static void _register(VM* vm, PyObject* mod, PyObject* type){
  10. vm->bind_default_constructor<Random>(type);
  11. vm->bind_method<1>(type, "seed", [](VM* vm, ArgsView args) {
  12. Random& self = _CAST(Random&, args[0]);
  13. self.gen.seed(CAST(i64, args[1]));
  14. return vm->None;
  15. });
  16. vm->bind_method<2>(type, "randint", [](VM* vm, ArgsView args) {
  17. Random& self = _CAST(Random&, args[0]);
  18. i64 a = CAST(i64, args[1]);
  19. i64 b = CAST(i64, args[2]);
  20. if (a > b) vm->ValueError("randint(a, b): a must be less than or equal to b");
  21. std::uniform_int_distribution<i64> dis(a, b);
  22. return VAR(dis(self.gen));
  23. });
  24. vm->bind_method<0>(type, "random", [](VM* vm, ArgsView args) {
  25. Random& self = _CAST(Random&, args[0]);
  26. std::uniform_real_distribution<f64> dis(0.0, 1.0);
  27. return VAR(dis(self.gen));
  28. });
  29. vm->bind_method<2>(type, "uniform", [](VM* vm, ArgsView args) {
  30. Random& self = _CAST(Random&, args[0]);
  31. f64 a = CAST(f64, args[1]);
  32. f64 b = CAST(f64, args[2]);
  33. std::uniform_real_distribution<f64> dis(a, b);
  34. return VAR(dis(self.gen));
  35. });
  36. vm->bind_method<1>(type, "shuffle", [](VM* vm, ArgsView args) {
  37. Random& self = _CAST(Random&, args[0]);
  38. List& L = CAST(List&, args[1]);
  39. std::shuffle(L.begin(), L.end(), self.gen);
  40. return vm->None;
  41. });
  42. vm->bind_method<1>(type, "choice", [](VM* vm, ArgsView args) {
  43. Random& self = _CAST(Random&, args[0]);
  44. auto [data, size] = vm->_cast_array(args[1]);
  45. if(size == 0) vm->IndexError("cannot choose from an empty sequence");
  46. std::uniform_int_distribution<i64> dis(0, size - 1);
  47. return data[dis(self.gen)];
  48. });
  49. vm->bind(type, "choices(self, population, weights=None, k=1)", [](VM* vm, ArgsView args) {
  50. Random& self = _CAST(Random&, args[0]);
  51. auto [data, size] = vm->_cast_array(args[1]);
  52. if(size == 0) vm->IndexError("cannot choose from an empty sequence");
  53. pod_vector<f64> cum_weights(size);
  54. if(args[2] == vm->None){
  55. for(int i = 0; i < size; i++) cum_weights[i] = i + 1;
  56. }else{
  57. auto [weights, weights_size] = vm->_cast_array(args[2]);
  58. if(weights_size != size) vm->ValueError(_S("len(weights) != ", size));
  59. cum_weights[0] = CAST(f64, weights[0]);
  60. for(int i = 1; i < size; i++){
  61. cum_weights[i] = cum_weights[i - 1] + CAST(f64, weights[i]);
  62. }
  63. }
  64. if(cum_weights[size - 1] <= 0) vm->ValueError("total of weights must be greater than zero");
  65. int k = CAST(i64, args[3]);
  66. List result(k);
  67. for(int i = 0; i < k; i++){
  68. f64 r = std::uniform_real_distribution<f64>(0.0, cum_weights[size - 1])(self.gen);
  69. int idx = std::lower_bound(cum_weights.begin(), cum_weights.end(), r) - cum_weights.begin();
  70. result[i] = data[idx];
  71. }
  72. return VAR(std::move(result));
  73. });
  74. }
  75. };
  76. void add_module_random(VM* vm){
  77. PyObject* mod = vm->new_module("random");
  78. Random::register_class(vm, mod);
  79. PyObject* instance = vm->heap.gcnew<Random>(Random::_type(vm));
  80. mod->attr().set("seed", vm->getattr(instance, "seed"));
  81. mod->attr().set("random", vm->getattr(instance, "random"));
  82. mod->attr().set("uniform", vm->getattr(instance, "uniform"));
  83. mod->attr().set("randint", vm->getattr(instance, "randint"));
  84. mod->attr().set("shuffle", vm->getattr(instance, "shuffle"));
  85. mod->attr().set("choice", vm->getattr(instance, "choice"));
  86. mod->attr().set("choices", vm->getattr(instance, "choices"));
  87. }
  88. } // namespace pkpy