Explorar o código

fix leak and improve `pickle`

blueloveTH hai 4 meses
pai
achega
0251f924c3

+ 1 - 1
docs/modules/pickle.md

@@ -20,7 +20,7 @@ The following types can be pickled:
 - [x] integers, floating-point numbers;
 - [x] strings, bytes;
 - [x] tuples, lists, sets, and dictionaries containing only picklable objects;
-- [ ] functions (built-in and user-defined) accessible from the top level of a module (using def, not lambda);
+- [x] functions (user-defined) accessible from the top level of a module (using def, not lambda);
 - [x] classes accessible from the top level of a module;
 - [x] instances of such classes
 

+ 1 - 1
docs/retype.yml

@@ -3,7 +3,7 @@ output: .retype
 url: https://pocketpy.dev
 branding:
   title: pocketpy
-  label: v2.1.2
+  label: v2.1.3
   logo: "./static/logo.png"
 favicon: "./static/logo.png"
 meta:

+ 2 - 2
include/pocketpy/config.h

@@ -1,10 +1,10 @@
 #pragma once
 // clang-format off
 
-#define PK_VERSION				"2.1.2"
+#define PK_VERSION				"2.1.3"
 #define PK_VERSION_MAJOR            2
 #define PK_VERSION_MINOR            1
-#define PK_VERSION_PATCH            2
+#define PK_VERSION_PATCH            3
 
 /*************** feature settings ***************/
 #ifndef PK_ENABLE_OS                // can be overridden by cmake

+ 7 - 0
src/bindings/py_str.c

