Sfoglia il codice sorgente

fix https://github.com/pocketpy/pocketpy/issues/376

blueloveTH 8 mesi fa
parent
commit
fd6f0d76b2

+ 2 - 2
include/pocketpy/interpreter/vm.h

@@ -113,7 +113,6 @@ const char* pk_opname(Opcode op);
 
 int pk_arrayview(py_Ref self, py_TValue** p);
 bool pk_wrapper__arrayequal(py_Type type, int argc, py_Ref argv);
-bool pk_arrayiter(py_Ref val);
 bool pk_arraycontains(py_Ref self, py_Ref val);
 
 bool pk_loadmethod(py_StackRef self, py_Name name);
@@ -139,7 +138,8 @@ py_Type pk_dict__register();
 py_Type pk_dict_items__register();
 py_Type pk_list__register();
 py_Type pk_tuple__register();
-py_Type pk_array_iterator__register();
+py_Type pk_list_iterator__register();
+py_Type pk_tuple_iterator__register();
 py_Type pk_slice__register();
 py_Type pk_function__register();
 py_Type pk_nativefunc__register();

+ 15 - 0
include/pocketpy/objects/iterator.h

@@ -0,0 +1,15 @@
+#pragma once
+
+#include "pocketpy/objects/base.h"
+#include "pocketpy/common/vector.h"
+
+typedef struct tuple_iterator {
+    py_TValue* p;
+    int length;
+    int index;
+} tuple_iterator;
+
+typedef struct list_iterator {
+    c11_vector* vec;
+    int index;
+} list_iterator;

+ 2 - 1
include/pocketpy/pocketpy.h

