Quellcode durchsuchen

fix a bug of closure for generator

blueloveTH vor 1 Jahr
Ursprung
Commit
ff2cd96c95

+ 1 - 1
include/pocketpy/interpreter/generator.h

@@ -8,6 +8,6 @@ typedef struct Generator{
     int state;
 } Generator;
 
-void pk_newgenerator(py_Ref out, Frame* frame, py_TValue* backup, int backup_length);
+void pk_newgenerator(py_Ref out, Frame* frame, py_TValue* begin, py_TValue* end);
 
 void Generator__dtor(Generator* ud);

+ 6 - 5
src/interpreter/generator.c

@@ -5,19 +5,19 @@
 #include "pocketpy/pocketpy.h"
 #include <stdbool.h>
 
-void pk_newgenerator(py_Ref out, Frame* frame, py_TValue* backup, int backup_length) {
+void pk_newgenerator(py_Ref out, Frame* frame, py_TValue* begin, py_TValue* end) {
     Generator* ud = py_newobject(out, tp_generator, 1, sizeof(Generator));
     ud->frame = frame;
     ud->state = 0;
     py_Ref tmp = py_getslot(out, 0);
     py_newlist(tmp);
-    for(int i = 0; i < backup_length; i++) {
-        py_list_append(tmp, &backup[i]);
+    for(py_TValue* p = begin; p != end; p++) {
+        py_list_append(tmp, p);
     }
 }
 
 void Generator__dtor(Generator* ud) {
-    if(ud->frame) { Frame__delete(ud->frame); }
+    if(ud->frame) Frame__delete(ud->frame);
 }
 
 static bool generator__next__(int argc, py_Ref argv) {
@@ -28,8 +28,9 @@ static bool generator__next__(int argc, py_Ref argv) {
     if(ud->state == 2) return StopIteration();
 
     // reset frame->p0
+    int locals_offset = ud->frame->locals - ud->frame->p0;
     ud->frame->p0 = py_peek(0);
-    ud->frame->locals = py_peek(0);
+    ud->frame->locals = ud->frame->p0 + locals_offset;
 
     // restore the context
     py_Ref backup = py_getslot(argv, 0);

+ 6 - 3
src/interpreter/vm.c

@@ -494,9 +494,12 @@ FrameResult VM__vectorcall(VM* self, uint16_t argc, uint16_t kwargc, bool opcall
             case FuncType_GENERATOR: {
                 bool ok = prepare_py_call(self->__vectorcall_buffer, argv, p1, kwargc, fn->decl);
                 if(!ok) return RES_ERROR;
-                Frame* frame = Frame__new(co, &fn->module, p0, argv, false);
-                pk_newgenerator(py_retval(), frame, self->__vectorcall_buffer, co->nlocals);
-                self->stack.sp = p0;
+                // copy buffer back to stack
+                self->stack.sp = argv + co->nlocals;
+                memcpy(argv, self->__vectorcall_buffer, co->nlocals * sizeof(py_TValue));
+                Frame* frame = Frame__new(co, &fn->module, p0, argv, true);
+                pk_newgenerator(py_retval(), frame, p0, self->stack.sp);
+                self->stack.sp = p0;    // reset the stack
                 return RES_RETURN;
             }
             default: c11__unreachedable();

+ 27 - 1
tests/43_closure.py

@@ -34,4 +34,30 @@ def f(n):
         return g(x+1)
     return g(0)
 
-assert f(10) == 10
+assert f(10) == 10
+
+# class closure
+class A:
+    def g(self, x):
+        def f(y):
+            return x + y
+        return f
+    
+assert A().g(1)(2) == 3
+
+# closure with yield
+def g(x):
+    def fx(y):
+        yield x
+        yield y
+        return x + y
+    return fx
+    
+gen = g(1)(2)
+assert next(gen) == 1
+assert next(gen) == 2
+try:
+    next(gen)
+    assert False
+except StopIteration as e:
+    assert e.value == 3