Просмотр исходного кода

add optimized opcodes for `FOR_ITER`s

blueloveTH 1 год назад
Родитель
Сommit
e86baa2e2f
8 измененных файлов с 77 добавлено и 15 удалено
  1. 1 1
      docs/retype.yml
  2. 1 1
      include/pocketpy/common.h
  3. 2 0
      include/pocketpy/expr.h
  4. 3 0
      include/pocketpy/opcodes.h
  5. 36 3
      src/ceval.cpp
  6. 4 6
      src/compiler.cpp
  7. 27 1
      src/expr.cpp
  8. 3 3
      src/vm.cpp

+ 1 - 1
docs/retype.yml

@@ -3,7 +3,7 @@ output: .retype
 url: https://pocketpy.dev
 branding:
   title: pocketpy
-  label: v1.4.3
+  label: v1.4.4
   logo: "./static/logo.png"
 favicon: "./static/logo.png"
 meta:

+ 1 - 1
include/pocketpy/common.h

@@ -21,7 +21,7 @@
 #include <typeinfo>
 #include <initializer_list>
 
-#define PK_VERSION				"1.4.3"
+#define PK_VERSION				"1.4.4"
 
 #include "config.h"
 #include "export.h"

+ 2 - 0
include/pocketpy/expr.h

@@ -105,6 +105,7 @@ struct CodeEmitContext{
     void exit_block();
     void emit_expr();   // clear the expression stack and generate bytecode
     int emit_(Opcode opcode, uint16_t arg, int line, bool is_virtual=false);
+    void revert_last_emit_();
     int emit_int(i64 value, int line);
     void patch_jump(int index);
     bool add_label(StrName name);
@@ -113,6 +114,7 @@ struct CodeEmitContext{
     int add_const_string(std::string_view);
     int add_func_decl(FuncDecl_ decl);
     void emit_store_name(NameScope scope, StrName name, int line);
+    void try_merge_for_iter_store(int);
 };
 
 struct NameExpr: Expr{

+ 3 - 0
include/pocketpy/opcodes.h

@@ -133,6 +133,9 @@ OPCODE(UNARY_INVERT)
 /**************************/
 OPCODE(GET_ITER)
 OPCODE(FOR_ITER)
+OPCODE(FOR_ITER_STORE_FAST)
+OPCODE(FOR_ITER_STORE_GLOBAL)
+OPCODE(FOR_ITER_YIELD_VALUE)
 /**************************/
 OPCODE(IMPORT_PATH)
 OPCODE(POP_IMPORT_STAR)

+ 36 - 3
src/ceval.cpp

@@ -719,9 +719,34 @@ __NEXT_STEP:;
         if(_0 != StopIteration){
             PUSH(_0);
         }else{
-            frame->jump_abs_break(&s_data, byte.arg);
+            frame->jump_abs_break(&s_data, co->_get_block_codei(frame->_ip).end);
         }
     } DISPATCH();
+    TARGET(FOR_ITER_STORE_FAST){
+        PyObject* _0 = py_next(TOP());
+        if(_0 != StopIteration){
+            frame->_locals[byte.arg] = _0;
+        }else{
+            frame->jump_abs_break(&s_data, co->_get_block_codei(frame->_ip).end);
+        }
+    } DISPATCH()
+    TARGET(FOR_ITER_STORE_GLOBAL){
+        PyObject* _0 = py_next(TOP());
+        if(_0 != StopIteration){
+            frame->f_globals().set(StrName(byte.arg), _0);
+        }else{
+            frame->jump_abs_break(&s_data, co->_get_block_codei(frame->_ip).end);
+        }
+    } DISPATCH()
+    TARGET(FOR_ITER_YIELD_VALUE){
+        PyObject* _0 = py_next(TOP());
+        if(_0 != StopIteration){
+            PUSH(_0);
+            return PY_OP_YIELD;
+        }else{
+            frame->jump_abs_break(&s_data, co->_get_block_codei(frame->_ip).end);
+        }
+    } DISPATCH()
     /*****************************************/
     TARGET(IMPORT_PATH){
         PyObject* _0 = co->consts[byte.arg];
@@ -877,8 +902,16 @@ __NEXT_STEP:;
         *p = VAR(CAST(i64, *p) - 1);
     } DISPATCH();
     /*****************************************/
-        static_assert(OP_DEC_GLOBAL == 133);
-        case 134: case 135: case 136: case 137: case 138: case 139: case 140: case 141: case 142: case 143: case 144: case 145: case 146: case 147: case 148: case 149: case 150: case 151: case 152: case 153: case 154: case 155: case 156: case 157: case 158: case 159: case 160: case 161: case 162: case 163: case 164: case 165: case 166: case 167: case 168: case 169: case 170: case 171: case 172: case 173: case 174: case 175: case 176: case 177: case 178: case 179: case 180: case 181: case 182: case 183: case 184: case 185: case 186: case 187: case 188: case 189: case 190: case 191: case 192: case 193: case 194: case 195: case 196: case 197: case 198: case 199: case 200: case 201: case 202: case 203: case 204: case 205: case 206: case 207: case 208: case 209: case 210: case 211: case 212: case 213: case 214: case 215: case 216: case 217: case 218: case 219: case 220: case 221: case 222: case 223: case 224: case 225: case 226: case 227: case 228: case 229: case 230: case 231: case 232: case 233: case 234: case 235: case 236: case 237: case 238: case 239: case 240: case 241: case 242: case 243: case 244: case 245: case 246: case 247: case 248: case 249: case 250: case 251: case 252: case 253: case 254: case 255: PK_UNREACHABLE() break;
+        static_assert(OP_DEC_GLOBAL == 136);
+        case 137: case 138: case 139: case 140: case 141: case 142: case 143: case 144: case 145: case 146: case 147: case 148: case 149:
+        case 150: case 151: case 152: case 153: case 154: case 155: case 156: case 157: case 158: case 159: case 160: case 161: case 162: case 163: case 164:
+        case 165: case 166: case 167: case 168: case 169: case 170: case 171: case 172: case 173: case 174: case 175: case 176: case 177: case 178: case 179:
+        case 180: case 181: case 182: case 183: case 184: case 185: case 186: case 187: case 188: case 189: case 190: case 191: case 192: case 193: case 194:
+        case 195: case 196: case 197: case 198: case 199: case 200: case 201: case 202: case 203: case 204: case 205: case 206: case 207: case 208: case 209:
+        case 210: case 211: case 212: case 213: case 214: case 215: case 216: case 217: case 218: case 219: case 220: case 221: case 222: case 223: case 224:
+        case 225: case 226: case 227: case 228: case 229: case 230: case 231: case 232: case 233: case 234: case 235: case 236: case 237: case 238: case 239:
+        case 240: case 241: case 242: case 243: case 244: case 245: case 246: case 247: case 248: case 249: case 250: case 251: case 252: case 253: case 254:
+        case 255: break;
     }
 
 }