@@ -414,6 +414,7 @@ static bool str_format(int argc, py_Ref argv) {
                 p += 2;
             } else {
                 if((p + 1) >= p_end) {
+                    c11_sbuf__dtor(&buf);
                     return ValueError("single '{' encountered in format string");
                 }
                 p++;
@@ -434,22 +435,26 @@ static bool str_format(int argc, py_Ref argv) {
                 if(p < p_end) {
                     c11__rtassert(*p == '}');
                 } else {
+                    c11_sbuf__dtor(&buf);
                     return ValueError("expected '}' before end of string");
                 }
                 // parse auto field
                 int64_t arg_index;
                 if(field.size > 0) {  // {0}
                     if(auto_field_index >= 0) {
+                        c11_sbuf__dtor(&buf);
                         return ValueError(
                             "cannot switch from automatic field numbering to manual field specification");
                     }
                     IntParsingResult res = c11__parse_uint(field, &arg_index, 10);
                     if(res != IntParsing_SUCCESS) {
+                        c11_sbuf__dtor(&buf);
                         return ValueError("only integer field name is supported");
                     }
                     manual_field_used = true;
                 } else {  // {}
                     if(manual_field_used) {
+                        c11_sbuf__dtor(&buf);
                         return ValueError(
                             "cannot switch from manual field specification to automatic field numbering");
                     }
@@ -458,6 +463,7 @@ static bool str_format(int argc, py_Ref argv) {
                 }
                 // do format
                 if(arg_index < 0 || arg_index >= (argc - 1)) {
+                    c11_sbuf__dtor(&buf);
                     return IndexError("replacement index %i out of range for positional args tuple",
                                       arg_index);
                 }
@@ -478,6 +484,7 @@ static bool str_format(int argc, py_Ref argv) {
                 c11_sbuf__write_char(&buf, '}');
                 p += 2;
             } else {
+                c11_sbuf__dtor(&buf);
                 return ValueError("single '}' encountered in format string");
             }
         } else {

+ 2 - 6
src/interpreter/ceval.c

@@ -201,6 +201,8 @@ __NEXT_STEP:
                 py_Name name = py_name(decl->code.name->data);
                 // capture itself to allow recursion
                 NameDict__set(ud->closure, name, SP());
+            } else {
+                if(self->curr_class) ud->clazz = self->curr_class->_obj;
             }
             SP()++;
             DISPATCH();
@@ -1082,12 +1084,6 @@ __NEXT_STEP:
             assert(self->curr_class);
             py_Name name = co_names[byte.arg];
             // TOP() can be a function, classmethod or custom decorator
-            py_Ref actual_func = TOP();
-            if(actual_func->type == tp_classmethod) { actual_func = py_getslot(actual_func, 0); }
-            if(actual_func->type == tp_function) {
-                Function* ud = py_touserdata(actual_func);
-                ud->clazz = self->curr_class->_obj;
-            }
             py_setdict(self->curr_class, name, TOP());
             POP();
             DISPATCH();

+ 125 - 55
src/modules/pickle.c

@@ -26,6 +26,8 @@ typedef enum {
     PKL_VEC2I, PKL_VEC3I,
     PKL_TYPE,
     PKL_ARRAY2D,
+    PKL_IMPORT_PATH,
+    PKL_GETATTR,
     PKL_TVALUE,
     PKL_CALL,
     PKL_OBJECT,
@@ -96,6 +98,16 @@ static void pkl__emit_int(PickleObject* buf, py_i64 val) {
     }
 }
 
+static void pkl__emit_cstr(PickleObject* buf, const char* s) {
+    PickleObject__write_bytes(buf, s, strlen(s) + 1);
+}
+
+const static char* pkl__read_cstr(const unsigned char** p) {
+    const char* s = (const char*)*p;
+    (*p) += strlen(s) + 1;
+    return s;
+}
+
 #define UNALIGNED_READ(p_val, p_buf)                                                               \
     do {                                                                                           \
         memcpy((p_val), (p_buf), sizeof(*(p_val)));                                                \
@@ -192,6 +204,16 @@ static void pkl__store_memo(PickleObject* buf, PyObject* memo_key) {
     pkl__emit_int(buf, index);
 }
 
+static bool _check_function(Function* f) {
+    if(!f->module) return ValueError("cannot pickle function (!f->module)");
+    if(f->closure) return ValueError("cannot pickle function with closure");
+    if(f->decl->nested) return ValueError("cannot pickle nested function");
+    c11_string* name = f->decl->code.name;
+    if(name->size == 0) return ValueError("cannot pickle function with empty name");
+    if(name->data[0] == '<') return ValueError("cannot pickle anonymous function");
+    return true;
+}
+
 static bool pkl__write_object(PickleObject* buf, py_TValue* obj) {
     switch(obj->type) {
         case tp_nil: {
@@ -228,61 +250,44 @@ static bool pkl__write_object(PickleObject* buf, py_TValue* obj) {
             return true;
         }
         case tp_str: {
-            if(pkl__try_memo(buf, obj->_obj))
-                return true;
-            else {
-                pkl__emit_op(buf, PKL_STRING);
-                c11_sv sv = py_tosv(obj);
-                pkl__emit_int(buf, sv.size);
-                PickleObject__write_bytes(buf, sv.data, sv.size);
-            }
-            pkl__store_memo(buf, obj->_obj);
+            if(obj->is_ptr && pkl__try_memo(buf, obj->_obj)) return true;
+            pkl__emit_op(buf, PKL_STRING);
+            c11_sv sv = py_tosv(obj);
+            pkl__emit_int(buf, sv.size);
+            PickleObject__write_bytes(buf, sv.data, sv.size);
+            if(obj->is_ptr) pkl__store_memo(buf, obj->_obj);
             return true;
         }
         case tp_bytes: {
-            if(pkl__try_memo(buf, obj->_obj))
-                return true;
-            else {
-                pkl__emit_op(buf, PKL_BYTES);
-                int size;
-                unsigned char* data = py_tobytes(obj, &size);
-                pkl__emit_int(buf, size);
-                PickleObject__write_bytes(buf, data, size);
-            }
+            if(pkl__try_memo(buf, obj->_obj)) return true;
+            pkl__emit_op(buf, PKL_BYTES);
+            int size;
+            unsigned char* data = py_tobytes(obj, &size);
+            pkl__emit_int(buf, size);
+            PickleObject__write_bytes(buf, data, size);
             pkl__store_memo(buf, obj->_obj);
             return true;
         }
         case tp_list: {
-            if(pkl__try_memo(buf, obj->_obj))
-                return true;
-            else {
-                bool ok =
-                    pkl__write_array(buf, PKL_BUILD_LIST, py_list_data(obj), py_list_len(obj));
-                if(!ok) return false;
-            }
+            if(pkl__try_memo(buf, obj->_obj)) return true;
+            bool ok = pkl__write_array(buf, PKL_BUILD_LIST, py_list_data(obj), py_list_len(obj));
+            if(!ok) return false;
             pkl__store_memo(buf, obj->_obj);
             return true;
         }
         case tp_tuple: {
-            if(pkl__try_memo(buf, obj->_obj))
-                return true;
-            else {
-                bool ok =
-                    pkl__write_array(buf, PKL_BUILD_TUPLE, py_tuple_data(obj), py_tuple_len(obj));
-                if(!ok) return false;
-            }
+            if(pkl__try_memo(buf, obj->_obj)) return true;
+            bool ok = pkl__write_array(buf, PKL_BUILD_TUPLE, py_tuple_data(obj), py_tuple_len(obj));
+            if(!ok) return false;
             pkl__store_memo(buf, obj->_obj);
             return true;
         }
         case tp_dict: {
-            if(pkl__try_memo(buf, obj->_obj))
-                return true;
-            else {
-                bool ok = py_dict_apply(obj, pkl__write_dict_kv, (void*)buf);
-                if(!ok) return false;
-                pkl__emit_op(buf, PKL_BUILD_DICT);
-                pkl__emit_int(buf, py_dict_len(obj));
-            }
+            if(pkl__try_memo(buf, obj->_obj)) return true;
+            bool ok = py_dict_apply(obj, pkl__write_dict_kv, (void*)buf);
+            if(!ok) return false;
+            pkl__emit_op(buf, PKL_BUILD_DICT);
+            pkl__emit_int(buf, py_dict_len(obj));
             pkl__store_memo(buf, obj->_obj);
             return true;
         }
@@ -320,22 +325,71 @@ static bool pkl__write_object(PickleObject* buf, py_TValue* obj) {
             pkl__emit_int(buf, type);
             return true;
         }
-        case tp_array2d: {
-            if(pkl__try_memo(buf, obj->_obj))
-                return true;
-            else {
-                c11_array2d* arr = py_touserdata(obj);
-                for(int i = 0; i < arr->header.numel; i++) {
-                    if(arr->data[i].is_ptr)
-                        return TypeError(
-                            "'array2d' object is not picklable because it contains heap-allocated objects");
-                    buf->used_types[arr->data[i].type] = true;
-                }
-                pkl__emit_op(buf, PKL_ARRAY2D);
-                pkl__emit_int(buf, arr->header.n_cols);
-                pkl__emit_int(buf, arr->header.n_rows);
-                PickleObject__write_bytes(buf, arr->data, arr->header.numel * sizeof(py_TValue));
+        case tp_module: {
+            if(pkl__try_memo(buf, obj->_obj)) return true;
+            py_ModuleInfo* mi = py_touserdata(obj);
+            pkl__emit_op(buf, PKL_IMPORT_PATH);
+            pkl__emit_cstr(buf, mi->path->data);
+            pkl__store_memo(buf, obj->_obj);
+            return true;
+        }
+        case tp_function: {
+            if(pkl__try_memo(buf, obj->_obj)) return true;
+            Function* f = py_touserdata(obj);
+            if(!_check_function(f)) return false;
+            c11_string* name = f->decl->code.name;
+            if(f->clazz) {
+                // NOTE: copied from logic of `case tp_type:`
+                pkl__emit_op(buf, PKL_TYPE);
+                py_TypeInfo* ti = PyObject__userdata(f->clazz);
+                py_Type type = ti->index;
+                buf->used_types[type] = true;
+                pkl__emit_int(buf, type);
+            } else {
+                if(!pkl__write_object(buf, f->module)) return false;
             }
+            pkl__emit_op(buf, PKL_GETATTR);
+            pkl__emit_cstr(buf, name->data);
+            pkl__store_memo(buf, obj->_obj);
+            return true;
+        }
+        case tp_boundmethod: {
+            py_Ref self = py_getslot(obj, 0);
+            if(!py_istype(self, tp_type)) {
+                return ValueError("tp_boundmethod: !py_istype(self, tp_type)");
+            }
+            py_Ref func = py_getslot(obj, 1);
+            if(!py_istype(func, tp_function)) {
+                return ValueError("tp_boundmethod: !py_istype(func, tp_function)");
+            }
+
+            Function* f = py_touserdata(func);
+            if(!_check_function(f)) return false;
+
+            c11_string* name = f->decl->code.name;
+            // NOTE: copied from logic of `case tp_type:`
+            pkl__emit_op(buf, PKL_TYPE);
+            py_Type type = py_totype(self);
+            buf->used_types[type] = true;
+            pkl__emit_int(buf, type);
+
+            pkl__emit_op(buf, PKL_GETATTR);
+            pkl__emit_cstr(buf, name->data);
+            return true;
+        }
+        case tp_array2d: {
+            if(pkl__try_memo(buf, obj->_obj)) return true;
+            c11_array2d* arr = py_touserdata(obj);
+            for(int i = 0; i < arr->header.numel; i++) {
+                if(arr->data[i].is_ptr)
+                    return TypeError(
+                        "'array2d' object is not picklable because it contains heap-allocated objects");
+                buf->used_types[arr->data[i].type] = true;
+            }
+            pkl__emit_op(buf, PKL_ARRAY2D);
+            pkl__emit_int(buf, arr->header.n_cols);
+            pkl__emit_int(buf, arr->header.n_rows);
+            PickleObject__write_bytes(buf, arr->data, arr->header.numel * sizeof(py_TValue));
             pkl__store_memo(buf, obj->_obj);
             return true;
         }
@@ -665,6 +719,22 @@ bool py_pickle_loads_body(const unsigned char* p, int memo_length, c11_smallmap_
                 p += total_size;
                 break;
             }
+            case PKL_IMPORT_PATH: {
+                const char* path = pkl__read_cstr(&p);
+                int res = py_import(path);
+                if(res == -1) return false;
+                if(res == 0) return ImportError("No module named '%s'", path);
+                py_push(py_retval());
+                break;
+            }
+            case PKL_GETATTR: {
+                const char* name = pkl__read_cstr(&p);
+                py_Ref obj = py_peek(-1);
+                if(!py_getattr(obj, py_name(name))) return false;
+                py_pop();
+                py_push(py_retval());
+                break;
+            }
             case PKL_TVALUE: {
                 py_TValue* tmp = py_pushtmp();
                 memcpy(tmp, p, sizeof(py_TValue));

+ 23 - 1
tests/90_pickle.py

@@ -210,4 +210,26 @@ a = deque([1, 2, 3])
 test(a)
 
 a = [int, float, Foo]
-test(a)
+test(a)
+
+# test function
+def f(x, y):
+    return x + y
+test(f)
+
+# test @staticmethod
+class B:
+    @staticmethod
+    def f(x, y):
+        return x * y
+    @classmethod
+    def g(cls):
+        return cls
+
+class C(B):
+    pass
+
+test(B.f)
+test(C.f)
+test(B.g)
+test(C.g)