@@ -762,7 +762,8 @@ enum py_PredefinedType {
     tp_str_iterator,
     tp_list,   // c11_vector
     tp_tuple,  // N slots
-    tp_array_iterator,
+    tp_list_iterator,   // 1 slot
+    tp_tuple_iterator,  // 1 slot
     tp_slice,  // 3 slots (start, stop, step)
     tp_range,
     tp_range_iterator,

+ 2 - 1
src/interpreter/vm.c

@@ -109,7 +109,8 @@ void VM__ctor(VM* self) {
 
     validate(tp_list, pk_list__register());
     validate(tp_tuple, pk_tuple__register());
-    validate(tp_array_iterator, pk_array_iterator__register());
+    validate(tp_list_iterator, pk_list_iterator__register());
+    validate(tp_tuple_iterator, pk_tuple_iterator__register());
 
     validate(tp_slice, pk_slice__register());
     validate(tp_range, pk_range__register());

+ 25 - 28
src/public/py_array.c

@@ -1,13 +1,8 @@
 #include "pocketpy/pocketpy.h"
 #include "pocketpy/objects/object.h"
+#include "pocketpy/objects/iterator.h"
 #include "pocketpy/interpreter/vm.h"
 
-typedef struct array_iterator {
-    py_TValue* p;
-    int length;
-    int index;
-} array_iterator;
-
 int pk_arrayview(py_Ref self, py_TValue** p) {
     if(self->type == tp_list) {
         *p = py_list_data(self);
@@ -46,18 +41,6 @@ bool pk_wrapper__arrayequal(py_Type type, int argc, py_Ref argv) {
     return true;
 }
 
-bool pk_arrayiter(py_Ref val) {
-    py_TValue* p;
-    int length = pk_arrayview(val, &p);
-    if(length == -1) return TypeError("expected list or tuple, got %t", val->type);
-    array_iterator* ud = py_newobject(py_retval(), tp_array_iterator, 1, sizeof(array_iterator));
-    ud->p = p;
-    ud->length = length;
-    ud->index = 0;
-    py_setslot(py_retval(), 0, val);  // keep a reference to the object
-    return true;
-}
-
 bool pk_arraycontains(py_Ref self, py_Ref val) {
     py_TValue* p;
     int length = pk_arrayview(self, &p);
@@ -74,25 +57,39 @@ bool pk_arraycontains(py_Ref self, py_Ref val) {
     return true;
 }
 
-static bool array_iterator__iter__(int argc, py_Ref argv) {
+static bool list_iterator__next__(int argc, py_Ref argv) {
     PY_CHECK_ARGC(1);
-    *py_retval() = *argv;
-    return true;
+    list_iterator* ud = py_touserdata(argv);
+    if(ud->index < ud->vec->length) {
+        py_TValue* res = c11__at(py_TValue, ud->vec, ud->index);
+        py_assign(py_retval(), res);
+        ud->index++;
+        return true;
+    }
+    return StopIteration();
 }
 
-static bool array_iterator__next__(int argc, py_Ref argv) {
+static bool tuple_iterator__next__(int argc, py_Ref argv) {
     PY_CHECK_ARGC(1);
-    array_iterator* ud = py_touserdata(argv);
+    tuple_iterator* ud = py_touserdata(argv);
     if(ud->index < ud->length) {
-        *py_retval() = ud->p[ud->index++];
+        py_assign(py_retval(), ud->p + ud->index);
+        ud->index++;
         return true;
     }
     return StopIteration();
 }
 
-py_Type pk_array_iterator__register() {
-    py_Type type = pk_newtype("array_iterator", tp_object, NULL, NULL, false, true);
-    py_bindmagic(type, __iter__, array_iterator__iter__);
-    py_bindmagic(type, __next__, array_iterator__next__);
+py_Type pk_list_iterator__register() {
+    py_Type type = pk_newtype("list_iterator", tp_object, NULL, NULL, false, true);
+    py_bindmagic(type, __iter__, pk_wrapper__self);
+    py_bindmagic(type, __next__, list_iterator__next__);
+    return type;
+}
+
+py_Type pk_tuple_iterator__register() {
+    py_Type type = pk_newtype("tuple_iterator", tp_object, NULL, NULL, false, true);
+    py_bindmagic(type, __iter__, pk_wrapper__self);
+    py_bindmagic(type, __next__, tuple_iterator__next__);
     return type;
 }

+ 6 - 1
src/public/py_list.c

@@ -3,6 +3,7 @@
 #include "pocketpy/common/utils.h"
 #include "pocketpy/interpreter/vm.h"
 #include "pocketpy/interpreter/types.h"
+#include "pocketpy/objects/iterator.h"
 #include "pocketpy/common/sstream.h"
 
 void py_newlist(py_OutRef out) {
@@ -394,7 +395,11 @@ static bool list_sort(int argc, py_Ref argv) {
 
 static bool list__iter__(int argc, py_Ref argv) {
     PY_CHECK_ARGC(1);
-    return pk_arrayiter(argv);
+    list_iterator* ud = py_newobject(py_retval(), tp_list_iterator, 1, sizeof(list_iterator));
+    ud->vec = py_touserdata(argv);
+    ud->index = 0;
+    py_setslot(py_retval(), 0, argv);  // keep a reference to the object
+    return true;
 }
 
 static bool list__contains__(int argc, py_Ref argv) {

+ 7 - 1
src/public/py_tuple.c

@@ -3,6 +3,7 @@
 #include "pocketpy/common/utils.h"
 #include "pocketpy/common/sstream.h"
 #include "pocketpy/objects/object.h"
+#include "pocketpy/objects/iterator.h"
 #include "pocketpy/interpreter/vm.h"
 
 py_ObjectRef py_newtuple(py_OutRef out, int n) {
@@ -144,7 +145,12 @@ static bool tuple__lt__(int argc, py_Ref argv) {
 
 static bool tuple__iter__(int argc, py_Ref argv) {
     PY_CHECK_ARGC(1);
-    return pk_arrayiter(argv);
+    tuple_iterator* ud = py_newobject(py_retval(), tp_tuple_iterator, 1, sizeof(tuple_iterator));
+    ud->p = py_tuple_data(argv);
+    ud->length = py_tuple_len(argv);
+    ud->index = 0;
+    py_setslot(py_retval(), 0, argv);  // keep a reference to the object
+    return true;
 }
 
 static bool tuple__contains__(int argc, py_Ref argv) {

+ 13 - 1
tests/95_bugs.py

@@ -157,4 +157,16 @@ assert a == 5
 a, b, c = (1, 2, 3) if True else (4, 5, 6)
 assert a == 1
 assert b == 2
-assert c == 3
+assert c == 3
+
+# https://github.com/pocketpy/pocketpy/issues/376
+xs = [0]
+res = []
+for x in xs:
+    res.append(x)
+    if x == 100:
+        break
+    xs.append(x+1)
+
+assert res == list(range(101))
+assert xs == res