Procházet zdrojové kódy

fix module reload bug

blueloveTH před 8 měsíci
rodič
revize
e187a61624

+ 1 - 1
build_g.sh

@@ -6,7 +6,7 @@ SRC=$(find src/ -name "*.c")
 
 FLAGS="-std=c11 -lm -ldl -lpthread -Iinclude -O0 -Wfatal-errors -g -DDEBUG -DPK_ENABLE_OS=1"
 
-SANITIZE_FLAGS="-fsanitize=address,leak,undefined"
+SANITIZE_FLAGS="-fsanitize=address,leak,undefined -fno-sanitize=function"
 
 if [ "$(uname)" == "Darwin" ]; then
     SANITIZE_FLAGS="-fsanitize=address,undefined"

+ 1 - 3
include/pocketpy/interpreter/typeinfo.h

@@ -16,10 +16,8 @@ typedef struct py_TypeInfo {
     bool is_python;  // is it a python class? (not derived from c object)
     bool is_sealed;  // can it be subclassed?
 
-    py_Dtor dtor;  // destructor for this type, NULL if no dtor
-
     py_TValue annotations;
-
+    py_Dtor dtor;  // destructor for this type, NULL if no dtor
     void (*on_end_subclass)(struct py_TypeInfo*);  // backdoor for enum module
 } py_TypeInfo;
 

+ 1 - 1
include/pocketpy/interpreter/vm.h

@@ -38,9 +38,9 @@ typedef struct py_ModuleInfo {
     c11_string* name;
     c11_string* package;
     c11_string* path;
+    py_GlobalRef self;  // weakref to the original module object
 } py_ModuleInfo;
 
