Explorar o código

fix generator

blueloveTH hai 1 ano
pai
achega
cd3c28fdd8

+ 2 - 2
include/pocketpy/interpreter/frame.h

@@ -36,7 +36,7 @@ typedef struct Frame {
     const Bytecode* ip;
     const CodeObject* co;
     py_GlobalRef module;
-    py_StackRef function;  // a function object or NULL (global scope)
+    bool has_function;     // is p0 a function?
     py_StackRef p0;        // unwinding base
     py_StackRef locals;    // locals base
     UnwindTarget* uw_list;
@@ -44,7 +44,7 @@ typedef struct Frame {
 
 Frame* Frame__new(const CodeObject* co,
                   py_GlobalRef module,
-                  py_StackRef function,
+                  bool has_function,
                   py_StackRef p0,
                   py_StackRef locals);
 void Frame__delete(Frame* self);

+ 4 - 1
include/pocketpy/interpreter/generator.h

@@ -1,10 +1,13 @@
 #pragma once
 
 #include "pocketpy/interpreter/frame.h"
+#include "pocketpy/pocketpy.h"
 
 typedef struct Generator{
     Frame* frame;
     int state;
 } Generator;
 
-void pk_newgenerator(py_Ref out, Frame* frame);
+void pk_newgenerator(py_Ref out, Frame* frame, py_TValue* backup, int backup_length);
+
+void Generator__dtor(Generator* ud);

+ 3 - 1
include/pocketpy/pocketpy.h

@@ -19,6 +19,8 @@ typedef int16_t py_Type;
 typedef int64_t py_i64;
 typedef double py_f64;
 
+typedef void (*py_Dtor)(void*);
+
 #define PY_RAISE  // mark a function that can raise an exception
 
 typedef struct c11_sv {
@@ -137,7 +139,7 @@ c11_sv py_name2sv(py_Name);
 /// @param base base type.
 /// @param module module where the type is defined. Use `NULL` for built-in types.
 /// @param dtor destructor function. Use `NULL` if not needed.
-py_Type py_newtype(const char* name, py_Type base, const py_GlobalRef module, void (*dtor)(void*));
+py_Type py_newtype(const char* name, py_Type base, const py_GlobalRef module, py_Dtor dtor);
 
 /// Create a new object.
 /// @param out output reference.

+ 3 - 3
src/interpreter/ceval.c

@@ -285,7 +285,7 @@ FrameResult VM__run_top_frame(VM* self) {
             case OP_STORE_FAST: frame->locals[byte.arg] = POPX(); DISPATCH();
             case OP_STORE_NAME: {
                 py_Name name = byte.arg;
-                if(frame->function) {
+                if(frame->has_function) {
                     py_Ref slot = Frame__f_locals_try_get(frame, name);
                     if(slot != NULL) {
                         *slot = *TOP();  // store in locals if possible
@@ -346,7 +346,7 @@ FrameResult VM__run_top_frame(VM* self) {
             }
             case OP_DELETE_NAME: {
                 py_Name name = byte.arg;
-                if(frame->function) {
+                if(frame->has_function) {
                     py_TValue* slot = Frame__f_locals_try_get(frame, name);
                     if(slot) {
                         py_newnil(slot);
@@ -977,7 +977,7 @@ FrameResult VM__run_top_frame(VM* self) {
         py_BaseException__stpush(&self->curr_exception,
                                  frame->co->src,
                                  lineno < 0 ? Frame__lineno(frame) : lineno,
-                                 frame->function ? frame->co->name->data : NULL);
+                                 frame->has_function ? frame->co->name->data : NULL);
 
         int target = Frame__prepare_jump_exception_handler(frame, &self->stack);
         if(target >= 0) {

+ 4 - 4
src/interpreter/frame.c

@@ -37,7 +37,7 @@ void UnwindTarget__delete(UnwindTarget* self) { free(self); }
 
 Frame* Frame__new(const CodeObject* co,
                   py_GlobalRef module,
-                  py_StackRef function,
+                  bool has_function,
                   py_StackRef p0,
                   py_StackRef locals) {
     static_assert(sizeof(Frame) <= kPoolFrameBlockSize, "!(sizeof(Frame) <= kPoolFrameBlockSize)");
@@ -46,7 +46,7 @@ Frame* Frame__new(const CodeObject* co,
     self->ip = (Bytecode*)co->codes.data - 1;
     self->co = co;
     self->module = module;
-    self->function = function;
+    self->has_function = has_function;
     self->p0 = p0;
     self->locals = locals;
     self->uw_list = NULL;
@@ -131,8 +131,8 @@ void Frame__set_unwind_target(Frame* self, py_TValue* sp) {
 }
 
 py_TValue* Frame__f_closure_try_get(Frame* self, py_Name name) {
-    if(self->function == NULL) return NULL;
-    Function* ud = py_touserdata(self->function);
+    if(!self->has_function) return NULL;
+    Function* ud = py_touserdata(self->p0);
     if(ud->closure == NULL) return NULL;
     return NameDict__try_get(ud->closure, name);
 }

+ 14 - 3
src/interpreter/generator.c

@@ -5,11 +5,19 @@
 #include "pocketpy/pocketpy.h"
 #include <stdbool.h>
 
-void pk_newgenerator(py_Ref out, Frame* frame) {
+void pk_newgenerator(py_Ref out, Frame* frame, py_TValue* backup, int backup_length) {
     Generator* ud = py_newobject(out, tp_generator, 1, sizeof(Generator));
     ud->frame = frame;
     ud->state = 0;
-    py_newlist(py_getslot(out, 0));
+    py_Ref tmp = py_getslot(out, 0);
+    py_newlist(tmp);
+    for(int i = 0; i < backup_length; i++) {
+        py_list_append(tmp, &backup[i]);
+    }
+}
+
+void Generator__dtor(Generator* ud) {
+    if(ud->frame) { Frame__delete(ud->frame); }
 }
 
 static bool generator__next__(int argc, py_Ref argv) {
@@ -33,6 +41,7 @@ static bool generator__next__(int argc, py_Ref argv) {
 
     // push frame
     VM__push_frame(vm, ud->frame);
+    ud->frame = NULL;
 
     FrameResult res = VM__run_top_frame(vm);
 
@@ -51,17 +60,19 @@ static bool generator__next__(int argc, py_Ref argv) {
         for(py_StackRef p = ud->frame->p0; p != vm->stack.sp; p++) {
             py_list_append(backup, p);
         }
+        vm->stack.sp = ud->frame->p0;
         vm->top_frame = vm->top_frame->f_back;
         ud->state = 1;
         return true;
     } else {
+        assert(res == RES_RETURN);
         ud->state = 2;
         return StopIteration();
     }
 }
 
 py_Type pk_generator__register() {
-    py_Type type = pk_newtype("generator", tp_object, NULL, NULL, false, true);
+    py_Type type = pk_newtype("generator", tp_object, NULL, (py_Dtor)Generator__dtor, false, true);
 
     py_bindmagic(type, __iter__, pk_wrapper__self);
     py_bindmagic(type, __next__, generator__next__);

+ 5 - 5
src/interpreter/vm.c

@@ -427,7 +427,7 @@ FrameResult VM__vectorcall(VM* self, uint16_t argc, uint16_t kwargc, bool opcall
                 memcpy(argv, self->__vectorcall_buffer, co->nlocals * sizeof(py_TValue));
                 // submit the call
                 if(!fn->cfunc) {
-                    VM__push_frame(self, Frame__new(co, &fn->module, p0, p0, argv));
+                    VM__push_frame(self, Frame__new(co, &fn->module, true, p0, argv));
                     return opcall ? RES_CALL : VM__run_top_frame(self);
                 } else {
                     bool ok = fn->cfunc(co->nlocals, argv);
@@ -451,13 +451,13 @@ FrameResult VM__vectorcall(VM* self, uint16_t argc, uint16_t kwargc, bool opcall
                 // initialize local variables to py_NIL
                 memset(p1, 0, (char*)self->stack.sp - (char*)p1);
                 // submit the call
-                VM__push_frame(self, Frame__new(co, &fn->module, p0, p0, argv));
+                VM__push_frame(self, Frame__new(co, &fn->module, true, p0, argv));
                 return opcall ? RES_CALL : VM__run_top_frame(self);
             case FuncType_GENERATOR: {
                 bool ok = prepare_py_call(self->__vectorcall_buffer, argv, p1, kwargc, fn->decl);
                 if(!ok) return RES_ERROR;
-                Frame* frame = Frame__new(co, &fn->module, p0, p0, argv);
-                pk_newgenerator(py_retval(), frame);
+                Frame* frame = Frame__new(co, &fn->module, false, p0, argv);
+                pk_newgenerator(py_retval(), frame, self->__vectorcall_buffer, co->nlocals);
                 self->stack.sp = p0;
                 return RES_RETURN;
             }
@@ -592,7 +592,7 @@ void ManagedHeap__mark(ManagedHeap* self) {
 }
 
 void pk_print_stack(VM* self, Frame* frame, Bytecode byte) {
-    // return;
+    return;
     if(frame == NULL || py_isnil(&self->main)) return;
 
     py_TValue* sp = self->stack.sp;

+ 1 - 1
src/public/internal.c

@@ -93,7 +93,7 @@ bool py_exec(const char* source, const char* filename, enum py_CompileMode mode,
 
     if(!module) module = &vm->main;
 
-    Frame* frame = Frame__new(&co, module, NULL, vm->stack.sp, vm->stack.sp);
+    Frame* frame = Frame__new(&co, module, false, vm->stack.sp, vm->stack.sp);
     VM__push_frame(vm, frame);
     FrameResult res = VM__run_top_frame(vm);
     CodeObject__dtor(&co);

+ 2 - 2
src/public/modules.c

@@ -468,8 +468,8 @@ static bool super__new__(int argc, py_Ref argv) {
     py_Ref self_arg = NULL;
     if(argc == 1) {
         // super()
-        if(frame->function) {
-            Function* func = py_touserdata(frame->function);
+        if(frame->has_function) {
+            Function* func = py_touserdata(frame->p0);
             *class_arg = *(py_Type*)PyObject__userdata(func->clazz);
             if(frame->co->nlocals > 0) self_arg = &frame->locals[0];
         }

+ 3 - 1
src/public/py_ops.c

@@ -81,7 +81,9 @@ int py_next(py_Ref val) {
         py_clearexc(p0);
         vm->is_stopiteration = true;
     }
-    return vm->is_stopiteration ? 0 : -1;
+    int retval = vm->is_stopiteration ? 0 : -1;
+    vm->is_stopiteration = false;
+    return retval;
 }
 
 bool py_getattr(py_Ref self, py_Name name) {

+ 2 - 2
src/public/stack_ops.c

@@ -60,8 +60,8 @@ void py_setslot(py_Ref self, int i, py_Ref val) {
 
 py_StackRef py_inspect_currentfunction(){
     Frame* frame = pk_current_vm->top_frame;
-    if(!frame) return NULL;
-    return frame->function;
+    if(!frame || !frame->has_function) return NULL;
+    return frame->p0;
 }
 
 void py_assign(py_Ref dst, py_Ref src) { *dst = *src; }

+ 7 - 1
tests/51_yield.py

@@ -6,6 +6,12 @@ a = g()
 assert next(a) == 1
 assert next(a) == 2
 
+try:
+    next(a)
+    exit(1)
+except StopIteration:
+    pass
+
 def f(n):
     for i in range(n):
         yield i
@@ -50,7 +56,7 @@ assert a == [1, 2, 3]
 def f():
     for i in range(5):
         yield str(i)
-assert '|'.join(f()) == '0|1|2|3|4'
+assert '|'.join(list(f())) == '0|1|2|3|4'
 
 
 def f(n):