+ 4 - 6
src/compiler.cpp

@@ -50,15 +50,13 @@ namespace pkpy{
             // json mode does not contain jump instructions, so it is safe to ignore this check
             SyntaxError("maximum number of opcodes exceeded");
         }
-        // pre-compute LOOP_BREAK and LOOP_CONTINUE and FOR_ITER
+        // pre-compute LOOP_BREAK and LOOP_CONTINUE
         for(int i=0; i<codes.size(); i++){
             Bytecode& bc = codes[i];
             if(bc.op == OP_LOOP_CONTINUE){
                 bc.arg = ctx()->co->blocks[bc.arg].start;
             }else if(bc.op == OP_LOOP_BREAK){
                 bc.arg = ctx()->co->blocks[bc.arg].get_break_end();
-            }else if(bc.op == OP_FOR_ITER){
-                bc.arg = ctx()->co->_get_block_codei(i).end;
             }
         }
         // pre-compute func->is_simple
@@ -658,9 +656,10 @@ __EAT_DOTS_END:
         EXPR_TUPLE(); ctx()->emit_expr();
         ctx()->emit_(OP_GET_ITER, BC_NOARG, BC_KEEPLINE);
         CodeBlock* block = ctx()->enter_block(CodeBlockType::FOR_LOOP);
-        ctx()->emit_(OP_FOR_ITER, BC_NOARG, BC_KEEPLINE);
+        int for_codei = ctx()->emit_(OP_FOR_ITER, BC_NOARG, BC_KEEPLINE);
         bool ok = vars->emit_store(ctx());
         if(!ok) SyntaxError();  // this error occurs in `vars` instead of this line, but...nevermind
+        ctx()->try_merge_for_iter_store(for_codei);
         compile_block_body();
         ctx()->emit_(OP_LOOP_CONTINUE, ctx()->get_loop(), BC_KEEPLINE, true);
         ctx()->exit_block();
@@ -822,8 +821,7 @@ __EAT_DOTS_END:
                 ctx()->co->is_generator = true;
                 ctx()->emit_(OP_GET_ITER, BC_NOARG, kw_line);
                 ctx()->enter_block(CodeBlockType::FOR_LOOP);
-                ctx()->emit_(OP_FOR_ITER, BC_NOARG, kw_line);
-                ctx()->emit_(OP_YIELD_VALUE, BC_NOARG, kw_line);
+                ctx()->emit_(OP_FOR_ITER_YIELD_VALUE, BC_NOARG, kw_line);
                 ctx()->emit_(OP_LOOP_CONTINUE, ctx()->get_loop(), kw_line);
                 ctx()->exit_block();
                 consume_end_stmt();

