BLUELOVETH 2 years ago
parent
commit
0ed2d8f3b1
3 changed files with 26 additions and 17 deletions
  1. 11 5
      include/pocketpy/vm.h
  2. 1 3
      src/pocketpy.cpp
  3. 14 9
      src/vm.cpp

+ 11 - 5
include/pocketpy/vm.h

@@ -410,17 +410,23 @@ public:
     }
 
     struct ImportContext{
-        std::vector<StrName> pending;
+        std::vector<Str> pending;
+        std::vector<bool> pending_is_init;   // a.k.a __init__.py
         struct Temp{
             ImportContext* ctx;
-            StrName name;
-            Temp(ImportContext* ctx, StrName name) : ctx(ctx), name(name){
+            Temp(ImportContext* ctx, Str name, bool is_init) : ctx(ctx){
                 ctx->pending.push_back(name);
+                ctx->pending_is_init.push_back(is_init);
+            }
+            ~Temp(){
+                ctx->pending.pop_back();
+                ctx->pending_is_init.pop_back();
             }
-            ~Temp(){ ctx->pending.pop_back(); }
         };
 
-        Temp scope(StrName name){ return {this, name}; }
+        Temp scope(Str name, bool is_init){
+            return {this, name, is_init};
+        }
     };
 
     ImportContext _import_context;

+ 1 - 3
src/pocketpy.cpp

@@ -1229,9 +1229,7 @@ void init_builtins(VM* _vm) {
     _vm->bind__repr__(_vm->tp_module, [](VM* vm, PyObject* obj) {
         const Str& package = CAST(Str&, obj->attr(__package__));
         Str name = CAST(Str&, obj->attr(__name__));
-        if(!package.empty()){
-            name = package + "." + name;
-        }
+        if(!package.empty()) name = package + "." + name;
         return VAR(fmt("<module ", name.escape(), ">"));
     });
 

+ 14 - 9
src/vm.cpp

@@ -228,12 +228,13 @@ namespace pkpy{
         };
 
         if(path[0] == '.'){
-            Str _mod_name = CAST(Str&, _module->attr(__name__));
-            Str _mod_package = CAST(Str&, _module->attr(__package__));
-            // get _module's fullname
-            if(!_mod_package.empty()) _mod_name = _mod_package + "." + _mod_name;
+            if(_import_context.pending.empty()){
+                ImportError("relative import outside of package");
+            }
+            Str curr_path = _import_context.pending.back();
+            bool curr_is_init = _import_context.pending_is_init.back();
             // convert relative path to absolute path
-            std::vector<std::string_view> cpnts = _mod_name.split(".", true);
+            std::vector<std::string_view> cpnts = curr_path.split(".", true);
             int prefix = 0;     // how many dots in the prefix
             for(int i=0; i<path.length(); i++){
                 if(path[i] == '.') prefix++;
@@ -241,16 +242,18 @@ namespace pkpy{
             }
             if(prefix > cpnts.size()) ImportError("attempted relative import beyond top-level package");
             path = path.substr(prefix);     // remove prefix
-            for(int i=1; i<prefix; i++) cpnts.pop_back();
+            for(int i=(int)curr_is_init; i<prefix; i++) cpnts.pop_back();
             cpnts.push_back(path.sv());
             path = f_join(cpnts);
         }
 
+        std::cout << "py_import(" << path.escape() << ")" << std::endl;
+
         StrName name(path);     // path to StrName
 
         // check circular import
-        for(StrName pending_name: _import_context.pending){
-            if(pending_name == name) ImportError(fmt("circular import ", name.escape()));
+        for(Str pending_name: _import_context.pending){
+            if(pending_name == path) ImportError(fmt("circular import ", name.escape()));
         }
 
         PyObject* ext_mod = _modules.try_get(name);
@@ -259,11 +262,13 @@ namespace pkpy{
         // try import
         Str filename = path.replace('.', kPlatformSep) + ".py";
         Str source;
+        bool is_init = false;
         auto it = _lazy_modules.find(name);
         if(it == _lazy_modules.end()){
             Bytes b = _import_handler(filename);
             if(!b){
                 filename = path.replace('.', kPlatformSep).str() + kPlatformSep + "__init__.py";
+                is_init = true;
                 b = _import_handler(filename);
             }
             if(!b) ImportError(fmt("module ", path.escape(), " not found"));
@@ -272,7 +277,7 @@ namespace pkpy{
             source = it->second;
             _lazy_modules.erase(it);
         }
-        auto _ = _import_context.scope(name);
+        auto _ = _import_context.scope(path, is_init);
         CodeObject_ code = compile(source, filename, EXEC_MODE);
 
         auto all_cpnts = path.split(".", true);