blueloveTH 2 gadi atpakaļ
vecāks
revīzija
0912e88ac7
5 mainītis faili ar 49 papildinājumiem un 36 dzēšanām
  1. 15 25
      src/ceval.h
  2. 1 2
      src/pocketpy.h
  3. 1 0
      src/str.h
  4. 6 9
      src/vm.h
  5. 26 0
      tests/28_iter.py

+ 15 - 25
src/ceval.h

@@ -408,22 +408,15 @@ __NEXT_STEP:;
     /*****************************************/
     /*****************************************/
     TARGET(GET_ITER)
     TARGET(GET_ITER)
         TOP() = asIter(TOP());
         TOP() = asIter(TOP());
-        check_type(TOP(), tp_iterator);
         DISPATCH();
         DISPATCH();
-    TARGET(FOR_ITER) {
-#if DEBUG_EXTRA_CHECK
-        BaseIter* it = PyIter_AS_C(TOP());
-#else
-        BaseIter* it = _PyIter_AS_C(TOP());
-#endif
-        PyObject* obj = it->next();
-        if(obj != StopIteration){
-            PUSH(obj);
+    TARGET(FOR_ITER)
+        _0 = PyIterNext(TOP());
+        if(_0 != StopIteration){
+            PUSH(_0);
         }else{
         }else{
-            int target = co_blocks[byte.block].end;
-            frame->jump_abs_break(target);
+            frame->jump_abs_break(co_blocks[byte.block].end);
         }
         }
-    } DISPATCH();
+        DISPATCH();
     /*****************************************/
     /*****************************************/
     TARGET(IMPORT_NAME) {
     TARGET(IMPORT_NAME) {
         StrName name(byte.arg);
         StrName name(byte.arg);
@@ -459,12 +452,10 @@ __NEXT_STEP:;
     /*****************************************/
     /*****************************************/
     TARGET(UNPACK_SEQUENCE)
     TARGET(UNPACK_SEQUENCE)
     TARGET(UNPACK_EX) {
     TARGET(UNPACK_EX) {
-        // asIter or iter->next may run bytecode, accidential gc may happen
         auto _lock = heap.gc_scope_lock();  // lock the gc via RAII!!
         auto _lock = heap.gc_scope_lock();  // lock the gc via RAII!!
-        PyObject* obj = asIter(POPX());
-        BaseIter* iter = PyIter_AS_C(obj);
+        PyObject* iter = asIter(POPX());
         for(int i=0; i<byte.arg; i++){
         for(int i=0; i<byte.arg; i++){
-            PyObject* item = iter->next();
+            PyObject* item = PyIterNext(iter);
             if(item == StopIteration) ValueError("not enough values to unpack");
             if(item == StopIteration) ValueError("not enough values to unpack");
             PUSH(item);
             PUSH(item);
         }
         }
@@ -472,23 +463,22 @@ __NEXT_STEP:;
         if(byte.op == OP_UNPACK_EX){
         if(byte.op == OP_UNPACK_EX){
             List extras;
             List extras;
             while(true){
             while(true){
-                PyObject* item = iter->next();
+                PyObject* item = PyIterNext(iter);
                 if(item == StopIteration) break;
                 if(item == StopIteration) break;
                 extras.push_back(item);
                 extras.push_back(item);
             }
             }
             PUSH(VAR(extras));
             PUSH(VAR(extras));
         }else{
         }else{
-            if(iter->next() != StopIteration) ValueError("too many values to unpack");
+            if(PyIterNext(iter) != StopIteration) ValueError("too many values to unpack");
         }
         }
     } DISPATCH();
     } DISPATCH();
     TARGET(UNPACK_UNLIMITED) {
     TARGET(UNPACK_UNLIMITED) {
         auto _lock = heap.gc_scope_lock();  // lock the gc via RAII!!
         auto _lock = heap.gc_scope_lock();  // lock the gc via RAII!!
-        PyObject* obj = asIter(POPX());
-        BaseIter* iter = PyIter_AS_C(obj);
-        obj = iter->next();
-        while(obj != StopIteration){
-            PUSH(obj);
-            obj = iter->next();
+        PyObject* iter = asIter(POPX());
+        _0 = PyIterNext(iter);
+        while(_0 != StopIteration){
+            PUSH(_0);
+            _0 = PyIterNext(iter);
         }
         }
     } DISPATCH();
     } DISPATCH();
     /*****************************************/
     /*****************************************/

+ 1 - 2
src/pocketpy.h

@@ -166,8 +166,7 @@ inline void init_builtins(VM* _vm) {
     });
     });
 
 
     _vm->bind_builtin_func<1>("next", [](VM* vm, ArgsView args) {
     _vm->bind_builtin_func<1>("next", [](VM* vm, ArgsView args) {
-        BaseIter* iter = vm->PyIter_AS_C(args[0]);
-        return iter->next();
+        return vm->PyIterNext(args[0]);
     });
     });
 
 
     _vm->bind_builtin_func<1>("dir", [](VM* vm, ArgsView args) {
     _vm->bind_builtin_func<1>("dir", [](VM* vm, ArgsView args) {

+ 1 - 0
src/str.h

@@ -383,6 +383,7 @@ const StrName __class__ = StrName::get("__class__");
 const StrName __base__ = StrName::get("__base__");
 const StrName __base__ = StrName::get("__base__");
 const StrName __new__ = StrName::get("__new__");
 const StrName __new__ = StrName::get("__new__");
 const StrName __iter__ = StrName::get("__iter__");
 const StrName __iter__ = StrName::get("__iter__");
+const StrName __next__ = StrName::get("__next__");
 const StrName __str__ = StrName::get("__str__");
 const StrName __str__ = StrName::get("__str__");
 const StrName __repr__ = StrName::get("__repr__");
 const StrName __repr__ = StrName::get("__repr__");
 const StrName __getitem__ = StrName::get("__getitem__");
 const StrName __getitem__ = StrName::get("__getitem__");

+ 6 - 9
src/vm.h

@@ -321,15 +321,12 @@ public:
         return heap.gcnew<P>(tp_iterator, std::forward<P>(value));
         return heap.gcnew<P>(tp_iterator, std::forward<P>(value));
     }
     }
 
 
-    BaseIter* PyIter_AS_C(PyObject* obj)
-    {
-        check_type(obj, tp_iterator);
-        return static_cast<BaseIter*>(obj->value());
-    }
-
-    BaseIter* _PyIter_AS_C(PyObject* obj)
-    {
-        return static_cast<BaseIter*>(obj->value());
+    PyObject* PyIterNext(PyObject* obj){
+        if(is_non_tagged_type(obj, tp_iterator)){
+            BaseIter* iter = static_cast<BaseIter*>(obj->value());
+            return iter->next();
+        }
+        return call_method(obj, __next__);
     }
     }
     
     
     /***** Error Reporter *****/
     /***** Error Reporter *****/

+ 26 - 0
tests/28_iter.py

@@ -10,3 +10,29 @@ while True:
     total += obj
     total += obj
 
 
 assert total == 6
 assert total == 6
+
+class Task:
+    def __init__(self, n):
+        self.n = n
+
+    def __iter__(self):
+        self.i = 0
+        return self
+
+    def __next__(self):
+        if self.i == self.n:
+            return StopIteration
+        self.i += 1
+        return self.i
+
+a = Task(3)
+assert sum(a) == 6
+
+i = iter(Task(5))
+assert next(i) == 1
+assert next(i) == 2
+assert next(i) == 3
+assert next(i) == 4
+assert next(i) == 5
+assert next(i) == StopIteration
+assert next(i) == StopIteration