Ver Fonte

`yield from` can return value

blueloveTH há 1 ano atrás
pai
commit
571a080127

+ 1 - 0
include/pocketpy/interpreter/vm.h

@@ -115,6 +115,7 @@ py_Type pk_range__register();
 py_Type pk_range_iterator__register();
 py_Type pk_BaseException__register();
 py_Type pk_Exception__register();
+py_Type pk_StopIteration__register();
 py_Type pk_super__register();
 py_Type pk_property__register();
 py_Type pk_staticmethod__register();

+ 1 - 0
include/pocketpy/xmacros/opcodes.h

@@ -83,6 +83,7 @@ OPCODE(UNARY_INVERT)
 /**************************/
 OPCODE(GET_ITER)
 OPCODE(FOR_ITER)
+OPCODE(FOR_ITER_YIELD_VALUE)
 /**************************/
 OPCODE(IMPORT_PATH)
 OPCODE(POP_IMPORT_STAR)

+ 34 - 26
src/compiler/compiler.c

@@ -1494,19 +1494,12 @@ static Error* pop_context(Compiler* self) {
         int codes_length = func->code.codes.length;
 
         for(int i = 0; i < codes_length; i++) {
-            if(codes[i].op == OP_YIELD_VALUE) {
+            if(codes[i].op == OP_YIELD_VALUE || codes[i].op == OP_FOR_ITER_YIELD_VALUE) {
                 func->type = FuncType_GENERATOR;
-                for(int j = 0; j < codes_length; j++) {
-                    if(codes[j].op == OP_RETURN_VALUE && codes[j].arg == BC_NOARG) {
-                        Error* err =
-                            SyntaxError(self, "'return' with argument inside generator function");
-                        err->lineno = c11__at(BytecodeEx, &func->code.codes_ex, j)->lineno;
-                        return err;
-                    }
-                }
                 break;
             }
         }
+
         if(func->type == FuncType_UNSET) {
             bool is_simple = true;
             if(func->kwargs.length > 0) is_simple = false;
@@ -2035,6 +2028,20 @@ static Error* compile_for_loop(Compiler* self) {
     return NULL;
 }
 
+static Error* compile_yield_from(Compiler* self, int kw_line) {
+    Error* err;
+    if(self->contexts.length <= 1) return SyntaxError(self, "'yield from' outside function");
+    check(EXPR_TUPLE(self));
+    Ctx__s_emit_top(ctx());
+    Ctx__emit_(ctx(), OP_GET_ITER, BC_NOARG, kw_line);
+    Ctx__enter_block(ctx(), CodeBlockType_FOR_LOOP);
+    Ctx__emit_(ctx(), OP_FOR_ITER_YIELD_VALUE, ctx()->curr_iblock, kw_line);
+    Ctx__emit_(ctx(), OP_LOOP_CONTINUE, Ctx__get_loop(ctx()), kw_line);
+    Ctx__exit_block(ctx());
+    // StopIteration.value will be pushed onto the stack
+    return NULL;
+}
+
 Error* try_compile_assignment(Compiler* self, bool* is_assign) {
     Error* err;
     switch(curr()->type) {
@@ -2074,15 +2081,24 @@ Error* try_compile_assignment(Compiler* self, bool* is_assign) {
             return NULL;
         }
         case TK_ASSIGN: {
+            consume(TK_ASSIGN);
             int n = 0;
-            while(match(TK_ASSIGN)) {
-                check(EXPR_TUPLE(self));
-                n += 1;
+
+            if(match(TK_YIELD_FROM)) {
+                check(compile_yield_from(self, prev()->line));
+                n = 1;
+            } else {
+                do {
+                    check(EXPR_TUPLE(self));
+                    n += 1;
+                } while(match(TK_ASSIGN));
+
+                // stack size is n+1
+                Ctx__s_emit_top(ctx());
+                for(int j = 1; j < n; j++)
+                    Ctx__emit_(ctx(), OP_DUP_TOP, BC_NOARG, BC_KEEPLINE);
             }
-            // stack size is n+1
-            Ctx__s_emit_top(ctx());
-            for(int j = 1; j < n; j++)
-                Ctx__emit_(ctx(), OP_DUP_TOP, BC_NOARG, BC_KEEPLINE);
+
             for(int j = 0; j < n; j++) {
                 if(Ctx__s_top(ctx())->vt->is_starred)
                     return SyntaxError(self, "can't use starred expression here");
@@ -2488,16 +2504,8 @@ static Error* compile_stmt(Compiler* self) {
             consume_end_stmt();
             break;
         case TK_YIELD_FROM:
-            if(self->contexts.length <= 1)
-                return SyntaxError(self, "'yield from' outside function");
-            check(EXPR_TUPLE(self));
-            Ctx__s_emit_top(ctx());
-            Ctx__emit_(ctx(), OP_GET_ITER, BC_NOARG, kw_line);
-            Ctx__enter_block(ctx(), CodeBlockType_FOR_LOOP);
-            Ctx__emit_(ctx(), OP_FOR_ITER, ctx()->curr_iblock, kw_line);
-            Ctx__emit_(ctx(), OP_YIELD_VALUE, BC_NOARG, kw_line);
-            Ctx__emit_(ctx(), OP_LOOP_CONTINUE, Ctx__get_loop(ctx()), kw_line);
-            Ctx__exit_block(ctx());
+            check(compile_yield_from(self, kw_line));
+            Ctx__emit_(ctx(), OP_POP_TOP, BC_NOARG, kw_line);
             consume_end_stmt();
             break;
         case TK_RETURN:

+ 15 - 0
src/interpreter/ceval.c

@@ -781,10 +781,25 @@ FrameResult VM__run_top_frame(VM* self) {
                     PUSH(py_retval());
                     DISPATCH();
                 } else {
+                    assert(self->last_retval.type == tp_StopIteration);
                     int target = Frame__prepare_loop_break(frame, &self->stack);
                     DISPATCH_JUMP_ABSOLUTE(target);
                 }
             }
+            case OP_FOR_ITER_YIELD_VALUE: {
+                int res = py_next(TOP());
+                if(res == -1) goto __ERROR;
+                if(res) {
+                    return RES_YIELD;
+                } else {
+                    assert(self->last_retval.type == tp_StopIteration);
+                    py_ObjectRef value = py_getslot(&self->last_retval, 0);
+                    int target = Frame__prepare_loop_break(frame, &self->stack);
+                    if(py_isnil(value)) value = py_None();
+                    PUSH(value);
+                    DISPATCH_JUMP_ABSOLUTE(target);
+                }
+            }
             ////////
             case OP_IMPORT_PATH: {
                 py_Ref path_object = c11__at(py_TValue, &frame->co->consts, byte.arg);

+ 4 - 1
src/interpreter/generator.c

@@ -67,7 +67,10 @@ static bool generator__next__(int argc, py_Ref argv) {
     } else {
         assert(res == RES_RETURN);
         ud->state = 2;
-        return StopIteration();
+        // raise StopIteration(<retval>)
+        bool ok = py_tpcall(tp_StopIteration, 1, py_retval());
+        if(!ok) return false;
+        return py_raise(py_retval());
     }
 }
 

+ 4 - 1
src/interpreter/vm.c

@@ -147,7 +147,10 @@ void VM__ctor(VM* self) {
     INJECT_BUILTIN_EXC(SystemExit, tp_BaseException);
     INJECT_BUILTIN_EXC(KeyboardInterrupt, tp_BaseException);
 
-    INJECT_BUILTIN_EXC(StopIteration, tp_Exception);
+    // INJECT_BUILTIN_EXC(StopIteration, tp_Exception);
+    validate(tp_StopIteration, pk_StopIteration__register());
+    py_setdict(&self->builtins, py_name("StopIteration"), py_tpobject(tp_StopIteration));
+    
     INJECT_BUILTIN_EXC(SyntaxError, tp_Exception);
     INJECT_BUILTIN_EXC(StackOverflowError, tp_Exception);
     INJECT_BUILTIN_EXC(IOError, tp_Exception);

+ 5 - 1
src/public/internal.c

@@ -255,4 +255,8 @@ bool pk_callmagic(py_Name name, int argc, py_Ref argv) {
     return py_call(tmp, argc, argv);
 }
 
-bool StopIteration() { return py_exception(tp_StopIteration, ""); }
+bool StopIteration() {
+    bool ok = py_tpcall(tp_StopIteration, 0, NULL);
+    if(!ok) return false;
+    return py_raise(py_retval());
+}

+ 2 - 1
src/public/modules.c

@@ -247,7 +247,8 @@ static bool builtins_next(int argc, py_Ref argv) {
     int res = py_next(argv);
     if(res == -1) return false;
     if(res) return true;
-    return py_exception(tp_StopIteration, "");
+    // StopIteration stored in py_retval()
+    return py_raise(py_retval());
 }
 
 static bool builtins_hash(int argc, py_Ref argv) {

+ 17 - 3
src/public/py_exception.c

@@ -96,6 +96,17 @@ static bool BaseException_args(int argc, py_Ref argv){
     return true;
 }
 
+static bool StopIteration_value(int argc, py_Ref argv) {
+    PY_CHECK_ARGC(1);
+    py_Ref arg = py_getslot(argv, 0);
+    if(py_isnil(arg)) {
+        py_newnone(py_retval());
+    }else{
+        py_assign(py_retval(), arg);
+    }
+    return true;
+}
+
 py_Type pk_BaseException__register() {
     py_Type type = pk_newtype("BaseException", tp_object, NULL, BaseException__dtor, false, false);
 
@@ -112,6 +123,12 @@ py_Type pk_Exception__register() {
     return type;
 }
 
+py_Type pk_StopIteration__register() {
+    py_Type type = pk_newtype("StopIteration", tp_Exception, NULL, NULL, false, false);
+    py_bindproperty(type, "value", StopIteration_value, NULL);
+    return type;
+}
+
 //////////////////////////////////////////////////
 bool py_checkexc(bool ignore_handled) {
     VM* vm = pk_current_vm;
@@ -134,13 +151,10 @@ bool py_matchexc(py_Type type) {
 
 void py_clearexc(py_StackRef p0) {
     VM* vm = pk_current_vm;
-    vm->last_retval = *py_NIL();
     vm->curr_exception = *py_NIL();
     vm->is_curr_exc_handled = false;
-
     /* Don't clear this, because StopIteration() may corrupt the class defination */
     // vm->__curr_class = NULL;
-
     vm->__curr_function = NULL;
     if(p0) vm->stack.sp = p0;
 }

+ 1 - 0
src/public/py_ops.c

@@ -77,6 +77,7 @@ int py_next(py_Ref val) {
     }
     if(py_call(tmp, 1, val)) return 1;
     if(vm->curr_exception.type == tp_StopIteration) {
+        vm->last_retval = vm->curr_exception;
         py_clearexc(NULL);
         return 0;
     }

+ 16 - 5
tests/51_yield.py

@@ -99,16 +99,27 @@ def f():
 
 assert list(f()) == [1, 2]
 
-src = '''
 def g():
     yield 1
     yield 2
     return 3
     yield 4
-'''
+
+assert StopIteration().value == None
+assert StopIteration(3).value == 3
 
 try:
-    exec(src)
+    iter = g()
+    assert next(iter) == 1
+    assert next(iter) == 2
+    next(iter)  # raises StopIteration
+    print('UNREACHABLE!!')
     exit(1)
-except SyntaxError:
-    pass
+except StopIteration as e:
+    assert e.value == 3
+
+def f():
+    a = yield from g()
+    yield a
+
+assert list(f()) == [1, 2, 3]