+ 27 - 1
src/expr.cpp

@@ -60,6 +60,31 @@ namespace pkpy{
         return i;
     }
 
+    void CodeEmitContext::revert_last_emit_(){
+        co->codes.pop_back();
+        co->iblocks.pop_back();
+        co->lines.pop_back();
+    }
+
+    void CodeEmitContext::try_merge_for_iter_store(int i){
+        // [FOR_ITER, STORE_?, ]
+        if(co->codes[i].op != OP_FOR_ITER) return;
+        if(co->codes.size() - i != 2) return;
+        uint16_t arg = co->codes[i+1].arg;
+        if(co->codes[i+1].op == OP_STORE_FAST){
+            revert_last_emit_();
+            co->codes[i].op = OP_FOR_ITER_STORE_FAST;
+            co->codes[i].arg = arg;
+            return;
+        }
+        if(co->codes[i+1].op == OP_STORE_GLOBAL){
+            revert_last_emit_();
+            co->codes[i].op = OP_FOR_ITER_STORE_GLOBAL;
+            co->codes[i].arg = arg;
+            return;
+        }
+    }
+
     int CodeEmitContext::emit_int(i64 value, int line){
         bool allow_neg_int = is_negative_shift_well_defined() || value >= 0;
         if(allow_neg_int && value >= -5 && value <= 16){
@@ -370,10 +395,11 @@ namespace pkpy{
         iter->emit_(ctx);
         ctx->emit_(OP_GET_ITER, BC_NOARG, BC_KEEPLINE);
         ctx->enter_block(CodeBlockType::FOR_LOOP);
-        ctx->emit_(OP_FOR_ITER, BC_NOARG, BC_KEEPLINE);
+        int for_codei = ctx->emit_(OP_FOR_ITER, BC_NOARG, BC_KEEPLINE);
         bool ok = vars->emit_store(ctx);
         // this error occurs in `vars` instead of this line, but...nevermind
         PK_ASSERT(ok);  // TODO: raise a SyntaxError instead
+        ctx->try_merge_for_iter_store(for_codei);
         if(cond){
             cond->emit_(ctx);
             int patch = ctx->emit_(OP_POP_JUMP_IF_FALSE, BC_NOARG, BC_KEEPLINE);

+ 3 - 3
src/vm.cpp

@@ -573,10 +573,10 @@ static std::string _opcode_argstr(VM* vm, Bytecode byte, const CodeObject* co){
         case OP_LOAD_NAME: case OP_LOAD_GLOBAL: case OP_LOAD_NONLOCAL: case OP_STORE_GLOBAL:
         case OP_LOAD_ATTR: case OP_LOAD_METHOD: case OP_STORE_ATTR: case OP_DELETE_ATTR:
         case OP_BEGIN_CLASS: case OP_GOTO:
-        case OP_DELETE_GLOBAL: case OP_INC_GLOBAL: case OP_DEC_GLOBAL: case OP_STORE_CLASS_ATTR:
+        case OP_DELETE_GLOBAL: case OP_INC_GLOBAL: case OP_DEC_GLOBAL: case OP_STORE_CLASS_ATTR: case OP_FOR_ITER_STORE_GLOBAL:
             argStr += _S(" (", StrName(byte.arg).sv(), ")").sv();
             break;
-        case OP_LOAD_FAST: case OP_STORE_FAST: case OP_DELETE_FAST: case OP_INC_FAST: case OP_DEC_FAST:
+        case OP_LOAD_FAST: case OP_STORE_FAST: case OP_DELETE_FAST: case OP_INC_FAST: case OP_DEC_FAST: case OP_FOR_ITER_STORE_FAST:
             argStr += _S(" (", co->varnames[byte.arg].sv(), ")").sv();
             break;
         case OP_LOAD_FUNCTION:
@@ -594,7 +594,7 @@ Str VM::disassemble(CodeObject_ co){
 
     pod_vector<int> jumpTargets;
     for(auto byte : co->codes){
-        if(byte.op == OP_JUMP_ABSOLUTE || byte.op == OP_POP_JUMP_IF_FALSE || byte.op == OP_SHORTCUT_IF_FALSE_OR_POP || byte.op == OP_FOR_ITER){
+        if(byte.op == OP_JUMP_ABSOLUTE || byte.op == OP_POP_JUMP_IF_FALSE || byte.op == OP_SHORTCUT_IF_FALSE_OR_POP){
             jumpTargets.push_back(byte.arg);
         }
         if(byte.op == OP_GOTO){