Просмотр исходного кода

warn return with arg inside generator function

blueloveTH 2 лет назад
Родитель
Сommit
02a25de8e5
7 измененных файлов с 57 добавлено и 41 удалено
  1. 4 1
      include/pocketpy/expr.h
  2. 1 1
      src/ceval.cpp
  3. 6 12
      src/compiler.cpp
  4. 22 12
      src/expr.cpp
  5. 23 1
      tests/45_yield.py
  6. 0 12
      tests/99_builtin_func.py
  7. 1 2
      tests/99_dis.py

+ 4 - 1
include/pocketpy/expr.h

@@ -43,7 +43,7 @@ struct Expr{
 struct CodeEmitContext{
 struct CodeEmitContext{
     VM* vm;
     VM* vm;
     FuncDecl_ func;     // optional
     FuncDecl_ func;     // optional
-    CodeObject_ co;
+    CodeObject_ co;     // 1 CodeEmitContext <=> 1 CodeObject_
     // some bugs on MSVC (error C2280) when using std::vector<Expr_>
     // some bugs on MSVC (error C2280) when using std::vector<Expr_>
     // so we use stack_no_copy instead
     // so we use stack_no_copy instead
     stack_no_copy<Expr_> s_expr;
     stack_no_copy<Expr_> s_expr;
@@ -55,6 +55,9 @@ struct CodeEmitContext{
     bool is_compiling_class = false;
     bool is_compiling_class = false;
     int for_loop_depth = 0;
     int for_loop_depth = 0;
 
 
+    std::map<void*, int> _co_consts_nonstring_dedup_map;
+    std::map<std::string_view, int> _co_consts_string_dedup_map;
+
     int get_loop() const;
     int get_loop() const;
     CodeBlock* enter_block(CodeBlockType type);
     CodeBlock* enter_block(CodeBlockType type);
     void exit_block();
     void exit_block();

+ 1 - 1
src/ceval.cpp

@@ -630,7 +630,7 @@ __NEXT_STEP:;
         PUSH(_0);
         PUSH(_0);
     } DISPATCH();
     } DISPATCH();
     TARGET(RETURN_VALUE){
     TARGET(RETURN_VALUE){
-        PyObject* _0 = POPX();
+        PyObject* _0 = byte.arg == BC_NOARG ? POPX() : None;
         _pop_frame();
         _pop_frame();
         if(frame.index == base_id){       // [ frameBase<- ]
         if(frame.index == base_id){       // [ frameBase<- ]
             return _0;
             return _0;

+ 6 - 12
src/compiler.cpp

@@ -30,22 +30,13 @@ namespace pkpy{
         // add a `return None` in the end as a guard
         // add a `return None` in the end as a guard
         // previously, we only do this if the last opcode is not a return
         // previously, we only do this if the last opcode is not a return
         // however, this is buggy...since there may be a jump to the end (out of bound) even if the last opcode is a return
         // however, this is buggy...since there may be a jump to the end (out of bound) even if the last opcode is a return
-        ctx()->emit_(OP_LOAD_NONE, BC_NOARG, BC_KEEPLINE);
-        ctx()->emit_(OP_RETURN_VALUE, BC_NOARG, BC_KEEPLINE);
+        ctx()->emit_(OP_RETURN_VALUE, 1, BC_KEEPLINE);
         // some check here
         // some check here
         std::vector<Bytecode>& codes = ctx()->co->codes;
         std::vector<Bytecode>& codes = ctx()->co->codes;
         if(ctx()->co->varnames.size() > PK_MAX_CO_VARNAMES){
         if(ctx()->co->varnames.size() > PK_MAX_CO_VARNAMES){
             SyntaxError("maximum number of local variables exceeded");
             SyntaxError("maximum number of local variables exceeded");
         }
         }
         if(ctx()->co->consts.size() > 65535){
         if(ctx()->co->consts.size() > 65535){
-            // std::map<std::string_view, int> counts;
-            // for(PyObject* c: ctx()->co->consts){
-            //     std::string_view key = obj_type_name(vm, vm->_tp(c)).sv();
-            //     counts[key] += 1;
-            // }
-            // for(auto pair: counts){
-            //     std::cout << pair.first << ": " << pair.second << std::endl;
-            // }
             SyntaxError("maximum number of constants exceeded");
             SyntaxError("maximum number of constants exceeded");
         }
         }
         if(codes.size() > 65535 && ctx()->co->src->mode != JSON_MODE){
         if(codes.size() > 65535 && ctx()->co->src->mode != JSON_MODE){
@@ -63,6 +54,7 @@ namespace pkpy{
                 bc.arg = ctx()->co->_get_block_codei(i).end;
                 bc.arg = ctx()->co->_get_block_codei(i).end;
             }
             }
         }
         }
+        // pre-compute func->is_simple
         FuncDecl_ func = contexts.top().func;
         FuncDecl_ func = contexts.top().func;
         if(func){
         if(func){
             func->is_simple = true;
             func->is_simple = true;
@@ -809,12 +801,14 @@ __EAT_DOTS_END:
             case TK("return"):
             case TK("return"):
                 if (contexts.size() <= 1) SyntaxError("'return' outside function");
                 if (contexts.size() <= 1) SyntaxError("'return' outside function");
                 if(match_end_stmt()){
                 if(match_end_stmt()){
-                    ctx()->emit_(OP_LOAD_NONE, BC_NOARG, kw_line);
+                    ctx()->emit_(OP_RETURN_VALUE, 1, kw_line);
                 }else{
                 }else{
                     EXPR_TUPLE(false);
                     EXPR_TUPLE(false);
+                    // check if it is a generator
+                    if(ctx()->co->is_generator) SyntaxError("'return' with argument inside generator function");
                     consume_end_stmt();
                     consume_end_stmt();
+                    ctx()->emit_(OP_RETURN_VALUE, BC_NOARG, kw_line);
                 }
                 }
-                ctx()->emit_(OP_RETURN_VALUE, BC_NOARG, kw_line);
                 break;
                 break;
             /*************************************************/
             /*************************************************/
             case TK("if"): compile_if_stmt(); break;
             case TK("if"): compile_if_stmt(); break;

+ 22 - 12
src/expr.cpp

@@ -83,21 +83,31 @@ namespace pkpy{
     }
     }
 
 
     int CodeEmitContext::add_const(PyObject* v){
     int CodeEmitContext::add_const(PyObject* v){
-        // simple deduplication, only works for int/float
-        for(int i=0; i<co->consts.size(); i++){
-            if(co->consts[i] == v) return i;
-        }
-        // string deduplication
         if(is_non_tagged_type(v, vm->tp_str)){
         if(is_non_tagged_type(v, vm->tp_str)){
-            const Str& v_str = PK_OBJ_GET(Str, v);
-            for(int i=0; i<co->consts.size(); i++){
-                if(is_non_tagged_type(co->consts[i], vm->tp_str)){
-                    if(PK_OBJ_GET(Str, co->consts[i]) == v_str) return i;
-                }
+            // string deduplication
+            std::string_view key = PK_OBJ_GET(Str, v).sv();
+            auto it = _co_consts_string_dedup_map.find(key);
+            if(it != _co_consts_string_dedup_map.end()){
+                return it->second;
+            }else{
+                co->consts.push_back(v);
+                int index = co->consts.size() - 1;
+                _co_consts_string_dedup_map[key] = index;
+                return index;
+            }
+        }else{
+            // non-string deduplication
+            auto it = _co_consts_nonstring_dedup_map.find(v);
+            if(it != _co_consts_nonstring_dedup_map.end()){
+                return it->second;
+            }else{
+                co->consts.push_back(v);
+                int index = co->consts.size() - 1;
+                _co_consts_nonstring_dedup_map[v] = index;
+                return index;
             }
             }
         }
         }
-        co->consts.push_back(v);
-        return co->consts.size() - 1;
+        PK_UNREACHABLE();
     }
     }
 
 
     int CodeEmitContext::add_func_decl(FuncDecl_ decl){
     int CodeEmitContext::add_func_decl(FuncDecl_ decl){

+ 23 - 1
tests/45_yield.py

@@ -64,4 +64,26 @@ try:
 except ValueError:
 except ValueError:
     pass
     pass
 
 
-assert next(t) == StopIteration
+assert next(t) == StopIteration
+
+def f():
+    yield 1
+    yield 2
+    return
+    yield 3
+
+assert list(f()) == [1, 2]
+
+src = '''
+def g():
+    yield 1
+    yield 2
+    return 3
+    yield 4
+'''
+
+try:
+    exec(src)
+    exit(1)
+except SyntaxError:
+    pass

+ 0 - 12
tests/99_builtin_func.py

@@ -899,18 +899,6 @@ time.sleep(0.1)
 # test time.localtime
 # test time.localtime
 assert type(time.localtime()) is time.struct_time
 assert type(time.localtime()) is time.struct_time
 
 
-# /************ module dis ************/
-import dis
-#       116: 1487:    vm->bind_func<1>(mod, "dis", [](VM* vm, ArgsView args) {
-#     #####: 1488:        CodeObject_ code = get_code(vm, args[0]);
-#     #####: 1489:        vm->_stdout(vm, vm->disassemble(code));
-#     #####: 1490:        return vm->None;
-#     #####: 1491:    });
-# test dis.dis
-def aaa():
-    pass
-assert dis.dis(aaa) is None
-
 # test min/max
 # test min/max
 assert min(1, 2) == 1
 assert min(1, 2) == 1
 assert min(1, 2, 3) == 1
 assert min(1, 2, 3) == 1

+ 1 - 2
tests/81_frontend.py → tests/99_dis.py

@@ -13,5 +13,4 @@ def f(a):
 def g(a):
 def g(a):
     return f([1,2,3] + a)
     return f([1,2,3] + a)
 
 
-# x = _s(g)
-# assert type(x) is str
+assert dis(g) is None