Selaa lähdekoodia

fix recursion bug

blueloveTH 3 vuotta sitten
vanhempi
commit
9256186cb1
3 muutettua tiedostoa jossa 54 lisäystä ja 13 poistoa
  1. 1 1
      src/main.cpp
  2. 12 0
      src/pocketpy.h
  3. 41 12
      src/vm.h

+ 1 - 1
src/main.cpp

@@ -69,7 +69,7 @@ void setStackSize(_Float mb){
 
 int main(int argc, char** argv){
 #ifdef PK_DEBUG_STACK
-    setStackSize(1);
+    setStackSize(0.5);
 #endif
 
     if(argc == 1){

+ 12 - 0
src/pocketpy.h

@@ -542,6 +542,18 @@ void __addModuleSys(VM* vm){
         vm->__checkArgSize(args, 1);
         return vm->PyInt(args[0].use_count());
     });
+
+    vm->bindFunc(mod, "getrecursionlimit", [](VM* vm, PyVarList args) {
+        vm->__checkArgSize(args, 0);
+        return vm->PyInt(vm->maxRecursionDepth);
+    });
+
+    vm->bindFunc(mod, "setrecursionlimit", [](VM* vm, PyVarList args) {
+        vm->__checkArgSize(args, 1);
+        vm->maxRecursionDepth = vm->PyInt_AS_C(args[0]);
+        return vm->None;
+    });
+
     vm->setAttr(mod, "version", vm->PyStr(PK_VERSION));
 }
 

+ 41 - 12
src/vm.h

@@ -29,6 +29,8 @@ private:
     std::vector<PyObject*> numPool;
     PyVarDict _modules;       // 3rd modules
 
+    PyVar __py2py_call_signal;
+
     PyVar runFrame(Frame* frame){
         while(!frame->isCodeEnd()){
             const ByteCode& byte = frame->readCode();
@@ -213,7 +215,9 @@ private:
                 {
                     PyVarList args = frame->popNValuesReversed(this, byte.arg);
                     PyVar callable = frame->popValue(this);
-                    frame->push(call(callable, args));
+                    PyVar ret = call(callable, args, true);
+                    if(ret == __py2py_call_signal) return ret;
+                    frame->push(ret);
                 } break;
             case OP_JUMP_ABSOLUTE: frame->jumpTo(byte.arg); break;
             case OP_GET_ITER:
@@ -293,6 +297,8 @@ public:
     PyVar builtins;         // builtins module
     PyVar _main;            // __main__ module
 
+    int maxRecursionDepth = 1000;
+
     VM(){
         initializeBuiltinClasses();
     }
@@ -340,7 +346,7 @@ public:
         return nullptr;
     }
 
-    PyVar call(PyVar callable, PyVarList args){
+    PyVar call(PyVar callable, PyVarList args, bool opCall=false){
         if(callable->isType(_tp_type)){
             auto it = callable->attribs.find(__new__);
             PyVar obj;
@@ -394,11 +400,12 @@ public:
             if(i < args.size()) typeError("too many arguments");
 
             auto it_m = callable->attribs.find(__module__);
-            if(it_m != callable->attribs.end()){
-                return _exec(fn->code, it_m->second, locals);
-            }else{
-                return _exec(fn->code, topFrame()->_module, locals);
+            PyVar _module = it_m != callable->attribs.end() ? it_m->second : topFrame()->_module;
+            if(opCall){
+                __pushNewFrame(fn->code, _module, locals);
+                return __py2py_call_signal;
             }
+            return _exec(fn->code, _module, locals);
         }
         typeError("'" + callable->getTypeName() + "' object is not callable");
         return None;
@@ -424,16 +431,36 @@ public:
         }
     }
 
-    PyVar _exec(const _Code& code, PyVar _module, const PyVarDict& locals={}){
+    Frame* __pushNewFrame(const _Code& code, PyVar _module, const PyVarDict& locals){
         if(code == nullptr) UNREACHABLE();
-
-        if(callstack.size() > 1000){
+        if(callstack.size() > maxRecursionDepth){
             throw RuntimeError("RecursionError", "maximum recursion depth exceeded", _cleanErrorAndGetSnapshots());
         }
-
         Frame* frame = new Frame(code.get(), _module, locals);
         callstack.push(std::unique_ptr<Frame>(frame));
-        PyVar ret = runFrame(frame);
+        return frame;
+    }
+
+    PyVar _exec(const _Code& code, PyVar _module, const PyVarDict& locals={}){
+        Frame* frame = __pushNewFrame(code, _module, locals);
+        Frame* frameBase = frame;
+        PyVar ret = nullptr;
+
+        while(true){
+            ret = runFrame(frame);
+            if(ret != __py2py_call_signal){
+                if(frame == frameBase){         // [ frameBase<- ]
+                    break;
+                }else{
+                    callstack.pop();
+                    frame = callstack.top().get();
+                    frame->push(ret);
+                }
+            }else{
+                frame = callstack.top().get();  // [ frameBase, newFrame<- ]
+            }
+        }
+
         callstack.pop();
         return ret;
     }
@@ -621,7 +648,7 @@ public:
         this->False = newObject(_tp_bool, false);
         this->builtins = newModule("builtins");
         this->_main = newModule("__main__", false);
-        
+
         setAttr(_tp_type, __base__, _tp_object);
         setAttr(_tp_type, __class__, _tp_type);
         setAttr(_tp_object, __base__, None);
@@ -631,6 +658,8 @@ public:
             setAttr(type, "__name__", PyStr(name));
         }
 
+        this->__py2py_call_signal = newObject(_tp_object, (_Int)7);
+
         std::vector<_Str> publicTypes = {"type", "object", "bool", "int", "float", "str", "list", "tuple", "range"};
         for (auto& name : publicTypes) {
             setAttr(builtins, name, _types[name]);