Quellcode durchsuchen

add `random.choices`

blueloveTH vor 2 Jahren
Ursprung
Commit
7095db428c
5 geänderte Dateien mit 71 neuen und 4 gelöschten Zeilen
  1. 1 0
      include/pocketpy/tuplelist.h
  2. 2 0
      include/pocketpy/vm.h
  3. 32 3
      src/random.cpp
  4. 12 0
      src/vm.cpp
  5. 24 1
      tests/70_random.py

+ 1 - 0
include/pocketpy/tuplelist.h

@@ -31,6 +31,7 @@ struct Tuple {
 
     PyObject** begin() const { return _args; }
     PyObject** end() const { return _args + _size; }
+    PyObject** data() const { return _args; }
 };
 
 // a lightweight view for function args, it does not own the memory

+ 2 - 0
include/pocketpy/vm.h

@@ -184,6 +184,8 @@ public:
     PyObject* py_json(PyObject* obj);
     PyObject* py_iter(PyObject* obj);
 
+    std::pair<PyObject**, int> _cast_array(PyObject* obj);
+
     PyObject* find_name_in_mro(Type cls, StrName name);
     bool isinstance(PyObject* obj, Type base);
     bool issubclass(Type cls, Type base);

+ 32 - 3
src/random.cpp

@@ -23,6 +23,7 @@ struct Random{
             Random& self = _CAST(Random&, args[0]);
             i64 a = CAST(i64, args[1]);
             i64 b = CAST(i64, args[2]);
+            if (a > b) vm->ValueError("randint(a, b): a must be less than or equal to b");
             std::uniform_int_distribution<i64> dis(a, b);
             return VAR(dis(self.gen));
         });
@@ -50,9 +51,36 @@ struct Random{
 
         vm->bind_method<1>(type, "choice", [](VM* vm, ArgsView args) {
             Random& self = _CAST(Random&, args[0]);
-            const List& L = CAST(List&, args[1]);
-            std::uniform_int_distribution<i64> dis(0, L.size() - 1);
-            return L[dis(self.gen)];
+            auto [data, size] = vm->_cast_array(args[1]);
+            if(size == 0) vm->IndexError("cannot choose from an empty sequence");
+            std::uniform_int_distribution<i64> dis(0, size - 1);
+            return data[dis(self.gen)];
+        });
+
+        vm->bind(type, "choices(self, population, weights=None, k=1)", [](VM* vm, ArgsView args) {
+            Random& self = _CAST(Random&, args[0]);
+            auto [data, size] = vm->_cast_array(args[1]);
+            if(size == 0) vm->IndexError("cannot choose from an empty sequence");
+            std::vector<f64> cum_weights(size);
+            if(args[2] == vm->None){
+                for(int i = 0; i < size; i++) cum_weights[i] = i + 1;
+            }else{
+                auto [weights, weights_size] = vm->_cast_array(args[2]);
+                if(weights_size != size) vm->ValueError(_S("len(weights) != ", size));
+                cum_weights[0] = CAST(f64, weights[0]);
+                for(int i = 1; i < size; i++){
+                    cum_weights[i] = cum_weights[i - 1] + CAST(f64, weights[i]);
+                }
+            }
+            if(cum_weights[size - 1] <= 0) vm->ValueError("total of weights must be greater than zero");
+            int k = CAST(i64, args[3]);
+            List result(k);
+            for(int i = 0; i < k; i++){
+                f64 r = std::uniform_real_distribution<f64>(0.0, cum_weights[size - 1])(self.gen);
+                int idx = std::lower_bound(cum_weights.begin(), cum_weights.end(), r) - cum_weights.begin();
+                result[i] = data[idx];
+            }
+            return VAR(std::move(result));
         });
     }
 };
@@ -67,6 +95,7 @@ void add_module_random(VM* vm){
     mod->attr().set("randint", vm->getattr(instance, "randint"));
     mod->attr().set("shuffle", vm->getattr(instance, "shuffle"));
     mod->attr().set("choice", vm->getattr(instance, "choice"));
+    mod->attr().set("choices", vm->getattr(instance, "choices"));
 }
 
 }   // namespace pkpy

+ 12 - 0
src/vm.cpp

@@ -117,6 +117,18 @@ namespace pkpy{
         return nullptr;
     }
 
+    std::pair<PyObject**, int> VM::_cast_array(PyObject* obj){
+        if(is_non_tagged_type(obj, VM::tp_list)){
+            List& list = PK_OBJ_GET(List, obj);
+            return {list.data(), list.size()};
+        }else if(is_non_tagged_type(obj, VM::tp_tuple)){
+            Tuple& tuple = PK_OBJ_GET(Tuple, obj);
+            return {tuple.data(), tuple.size()};
+        }
+        TypeError(_S("expected list or tuple, got ", _type_name(this, _tp(obj)).escape()));
+        PK_UNREACHABLE();
+    }
+
     FrameId VM::top_frame(){
 #if PK_DEBUG_EXTRA_CHECK
         if(callstack.empty()) PK_FATAL_ERROR();

+ 24 - 1
tests/70_random.py

@@ -14,8 +14,31 @@ for _ in range(100):
 a = [1, 2, 3, 4]
 r.shuffle(a)
 
-for i in range(100):
+for i in range(10):
     assert r.choice(a) in a
 
+for i in range(10):
+    assert r.choice(tuple(a)) in a
 
+for i in range(10):
+    assert r.randint(1, 1) == 1
 
+# test choices
+x = (1,)
+res = r.choices(x, k=4)
+assert (res == [1, 1, 1, 1]), res
+
+w = (1, 2, 3)
+assert r.choices([1, 2, 3], (0.0, 0.0, 0.5)) == [3]
+
+try:
+    r.choices([1, 2, 3], (0.0, 0.0, 0.5, 0.5))
+    exit(1)
+except ValueError:
+    pass
+
+try:
+    r.choices([])
+    exit(1)
+except IndexError:
+    pass