瀏覽代碼

move `min/max` to cpp

blueloveTH 1 年之前
父節點
當前提交
022e9c53fb
共有 6 個文件被更改,包括 60 次插入43 次删除
  1. 3 1
      include/pocketpy/vm.h
  2. 0 32
      python/builtins.py
  3. 0 0
      src/_generated.cpp
  4. 8 0
      src/pocketpy.cpp
  5. 9 7
      src/random.cpp
  6. 40 3
      src/vm.cpp

+ 3 - 1
include/pocketpy/vm.h

@@ -190,7 +190,7 @@ public:
     PyObject* py_json(PyObject* obj);
     PyObject* py_iter(PyObject* obj);
 
-    std::pair<PyObject**, int> _cast_array(PyObject* obj);
+    ArgsView _cast_array_view(PyObject* obj);
 
     PyObject* find_name_in_mro(Type cls, StrName name);
     bool isinstance(PyObject* obj, Type base);
@@ -312,6 +312,8 @@ public:
     PyObject* _py_next(const PyTypeInfo*, PyObject*);
     PyObject* _pack_next_retval(unsigned);
     bool py_callable(PyObject* obj);
+
+    PyObject* _minmax_reduce(bool (VM::*op)(PyObject*, PyObject*), PyObject* args, PyObject* key);
     
     /***** Error Reporter *****/
     void _raise(bool re_raise=false);

+ 0 - 32
python/builtins.py

@@ -1,37 +1,5 @@
-from operator import lt as __operator_lt
-from operator import gt as __operator_gt
 from __builtins import next as __builtins_next
 
-def __minmax_reduce(op, args, key):
-    if key is None:
-        if len(args) == 2:
-            return args[0] if op(args[0], args[1]) else args[1]
-    if len(args) == 0:
-        raise TypeError('expected 1 arguments, got 0')
-    if len(args) == 1:
-        args = args[0]
-    args = iter(args)
-    res = __builtins_next(args)
-    if res is StopIteration:
-        raise ValueError('args is an empty sequence')
-    while True:
-        i = __builtins_next(args)
-        if i is StopIteration:
-            break
-        if key is None:
-            if op(i, res):
-                res = i
-        else:
-            if op(key(i), key(res)):
-                res = i
-    return res
-
-def min(*args, key=None):
-    return __minmax_reduce(__operator_lt, args, key)
-
-def max(*args, key=None):
-    return __minmax_reduce(__operator_gt, args, key)
-
 def all(iterable):
     for i in iterable:
         if not i:

文件差異過大導致無法顯示
+ 0 - 0
src/_generated.cpp


+ 8 - 0
src/pocketpy.cpp

@@ -162,6 +162,14 @@ void init_builtins(VM* _vm) {
         return vm->None;
     });
 
