blueloveTH 1 рік тому
батько
коміт
cef0a4a254

+ 8 - 0
docs/modules/sys.md

@@ -20,3 +20,11 @@ May be one of:
 ### `sys.argv`
 
 The command line arguments. Set by `py_sys_setargv`.
+
+### `sys.setrecursionlimit(limit: int)`
+
+Set the maximum depth of the Python interpreter stack to `limit`. This limit prevents infinite recursion from causing an overflow of the C stack and crashing the interpreter.
+
+### `sys.getrecursionlimit() -> int`
+
+Return the current value of the recursion limit.

+ 4 - 0
include/pocketpy/interpreter/vm.h

@@ -38,6 +38,10 @@ typedef struct VM {
 
     py_TValue last_retval;
     py_TValue curr_exception;
+
+    int recursion_depth;
+    int max_recursion_depth;
+    
     bool is_curr_exc_handled;  // handled by try-except block but not cleared yet
 
     py_TValue reg[8];  // users' registers

+ 1 - 1
include/pocketpy/pocketpy.h

@@ -741,7 +741,7 @@ enum py_PredefinedType {
     tp_KeyboardInterrupt,
     tp_StopIteration,
     tp_SyntaxError,
-    tp_StackOverflowError,
+    tp_RecursionError,
     tp_OSError,
     tp_NotImplementedError,
     tp_TypeError,

+ 3 - 3
src/interpreter/ceval.c

@@ -86,8 +86,8 @@ FrameResult VM__run_top_frame(VM* self) {
     while(true) {
         Bytecode byte;
     __NEXT_FRAME:
-        if(self->stack.sp > self->stack.end) {
-            py_exception(tp_StackOverflowError, "");
+        if(self->recursion_depth >= self->max_recursion_depth) {
+            py_exception(tp_RecursionError, "maximum recursion depth exceeded");
             goto __ERROR;
         }
 
@@ -403,7 +403,7 @@ FrameResult VM__run_top_frame(VM* self) {
                         if(!py_callcfunc(magic->_cfunc, 3, THIRD())) goto __ERROR;
                         STACK_SHRINK(4);
                     } else {
-                        *FOURTH() = *magic;  // [__selitem__, a, b, val]
+                        *FOURTH() = *magic;  // [__setitem__, a, b, val]
                         if(!py_vectorcall(2, 0)) goto __ERROR;
                     }
                     DISPATCH();

+ 1 - 0
src/interpreter/generator.c

@@ -64,6 +64,7 @@ static bool generator__next__(int argc, py_Ref argv) {
         }
         vm->stack.sp = ud->frame->p0;
         vm->top_frame = vm->top_frame->f_back;
+        vm->recursion_depth--;
         ud->state = 1;
         return true;
     } else {

+ 7 - 7
src/interpreter/vm.c

@@ -71,6 +71,10 @@ void VM__ctor(VM* self) {
 
     self->last_retval = *py_NIL();
     self->curr_exception = *py_NIL();
+
+    self->recursion_depth = 0;
+    self->max_recursion_depth = 1000;
+    
     self->is_curr_exc_handled = false;
 
     self->ctx = NULL;
@@ -162,7 +166,7 @@ void VM__ctor(VM* self) {
     py_setdict(&self->builtins, py_name("StopIteration"), py_tpobject(tp_StopIteration));
 
     INJECT_BUILTIN_EXC(SyntaxError, tp_Exception);
-    INJECT_BUILTIN_EXC(StackOverflowError, tp_Exception);
+    INJECT_BUILTIN_EXC(RecursionError, tp_Exception);
     INJECT_BUILTIN_EXC(OSError, tp_Exception);
     INJECT_BUILTIN_EXC(NotImplementedError, tp_Exception);
     INJECT_BUILTIN_EXC(TypeError, tp_Exception);
@@ -265,6 +269,7 @@ void VM__dtor(VM* self) {
 void VM__push_frame(VM* self, py_Frame* frame) {
     frame->f_back = self->top_frame;
     self->top_frame = frame;
+    self->recursion_depth++;
     if(self->trace_info.func) self->trace_info.func(frame, TRACE_EVENT_PUSH);
 }
 
@@ -277,6 +282,7 @@ void VM__pop_frame(VM* self) {
     // pop frame and delete
     self->top_frame = frame->f_back;
     Frame__delete(frame);
+    self->recursion_depth--;
 }
 
 static void _clip_int(int* value, int min, int max) {
@@ -469,12 +475,6 @@ FrameResult VM__vectorcall(VM* self, uint16_t argc, uint16_t kwargc, bool opcall
     py_Ref argv = p0 + 1 + (int)py_isnil(p0 + 1);
 
     if(p0->type == tp_function) {
-        // check stack overflow
-        if(self->stack.sp > self->stack.end) {
-            py_exception(tp_StackOverflowError, "");
-            return RES_ERROR;
-        }
-
         Function* fn = py_touserdata(p0);
         const CodeObject* co = &fn->decl->code;
 

+ 19 - 0
src/modules/os.c

@@ -240,9 +240,28 @@ void pk__add_module_io() {}
 
 #endif
 
+static bool sys_setrecursionlimit(int argc, py_Ref argv) {
+    PY_CHECK_ARGC(1);
+    PY_CHECK_ARG_TYPE(0, tp_int);
+    int limit = py_toint(py_arg(0));
+    if(limit <= pk_current_vm->recursion_depth) return ValueError("the limit is too low");
+    pk_current_vm->max_recursion_depth = limit;
+    py_newnone(py_retval());
+    return true;
+}
+
+static bool sys_getrecursionlimit(int argc, py_Ref argv) {
+    PY_CHECK_ARGC(0);
+    py_newint(py_retval(), pk_current_vm->max_recursion_depth);
+    return true;
+}
+
 void pk__add_module_sys() {
     py_Ref mod = py_newmodule("sys");
     py_newstr(py_emplacedict(mod, py_name("platform")), PY_SYS_PLATFORM_STRING);
     py_newstr(py_emplacedict(mod, py_name("version")), PK_VERSION);
     py_newlist(py_emplacedict(mod, py_name("argv")));
+
+    py_bindfunc(mod, "setrecursionlimit", sys_setrecursionlimit);
+    py_bindfunc(mod, "getrecursionlimit", sys_getrecursionlimit);
 }