Procházet zdrojové kódy

add `py_importlib_reload`

blueloveTH před 1 rokem
rodič
revize
e25aba8d97

+ 1 - 0
include/pocketpy/interpreter/modules.h

@@ -15,6 +15,7 @@ void pk__add_module_traceback();
 void pk__add_module_enum();
 void pk__add_module_inspect();
 void pk__add_module_pickle();
+void pk__add_module_importlib();
 
 void pk__add_module_linalg();
 void pk__add_module_array2d();

+ 2 - 0
include/pocketpy/pocketpy.h

@@ -485,6 +485,8 @@ PK_API bool py_pusheval(const char* expr, py_GlobalRef module) PY_RAISE;
 PK_API py_GlobalRef py_newmodule(const char* path);
 /// Get a module by path.
 PK_API py_GlobalRef py_getmodule(const char* path);
+/// Reload an existing module.
+PK_API bool py_importlib_reload(py_GlobalRef module) PY_RAISE;
 
 /// Import a module.
 /// The result will be set to `py_retval()`.

+ 1 - 0
src/interpreter/vm.c

@@ -218,6 +218,7 @@ void VM__ctor(VM* self) {
     pk__add_module_enum();
     pk__add_module_inspect();
     pk__add_module_pickle();
+    pk__add_module_importlib();
 
     pk__add_module_conio();
     pk__add_module_lz4();

+ 15 - 0
src/modules/importlib.c

@@ -0,0 +1,15 @@
+#include "pocketpy/pocketpy.h"
+
+static bool importlib_reload(int argc, py_Ref argv) {
+    PY_CHECK_ARGC(1);
+    PY_CHECK_ARG_TYPE(0, tp_module);
+    bool ok = py_importlib_reload(argv);
+    py_newnone(py_retval());
+    return ok;
+}
+
+void pk__add_module_importlib() {
+    py_Ref mod = py_newmodule("importlib");
+
+    py_bindfunc(mod, "reload", importlib_reload);
+}

+ 20 - 2
src/public/modules.c

@@ -137,7 +137,7 @@ int py_import(const char* path_cstr) {
     if(data != NULL) goto __SUCCESS;
 
     c11_string__delete(filename);
-    filename = c11_string__new3("%s/__init__.py", slashed_path->data);
+    filename = c11_string__new3("%s%c__init__.py", slashed_path->data, PK_PLATFORM_SEP);
     data = vm->callbacks.importfile(filename->data);
     if(data != NULL) goto __SUCCESS;
 
@@ -159,6 +159,24 @@ __SUCCESS:
     return ok ? 1 : -1;
 }
 
+bool py_importlib_reload(py_GlobalRef module) {
+    VM* vm = pk_current_vm;
+    c11_sv path = py_tosv(py_getdict(module, __path__));
+    c11_string* slashed_path = c11_sv__replace(path, '.', PK_PLATFORM_SEP);
+    c11_string* filename = c11_string__new3("%s.py", slashed_path->data);
+    const char* data = vm->callbacks.importfile(filename->data);
+    if(data == NULL) {
+        c11_string__delete(filename);
+        filename = c11_string__new3("%s%c__init__.py", slashed_path->data, PK_PLATFORM_SEP);
+        data = vm->callbacks.importfile(filename->data);
+    }
+    c11_string__delete(slashed_path);
+    if(data == NULL) return ImportError("module '%v' not found", path);
+    bool ok = py_exec(data, filename->data, EXEC_MODE, module);
+    c11_string__delete(filename);
+    return ok;
+}
+
 //////////////////////////
 
 static bool builtins_exit(int argc, py_Ref argv) {
@@ -655,7 +673,7 @@ static bool builtins__import__(int argc, py_Ref argv) {
     int res = py_import(py_tostr(argv));
     if(res == -1) return false;
     if(res) return true;
-    return ImportError("No module named '%s'", py_tostr(argv));
+    return ImportError("module '%s' not found", py_tostr(argv));
 }
 
 static bool NoneType__repr__(int argc, py_Ref argv) {