blueloveTH hai 1 ano
pai
achega
76075de70c

+ 1 - 1
include/pocketpy/interpreter/vm.h

@@ -30,7 +30,7 @@ typedef struct py_TypeInfo {
 typedef struct VM {
     Frame* top_frame;
 
-    NameDict modules;
+    ModuleDict modules;
     c11_vector /*T=py_TypeInfo*/ types;
 
     py_TValue builtins;  // builtins module

+ 13 - 0
include/pocketpy/objects/namedict.h

@@ -12,3 +12,16 @@
 #include "pocketpy/xmacros/smallmap.h"
 #undef SMALLMAP_T__HEADER
 
+/* A simple binary tree for storing modules. */
+typedef struct ModuleDict {
+    const char* path;
+    py_TValue module;
+    struct ModuleDict* left;
+    struct ModuleDict* right;
+} ModuleDict;
+
+void ModuleDict__ctor(ModuleDict* self, const char* path, py_TValue module);
+void ModuleDict__dtor(ModuleDict* self);
+void ModuleDict__set(ModuleDict* self, const char* key, py_TValue val);
+py_TValue* ModuleDict__try_get(ModuleDict* self, const char* path);
+bool ModuleDict__contains(ModuleDict* self, const char* path);

+ 2 - 2
include/pocketpy/pocketpy.h

@@ -377,9 +377,9 @@ bool py_vectorcall(uint16_t argc, uint16_t kwargc) PY_RAISE;
 /************* Modules *************/
 
 /// Create a new module.
-py_TmpRef py_newmodule(const char* path);
+py_GlobalRef py_newmodule(const char* path);
 /// Get a module by path.
-py_TmpRef py_getmodule(const char* path);
+py_GlobalRef py_getmodule(const char* path);
 
 /// Import a module.
 /// The result will be set to `py_retval()`.

+ 3 - 3
src/interpreter/vm.c

@@ -57,7 +57,7 @@ static void py_TypeInfo__dtor(py_TypeInfo* self) { c11_vector__dtor(&self->annot
 void VM__ctor(VM* self) {
     self->top_frame = NULL;
 
-    NameDict__ctor(&self->modules);
+    ModuleDict__ctor(&self->modules, NULL, *py_NIL);
     c11_vector__ctor(&self->types, sizeof(py_TypeInfo));
 
     self->builtins = *py_NIL;
@@ -221,7 +221,7 @@ void VM__dtor(VM* self) {
     // clear frames
     while(self->top_frame)
         VM__pop_frame(self);
-    NameDict__dtor(&self->modules);
+    ModuleDict__dtor(&self->modules);
     c11__foreach(py_TypeInfo, &self->types, ti) py_TypeInfo__dtor(ti);
     c11_vector__dtor(&self->types);
     ValueStack__clear(&self->stack);
@@ -315,7 +315,7 @@ py_Type pk_newtype(const char* name,
     py_Type index = types->count;
     py_TypeInfo* ti = c11_vector__emplace(types);
     py_TypeInfo* base_ti = base ? c11__at(py_TypeInfo, types, base) : NULL;
-    if(base_ti && base_ti->is_sealed){
+    if(base_ti && base_ti->is_sealed) {
         c11__abort("type '%s' is not an acceptable base type", py_name2str(base_ti->name));
     }
     py_TypeInfo__ctor(ti, py_name(name), index, base, module ? *module : *py_NIL);

+ 1 - 1
src/modules/json.c

@@ -10,7 +10,7 @@ static bool json_loads(int argc, py_Ref argv) {
     PY_CHECK_ARGC(1);
     PY_CHECK_ARG_TYPE(0, tp_str);
     const char* source = py_tostr(argv);
-    py_TmpRef mod = py_getmodule("json");
+    py_GlobalRef mod = py_getmodule("json");
     return py_exec(source, "<json>", EVAL_MODE, mod);
 }
 

+ 67 - 0
src/objects/namedict.c

@@ -6,3 +6,70 @@
 #define NAME NameDict
 #include "pocketpy/xmacros/smallmap.h"
 #undef SMALLMAP_T__SOURCE
+
+void ModuleDict__ctor(ModuleDict* self, const char* path, py_TValue module) {
+    self->path = path;
+    self->module = module;
+    self->left = NULL;
+    self->right = NULL;
+}
+
+void ModuleDict__dtor(ModuleDict* self) {
+    if(self->left) {
+        ModuleDict__dtor(self->left);
+        free(self->left);
+    }
+    if(self->right) {
+        ModuleDict__dtor(self->right);
+        free(self->right);
+    }
+}
+
+void ModuleDict__set(ModuleDict* self, const char* key, py_TValue val) {
+    if(self->path == NULL) {
+        self->path = key;
+        self->module = val;
+    }
+    int cmp = strcmp(key, self->path);
+    if(cmp < 0) {
+        if(self->left) {
+            ModuleDict__set(self->left, key, val);
+        } else {
+            self->left = malloc(sizeof(ModuleDict));
+            ModuleDict__ctor(self->left, key, val);
+        }
+    } else if(cmp > 0) {
+        if(self->right) {
+            ModuleDict__set(self->right, key, val);
+        } else {
+            self->right = malloc(sizeof(ModuleDict));
+            ModuleDict__ctor(self->right, key, val);
+        }
+    } else {
+        self->module = val;
+    }
+}
+
+py_TValue* ModuleDict__try_get(ModuleDict* self, const char* path) {
+    if(self->path == NULL) return NULL;
+    int cmp = strcmp(path, self->path);
+    if(cmp < 0) {
+        if(self->left) {
+            return ModuleDict__try_get(self->left, path);
+        } else {
+            return NULL;
+        }
+    } else if(cmp > 0) {
+        if(self->right) {
+            return ModuleDict__try_get(self->right, path);
+        } else {
+            return NULL;
+        }
+    } else {
+        return &self->module;
+    }
+}
+
+bool ModuleDict__contains(ModuleDict* self, const char* path) {
+    return ModuleDict__try_get(self, path) != NULL;
+}

+ 7 - 5
src/public/modules.c

@@ -10,7 +10,7 @@
 
 py_Ref py_getmodule(const char* path) {
     VM* vm = pk_current_vm;
-    return NameDict__try_get(&vm->modules, py_name(path));
+    return ModuleDict__try_get(&vm->modules, path);
 }
 
 py_Ref py_getbuiltin(py_Name name) { return py_getdict(&pk_current_vm->builtins, name); }
@@ -51,10 +51,12 @@ py_Ref py_newmodule(const char* path) {
 
     // we do not allow override in order to avoid memory leak
     // it is because Module objects are not garbage collected
-    py_Name path_name = py_name(path);
-    bool exists = NameDict__contains(&pk_current_vm->modules, path_name);
+    bool exists = ModuleDict__contains(&pk_current_vm->modules, path);
     if(exists) c11__abort("module '%s' already exists", path);
-    NameDict__set(&pk_current_vm->modules, path_name, *r0);
+
+    // convert to a weak (const char*)
+    path = py_tostr(py_getdict(r0, __path__));
+    ModuleDict__set(&pk_current_vm->modules, path, *r0);
 
     py_shrink(2);
     return py_getmodule(path);
@@ -112,7 +114,7 @@ int py_import(const char* path_cstr) {
     assert(path.data[0] != '.' && path.data[path.size - 1] != '.');
 
     // check existing module
-    py_TmpRef ext_mod = py_getmodule(path.data);
+    py_GlobalRef ext_mod = py_getmodule(path.data);
     if(ext_mod) {
         py_assign(py_retval(), ext_mod);
         return true;