-
 typedef struct VM {
     py_Frame* top_frame;
 

+ 1 - 1
include/pocketpy/pocketpy.h

@@ -558,7 +558,7 @@ PK_API py_GlobalRef py_newmodule(const char* path);
 /// Get a module by path.
 PK_API py_GlobalRef py_getmodule(const char* path);
 /// Reload an existing module.
-PK_API bool py_importlib_reload(py_GlobalRef module) PY_RAISE PY_RETURN;
+PK_API bool py_importlib_reload(py_Ref module) PY_RAISE PY_RETURN;
 
 /// Import a module.
 /// The result will be set to `py_retval()`.

+ 29 - 8
src/interpreter/typeinfo.c

@@ -55,26 +55,28 @@ static void py_TypeInfo__common_init(py_Name name,
                                      void (*dtor)(void*),
                                      bool is_python,
                                      bool is_sealed,
-                                     py_TypeInfo* self) {
+                                     py_TypeInfo* self,
+                                     py_TValue* typeobject) {
     py_TypeInfo* base_ti = base ? pk_typeinfo(base) : NULL;
     if(base_ti && base_ti->is_sealed) {
         c11__abort("type '%s' is not an acceptable base type", py_name2str(base_ti->name));
     }
 
-    memset(self, 0, sizeof(py_TypeInfo));
     self->name = name;
     self->index = index;
     self->base = base;
     self->base_ti = base_ti;
 
-    self->self = *py_retval();
+    py_assign(&self->self, typeobject);
     self->module = module ? module : py_NIL();
-    self->annotations = *py_NIL();
 
     if(!dtor && base) dtor = base_ti->dtor;
     self->is_python = is_python;
     self->is_sealed = is_sealed;
+
+    self->annotations = *py_NIL();
     self->dtor = dtor;
+    self->on_end_subclass = NULL;
 }
 
 py_Type pk_newtype(const char* name,
@@ -85,7 +87,15 @@ py_Type pk_newtype(const char* name,
                    bool is_sealed) {
     py_Type index = pk_current_vm->types.length;
     py_TypeInfo* self = py_newobject(py_retval(), tp_type, -1, sizeof(py_TypeInfo));
-    py_TypeInfo__common_init(py_name(name), base, index, module, dtor, is_python, is_sealed, self);
+    py_TypeInfo__common_init(py_name(name),
+                             base,
+                             index,
+                             module,
+                             dtor,
+                             is_python,
+                             is_sealed,
+                             self,
+                             py_retval());
     TypePointer* pointer = c11_vector__emplace(&pk_current_vm->types);
     pointer->ti = self;
     pointer->dtor = self->dtor;
@@ -102,11 +112,22 @@ py_Type pk_newtypewithmode(py_Name name,
     if(mode == RELOAD_MODE && module != NULL) {
         py_ItemRef old_class = py_getdict(module, name);
         if(old_class != NULL && py_istype(old_class, tp_type)) {
-            NameDict* old_dict = PyObject__dict(old_class->_obj);
-            NameDict__clear(old_dict);
+#ifndef NDEBUG
+            const char* name_cstr = py_name2str(name);
+            (void)name_cstr;  // avoid unused warning
+#endif
+            py_cleardict(old_class);
             py_TypeInfo* self = py_touserdata(old_class);
             py_Type index = self->index;
-            py_TypeInfo__common_init(name, base, index, module, dtor, is_python, is_sealed, self);
+            py_TypeInfo__common_init(name,
+                                     base,
+                                     index,
+                                     module,
+                                     dtor,
+                                     is_python,
+                                     is_sealed,
+                                     self,
+                                     &self->self);
             TypePointer* pointer = c11__at(TypePointer, &pk_current_vm->types, index);
             pointer->ti = self;
             pointer->dtor = self->dtor;

+ 1 - 1
src/interpreter/vm.c

@@ -52,7 +52,7 @@ void VM__ctor(VM* self) {
         .f_cmp = BinTree__cmp_cstr,
         .need_free_key = false,
     };
-    BinTree__ctor(&self->modules, c11_strdup(""), py_NIL(), &modules_config);
+    BinTree__ctor(&self->modules, "", py_NIL(), &modules_config);
     c11_vector__ctor(&self->types, sizeof(TypePointer));
 
     self->builtins = NULL;

+ 2 - 0
src/modules/os.c

@@ -97,6 +97,8 @@ void pk__add_module_os() {
     py_ItemRef path_object = py_emplacedict(mod, py_name("path"));
     py_newobject(path_object, tp_object, -1, 0);
     py_bindfunc(path_object, "exists", os_path_exists);
+
+    py_newdict(py_emplacedict(mod, py_name("environ")));
 }
 
 typedef struct {

+ 1 - 0
src/objects/codeobject.c

@@ -197,4 +197,5 @@ void Function__dtor(Function* self) {
     // self->decl->code.src->filename->data);
     PK_DECREF(self->decl);
     if(self->closure) NameDict__delete(self->closure);
+    memset(self, 0, sizeof(Function));
 }

+ 7 - 3
src/public/modules.c

@@ -83,7 +83,9 @@ py_Ref py_newmodule(const char* path) {
     if(exists) c11__abort("module '%s' already exists", path);
 
     BinTree__set(&pk_current_vm->modules, (void*)path, py_retval());
-    return py_getmodule(path);
+    py_GlobalRef retval = py_getmodule(path);
+    mi->self = retval;
+    return retval;
 }
 
 int load_module_from_dll_desktop_only(const char* path) PY_RAISE PY_RETURN;
@@ -181,9 +183,11 @@ __SUCCESS:
     return ok ? 1 : -1;
 }
 
-bool py_importlib_reload(py_GlobalRef module) {
+bool py_importlib_reload(py_Ref module) {
     VM* vm = pk_current_vm;
     py_ModuleInfo* mi = py_touserdata(module);
+    // We should ensure that the module is its original py_GlobalRef
+    module = mi->self;
     c11_sv path = c11_string__sv(mi->path);
     c11_string* slashed_path = c11_sv__replace(path, '.', PK_PLATFORM_SEP);
     c11_string* filename = c11_string__new3("%s.py", slashed_path->data);
@@ -195,7 +199,7 @@ bool py_importlib_reload(py_GlobalRef module) {
     }
     c11_string__delete(slashed_path);
     if(data == NULL) return ImportError("module '%v' not found", path);
-    py_cleardict(module);
+    // py_cleardict(module); BUG: removing old classes will cause RELOAD_MODE to fail
     bool ok = py_exec(data, filename->data, RELOAD_MODE, module);
     c11_string__delete(filename);
     PK_FREE(data);

+ 10 - 0
src/public/py_mappingproxy.c

@@ -20,6 +20,15 @@ static bool namedict__getitem__(int argc, py_Ref argv) {
     return true;
 }
 
+static bool namedict__get(int argc, py_Ref argv) {
+    PY_CHECK_ARGC(3);
+    PY_CHECK_ARG_TYPE(1, tp_str);
+    py_Name name = py_namev(py_tosv(py_arg(1)));
+    py_Ref res = py_getdict(py_getslot(argv, 0), name);
+    py_assign(py_retval(), res ? res : py_arg(2));
+    return true;
+}
+
 static bool namedict__setitem__(int argc, py_Ref argv) {
     PY_CHECK_ARGC(3);
     PY_CHECK_ARG_TYPE(1, tp_str);
@@ -82,5 +91,6 @@ py_Type pk_namedict__register() {
     py_setdict(py_tpobject(type), __hash__, py_None());
     py_bindmethod(type, "items", namedict_items);
     py_bindmethod(type, "clear", namedict_clear);
+    py_bindmethod(type, "get", namedict__get);
     return type;
 }

+ 43 - 0
tests/31_modulereload.py

@@ -0,0 +1,43 @@
+try:
+    import os
+except ImportError:
+    exit(0)
+
+import importlib
+
+os.chdir('tests')
+assert os.getcwd().endswith('tests')
+
+# test
+os.environ['TEST_RELOAD_VALUE'] = '123'
+os.environ['SET_X'] = '1'
+os.environ['SET_Y'] = '0'
+
+from testreload import MyClass, a
+
+objid = id(MyClass)
+funcid = id(MyClass.some_func)
+getxyid = id(MyClass.get_xy)
+
+assert MyClass.value == '123'
+assert MyClass.get_xy() == (1, 0)
+
+inst = MyClass()
+assert inst.some_func() == '123'
+
+# reload
+os.environ['TEST_RELOAD_VALUE'] = '456'
+os.environ['SET_X'] = '0'
+os.environ['SET_Y'] = '1'
+
+importlib.reload(a)
+
+assert id(MyClass) == objid
+assert id(MyClass.some_func) != funcid
+assert id(MyClass.get_xy) != getxyid
+
+assert MyClass.value == '456'
+assert inst.some_func() == '456'
+assert (MyClass.get_xy() == (1, 1)), MyClass.get_xy()
+
+

+ 2 - 0
tests/testreload/__init__.py

@@ -0,0 +1,2 @@
+from .a import MyClass
+from . import a

+ 18 - 0
tests/testreload/a.py

@@ -0,0 +1,18 @@
+import os
+
+class MyClass:
+    value = os.environ['TEST_RELOAD_VALUE']
+
+    def some_func(self):
+        return self.value
+    
+    @staticmethod
+    def get_xy():
+        g = globals()
+        return g.get('x', 0), g.get('y', 0)
+
+
+if os.environ['SET_X'] == '1':
+    x = 1
+elif os.environ['SET_Y'] == '1':
+    y = 1