BLUELOVETH 2 years ago
parent
commit
966c58d808
3 changed files with 13 additions and 18 deletions
  1. 9 12
      src/vm.cpp
  2. 2 0
      tests/test2/a/__init__.py
  3. 2 6
      tests/test2/utils/r.py

+ 9 - 12
src/vm.cpp

@@ -217,9 +217,6 @@ namespace pkpy{
 
     PyObject* VM::py_import(Str path, bool throw_err){
         if(path.empty()) vm->ValueError("empty module name");
-
-        // std::cout << ">> py_import(" << path.escape() << ")" << std::endl;
-        
         auto f_join = [](const std::vector<std::string_view>& cpnts){
             std::stringstream ss;
             for(int i=0; i<cpnts.size(); i++){
@@ -249,18 +246,19 @@ namespace pkpy{
             path = f_join(cpnts);
         }
 
-        // std::cout << ".. py_import(" << path.escape() << ")" << std::endl;
+        // std::cout << "py_import(" << path.escape() << ")" << std::endl;
 
         PK_ASSERT(path.begin()[0] != '.');
         PK_ASSERT(path.end()[-1] != '.');
 
-        StrName name(path);     // path to StrName
+        auto path_cpnts = path.split(".", true);
 
         // check circular import
-        // for(Str pending_name: _import_context.pending){
-        //     if(pending_name == path) ImportError(fmt("circular import ", name.escape()));
-        // }
+        if(_import_context.pending.size() > 128){
+            ImportError("maximum recursion depth exceeded while importing");
+        }
 
+        StrName name(path);     // path to StrName
         PyObject* ext_mod = _modules.try_get(name);
         if(ext_mod != nullptr) return ext_mod;
 
@@ -288,10 +286,9 @@ namespace pkpy{
         auto _ = _import_context.scope(path, is_init);
         CodeObject_ code = compile(source, filename, EXEC_MODE);
 
-        auto all_cpnts = path.split(".", true);
-        Str name_cpnt = all_cpnts.back();
-        all_cpnts.pop_back();
-        PyObject* new_mod = new_module(name_cpnt, f_join(all_cpnts));
+        Str name_cpnt = path_cpnts.back();
+        path_cpnts.pop_back();
+        PyObject* new_mod = new_module(name_cpnt, f_join(path_cpnts));
         _exec(code, new_mod);
         new_mod->attr()._try_perfect_rehash();
         return new_mod;

+ 2 - 0
tests/test2/a/__init__.py

@@ -1 +1,3 @@
+ok = True
+
 from ..b import D

+ 2 - 6
tests/test2/utils/r.py

@@ -1,8 +1,4 @@
 value = '123'
 
-# try:
-#     from test2.a import g
-#     exit(1)
-# except ImportError:
-#     # circular import
-#     pass
+from test2.a import g
+assert g.ok == True