blueloveTH %!s(int64=3) %!d(string=hai) anos
pai
achega
377a08e278
Modificáronse 4 ficheiros con 85 adicións e 39 borrados
  1. 60 22
      src/codeobject.h
  2. 8 1
      src/compiler.h
  3. 10 12
      src/vm.h
  4. 7 4
      tests/_goto.py

+ 60 - 22
src/codeobject.h

@@ -20,6 +20,7 @@ struct ByteCode{
     uint8_t op;
     int arg;
     uint16_t line;
+    uint16_t block;     // the block id of this bytecode
 };
 
 _Str pad(const _Str& s, const int n){
@@ -44,6 +45,46 @@ struct CodeObject {
     std::vector<std::pair<_Str, NameScope>> co_names;
     std::vector<_Str> co_global_names;
 
+    std::vector<std::vector<int>> co_loops = {{}};
+    int _currLoopIndex = 0;
+
+    std::string getBlockStr(int block){
+        std::vector<int> loopId = co_loops[block];
+        std::string s = "";
+        for(int i=0; i<loopId.size(); i++){
+            s += std::to_string(loopId[i]);
+            if(i != loopId.size()-1) s += "-";
+        }
+        return s;
+    }
+
+    void __enterLoop(int depth){
+        const std::vector<int>& prevLoopId = co_loops[_currLoopIndex];
+        if(depth - prevLoopId.size() == 1){
+            std::vector<int> copy = prevLoopId;
+            copy.push_back(0);
+            int t = 0;
+            while(true){
+                copy[copy.size()-1] = t;
+                auto it = std::find(co_loops.begin(), co_loops.end(), copy);
+                if(it == co_loops.end()) break;
+                t++;
+            }
+            co_loops.push_back(copy);
+        }else{
+            UNREACHABLE();
+        }
+        _currLoopIndex = co_loops.size()-1;
+    }
+
+    void __exitLoop(){
+        std::vector<int> copy = co_loops[_currLoopIndex];
+        copy.pop_back();
+        auto it = std::find(co_loops.begin(), co_loops.end(), copy);
+        if(it == co_loops.end()) UNREACHABLE();
+        _currLoopIndex = it - co_loops.begin();
+    }
+
     // for goto use
     // note: some opcodes moves the bytecode, such as listcomp
     // goto/label should be put at toplevel statements
@@ -79,7 +120,6 @@ class Frame {
 private:
     std::vector<PyVar> s_data;
     int ip = 0;
-    std::stack<int> forLoops;       // record the FOR_ITER bytecode index
 public:
     const _Code code;
     PyVar _module;
@@ -146,17 +186,6 @@ public:
         s_data.push_back(std::forward<T>(obj));
     }
 
-
-    void __reportForIter(){
-        int lastIp = ip - 1;
-        if(forLoops.empty()) forLoops.push(lastIp);
-        else{
-            if(forLoops.top() == lastIp) return;
-            if(forLoops.top() < lastIp) forLoops.push(lastIp);
-            else UNREACHABLE();
-        }
-    }
-
     inline void jumpAbsolute(int i){
         this->ip = i;
     }
@@ -165,18 +194,27 @@ public:
         this->ip += i;
     }
 
-    void __safeJumpClean(){
-        while(!forLoops.empty()){
-            int start = forLoops.top();
-            int end = code->co_code[start].arg;
-            if(ip < start || ip >= end){
-                //printf("%d <- [%d, %d)\n", i, start, end);
-                __pop();    // pop the iterator
-                forLoops.pop();
-            }else{
-                break;
+    void jumpAbsoluteSafe(int i){
+        const ByteCode& prev = code->co_code[this->ip];
+        const std::vector<int> prevLoopId = code->co_loops[prev.block];
+        this->ip = i;
+        if(isCodeEnd()){
+            for(int i=0; i<prevLoopId.size(); i++) __pop();
+            return;
+        }
+        const ByteCode& next = code->co_code[i];
+        const std::vector<int> nextLoopId = code->co_loops[next.block];
+        int sizeDelta = prevLoopId.size() - nextLoopId.size();
+        if(sizeDelta < 0){
+            throw std::runtime_error("invalid jump from " + code->getBlockStr(prev.block) + " to " + code->getBlockStr(next.block));
+        }else{
+            for(int i=0; i<nextLoopId.size(); i++){
+                if(nextLoopId[i] != prevLoopId[i]){
+                    throw std::runtime_error("invalid jump from " + code->getBlockStr(prev.block) + " to " + code->getBlockStr(next.block));
+                }
             }
         }
+        for(int i=0; i<sizeDelta; i++) __pop();
     }
 
     pkpy::ArgList popNValuesReversed(VM* vm, int n){

+ 8 - 1
src/compiler.h

@@ -710,7 +710,7 @@ __LISTCOMP:
     int emitCode(Opcode opcode, int arg=-1) {
         int line = parser->previous.line;
         getCode()->co_code.push_back(
-            ByteCode{(uint8_t)opcode, arg, (uint16_t)line}
+            ByteCode{(uint8_t)opcode, arg, (uint16_t)line, (uint16_t)getCode()->_currLoopIndex}
         );
         return getCode()->co_code.size() - 1;
     }
@@ -818,12 +818,14 @@ __LISTCOMP:
     }
 
     Loop& enterLoop(){
+        getCode()->__enterLoop(loops.size()+1);
         Loop lp((int)getCode()->co_code.size());
         loops.push(lp);
         return loops.top();
     }
 
     void exitLoop(){
+        getCode()->__exitLoop();
         Loop& lp = loops.top();
         for(int addr : lp.breaks) patchJump(addr);
         loops.pop();
@@ -1062,7 +1064,12 @@ __LISTCOMP:
         }
     }
 
+    bool _used = false;
     _Code __fillCode(){
+        // can only be called once
+        if(_used) UNREACHABLE();
+        _used = true;
+
         _Code code = pkpy::make_shared<CodeObject>(parser->src, _Str("<module>"));
         codes.push(code);
 

+ 10 - 12
src/vm.h

@@ -248,7 +248,7 @@ protected:
                 } break;
             case OP_JUMP_ABSOLUTE: frame->jumpAbsolute(byte.arg); break;
             case OP_JUMP_RELATIVE: frame->jumpRelative(byte.arg); break;
-            case OP_SAFE_JUMP_ABSOLUTE: frame->jumpAbsolute(byte.arg); frame->__safeJumpClean(); break;
+            case OP_SAFE_JUMP_ABSOLUTE: frame->jumpAbsoluteSafe(byte.arg); break;
             case OP_GOTO: {
                 PyVar obj = frame->popValue(this);
                 const _Str& label = PyStr_AS_C(obj);
@@ -256,8 +256,7 @@ protected:
                 if(target == nullptr){
                     _error("KeyError", "label '" + label + "' not found");
                 }
-                frame->jumpAbsolute(*target);
-                frame->__safeJumpClean();
+                frame->jumpAbsoluteSafe(*target);
             } break;
             case OP_GET_ITER:
                 {
@@ -275,15 +274,13 @@ protected:
                 } break;
             case OP_FOR_ITER:
                 {
-                    frame->__reportForIter();
                     // __top() must be PyIter, so no need to __deref()
                     auto& it = PyIter_AS_C(frame->__top());
                     if(it->hasNext()){
                         PyRef_AS_C(it->var)->set(this, frame, it->next());
                     }
                     else{
-                        frame->jumpAbsolute(byte.arg);
-                        frame->__safeJumpClean();
+                        frame->jumpAbsoluteSafe(byte.arg);
                     }
                 } break;
             case OP_JUMP_IF_FALSE_OR_POP:
@@ -327,7 +324,7 @@ protected:
                         frame->push(it->second);
                     }
                 } break;
-            // TODO: goto inside with block is unsafe
+            // TODO: using "goto" inside with block may cause __exit__ not called
             case OP_WITH_ENTER: call(frame->popValue(this), __enter__); break;
             case OP_WITH_EXIT: call(frame->popValue(this), __exit__); break;
             default:
@@ -551,9 +548,9 @@ public:
         try {
             _Code code = compile(source, filename, mode);
 
-            if(filename == "<stdin>"){
-                std::cout << disassemble(code) << std::endl;
-            }
+            // if(filename != "<builtins>"){
+            //     std::cout << disassemble(code) << std::endl;
+            // }
             
             return _exec(code, _module, {});
         }catch (const _Error& e){
@@ -777,7 +774,7 @@ public:
         int prev_line = -1;
         for(int i=0; i<code->co_code.size(); i++){
             const ByteCode& byte = code->co_code[i];
-            if(byte.op == OP_NO_OP || byte.op == OP_DELETED_OP) continue;
+            //if(byte.op == OP_NO_OP || byte.op == OP_DELETED_OP) continue;
             _Str line = std::to_string(byte.line);
             if(byte.line == prev_line) line = "";
             else{
@@ -786,7 +783,8 @@ public:
             }
             ss << pad(line, 12) << " " << pad(std::to_string(i), 3);
             ss << " " << pad(OP_NAMES[byte.op], 20) << " ";
-            ss << (byte.arg == -1 ? "" : std::to_string(byte.arg));
+            ss << pad(byte.arg == -1 ? "" : std::to_string(byte.arg), 5);
+            ss << '[' << code->getBlockStr(byte.block) << ']';
             if(i != code->co_code.size() - 1) ss << '\n';
         }
         _StrStream consts;

+ 7 - 4
tests/_goto.py

@@ -1,18 +1,21 @@
 a = []
 
-for i in range(10):
-    for j in range(10):
+for i in range(10):         # [0]
+    for j in range(10):     # [0-0]
         goto .test 
         print(2)
     label .test
     a.append(i)
+    for k in range(5):      # [0-1]
+        for t in range(7):  # [0-1-0]
+            pass
 
 assert a == list(range(10))
 
 b = False
 
-for i in range(10):
-    for j in range(10):
+for i in range(10):         # [1]
+    for j in range(10):     # [1-0]
         goto .out
         b = True
 label .out