+    _vm->bind(_vm->builtins, "max(*args, key=None)", [](VM* vm, ArgsView args){
+        return vm->_minmax_reduce(&VM::py_gt, args[0], args[1]);
+    });
+
+    _vm->bind(_vm->builtins, "min(*args, key=None)", [](VM* vm, ArgsView args){
+        return vm->_minmax_reduce(&VM::py_lt, args[0], args[1]);
+    });
+
     _vm->bind_func<1>(_vm->builtins, "id", [](VM* vm, ArgsView args) {
         PyObject* obj = args[0];
         if(is_tagged(obj)) return vm->None;

+ 9 - 7
src/random.cpp

@@ -178,22 +178,24 @@ struct Random{
 
         vm->bind_method<1>(type, "choice", [](VM* vm, ArgsView args) {
             Random& self = PK_OBJ_GET(Random, args[0]);
-            auto [data, size] = vm->_cast_array(args[1]);
-            if(size == 0) vm->IndexError("cannot choose from an empty sequence");
-            int index = self.gen.randint(0, size-1);
-            return data[index];
+            ArgsView view = vm->_cast_array_view(args[1]);
+            if(view.empty()) vm->IndexError("cannot choose from an empty sequence");
+            int index = self.gen.randint(0, view.size()-1);
+            return view[index];
         });
 
         vm->bind(type, "choices(self, population, weights=None, k=1)", [](VM* vm, ArgsView args) {
             Random& self = PK_OBJ_GET(Random, args[0]);
-            auto [data, size] = vm->_cast_array(args[1]);
+            ArgsView view = vm->_cast_array_view(args[1]);
+            PyObject** data = view.begin();
+            int size = view.size();
             if(size == 0) vm->IndexError("cannot choose from an empty sequence");
             pod_vector<f64> cum_weights(size);
             if(args[2] == vm->None){
                 for(int i = 0; i < size; i++) cum_weights[i] = i + 1;
             }else{
-                auto [weights, weights_size] = vm->_cast_array(args[2]);
-                if(weights_size != size) vm->ValueError(_S("len(weights) != ", size));
+                ArgsView weights = vm->_cast_array_view(args[2]);
+                if(weights.size() != size) vm->ValueError(_S("len(weights) != ", size));
                 cum_weights[0] = CAST(f64, weights[0]);
                 for(int i = 1; i < size; i++){
                     cum_weights[i] = cum_weights[i - 1] + CAST(f64, weights[i]);

+ 40 - 3
src/vm.cpp

@@ -113,13 +113,13 @@ namespace pkpy{
         return nullptr;
     }
 
-    std::pair<PyObject**, int> VM::_cast_array(PyObject* obj){
+    ArgsView VM::_cast_array_view(PyObject* obj){
         if(is_type(obj, VM::tp_list)){
             List& list = PK_OBJ_GET(List, obj);
-            return {list.data(), list.size()};
+            return ArgsView(list.begin(), list.end());
         }else if(is_type(obj, VM::tp_tuple)){
             Tuple& tuple = PK_OBJ_GET(Tuple, obj);
-            return {tuple.data(), tuple.size()};
+            return ArgsView(tuple.begin(), tuple.end());
         }
         TypeError(_S("expected list or tuple, got ", _type_name(this, _tp(obj)).escape()));
         PK_UNREACHABLE();
@@ -270,6 +270,43 @@ namespace pkpy{
         return vm->find_name_in_mro(cls, __call__) != nullptr;
     }
 
+    PyObject* VM::_minmax_reduce(bool (VM::*op)(PyObject*, PyObject*), PyObject* args, PyObject* key){
+        auto _lock = heap.gc_scope_lock();
+        const Tuple& args_tuple = PK_OBJ_GET(Tuple, args);  // from *args, it must be a tuple
+        if(key==vm->None && args_tuple.size()==2){
+            // fast path
+            PyObject* a = args_tuple[0];
+            PyObject* b = args_tuple[1];
+            return (this->*op)(a, b) ? a : b;
+        }
+
+        if(args_tuple.size() == 0) TypeError("expected at least 1 argument, got 0");
+        
+        ArgsView view(nullptr, nullptr);
+        if(args_tuple.size()==1){
+            view = _cast_array_view(args_tuple[0]);
+        }else{
+            view = ArgsView(args_tuple);
+        }
+
+        if(view.empty()) ValueError("arg is an empty sequence");
+        PyObject* res = view[0];
+
+        if(key == vm->None){
+            for(int i=1; i<view.size(); i++){
+                if((this->*op)(view[i], res)) res = view[i];
+            }
+        }else{
+            auto _lock = heap.gc_scope_lock();
+            for(int i=1; i<view.size(); i++){
+                PyObject* a = call(key, view[i]);
+                PyObject* b = call(key, res);
+                if((this->*op)(a, b)) res = view[i];
+            }
+        }
+        return res;
+    }
+
     PyObject* VM::py_import(Str path, bool throw_err){
         if(path.empty()) vm->ValueError("empty module name");
         static auto f_join = [](const pod_vector<std::string_view>& cpnts){

部分文件因文件數量過多而無法顯示