Просмотр исходного кода

add shortcut for `__next__`

Update py_dict.c
blueloveTH 8 месяцев назад
Родитель
Сommit
9b8f706010

+ 11 - 0
include/pocketpy/interpreter/bindings.h

@@ -0,0 +1,11 @@
+#pragma once
+
+#include "pocketpy/pocketpy.h"
+
+bool generator__next__(int argc, py_Ref argv);
+bool array2d_like_iterator__next__(int argc, py_Ref argv);
+bool list_iterator__next__(int argc, py_Ref argv);
+bool tuple_iterator__next__(int argc, py_Ref argv);
+bool dict_items__next__(int argc, py_Ref argv);
+bool range_iterator__next__(int argc, py_Ref argv);
+bool str_iterator__next__(int argc, py_Ref argv);

+ 1 - 1
src/interpreter/generator.c

@@ -20,7 +20,7 @@ void Generator__dtor(Generator* ud) {
     if(ud->frame) Frame__delete(ud->frame);
 }
 
-static bool generator__next__(int argc, py_Ref argv) {
+bool generator__next__(int argc, py_Ref argv) {
     PY_CHECK_ARGC(1);
     Generator* ud = py_touserdata(argv);
     py_StackRef p0 = py_peek(0);

+ 1 - 1
src/modules/array2d.c

@@ -800,7 +800,7 @@ static void register_array2d_like(py_Ref mod) {
     }
 }
 
-static bool array2d_like_iterator__next__(int argc, py_Ref argv) {
+bool array2d_like_iterator__next__(int argc, py_Ref argv) {
     PY_CHECK_ARGC(1);
     c11_array2d_like_iterator* self = py_touserdata(argv);
     if(self->j >= self->array->n_rows) return StopIteration();

+ 2 - 2
src/public/py_array.c

@@ -57,7 +57,7 @@ bool pk_arraycontains(py_Ref self, py_Ref val) {
     return true;
 }
 
-static bool list_iterator__next__(int argc, py_Ref argv) {
+bool list_iterator__next__(int argc, py_Ref argv) {
     PY_CHECK_ARGC(1);
     list_iterator* ud = py_touserdata(argv);
     if(ud->index < ud->vec->length) {
@@ -69,7 +69,7 @@ static bool list_iterator__next__(int argc, py_Ref argv) {
     return StopIteration();
 }
 
-static bool tuple_iterator__next__(int argc, py_Ref argv) {
+bool tuple_iterator__next__(int argc, py_Ref argv) {
     PY_CHECK_ARGC(1);
     tuple_iterator* ud = py_touserdata(argv);
     if(ud->index < ud->length) {

+ 1 - 1
src/public/py_dict.c

@@ -610,7 +610,7 @@ py_Type pk_dict__register() {
 }
 
 //////////////////////////
-static bool dict_items__next__(int argc, py_Ref argv) {
+bool dict_items__next__(int argc, py_Ref argv) {
     PY_CHECK_ARGC(1);
     DictIterator* iter = py_touserdata(py_arg(0));
     if(DictIterator__modified(iter)) return RuntimeError("dictionary modified during iteration");

+ 33 - 5
src/public/py_ops.c

@@ -1,4 +1,5 @@
 #include "pocketpy/interpreter/typeinfo.h"
+#include "pocketpy/interpreter/bindings.h"
 #include "pocketpy/interpreter/vm.h"
 #include "pocketpy/objects/base.h"
 #include "pocketpy/pocketpy.h"
@@ -77,12 +78,39 @@ bool py_iter(py_Ref val) {
 
 int py_next(py_Ref val) {
     VM* vm = pk_current_vm;
-    py_Ref tmp = py_tpfindmagic(val->type, __next__);
-    if(!tmp) {
-        TypeError("'%t' object is not an iterator", val->type);
-        return -1;
+
+    switch(val->type) {
+        case tp_generator:
+            if(generator__next__(1, val)) return 1;
+            break;
+        case tp_array2d_like_iterator:
+            if(array2d_like_iterator__next__(1, val)) return 1;
+            break;
+        case tp_list_iterator:
+            if(list_iterator__next__(1, val)) return 1;
+            break;
+        case tp_tuple_iterator:
+            if(tuple_iterator__next__(1, val)) return 1;
+            break;
+        case tp_dict_iterator:
+            if(dict_items__next__(1, val)) return 1;
+            break;
+        case tp_range_iterator:
+            if(range_iterator__next__(1, val)) return 1;
+            break;
+        case tp_str_iterator:
+            if(str_iterator__next__(1, val)) return 1;
+            break;
+        default: {
+            py_Ref tmp = py_tpfindmagic(val->type, __next__);
+            if(!tmp) {
+                TypeError("'%t' object is not an iterator", val->type);
+                return -1;
+            }
+            if(py_call(tmp, 1, val)) return 1;
+            break;
+        }
     }
-    if(py_call(tmp, 1, val)) return 1;
     if(vm->curr_exception.type == tp_StopIteration) {
         vm->last_retval = vm->curr_exception;
         py_clearexc(NULL);

+ 2 - 2
src/public/py_range.c

@@ -68,9 +68,9 @@ static bool range_iterator__new__(int argc, py_Ref argv) {
     return true;
 }
 
-static bool range_iterator__next__(int argc, py_Ref argv) {
+bool range_iterator__next__(int argc, py_Ref argv) {
     PY_CHECK_ARGC(1);
-    RangeIterator* ud = py_touserdata(py_arg(0));
+    RangeIterator* ud = py_touserdata(argv);
     if(ud->range.step > 0) {
         if(ud->current >= ud->range.stop) return StopIteration();
     } else {

+ 1 - 1
src/public/py_str.c

@@ -513,7 +513,7 @@ py_Type pk_str__register() {
     return type;
 }
 
-static bool str_iterator__next__(int argc, py_Ref argv) {
+bool str_iterator__next__(int argc, py_Ref argv) {
     PY_CHECK_ARGC(1);
     int* ud = py_touserdata(&argv[0]);
     int size;