BLUELOVETH 2 năm trước cách đây
mục cha
commit
08d6a9a1ea
8 tập tin đã thay đổi với 77 bổ sung18 xóa
  1. 1 1
      docs/quick-start/modules.md
  2. 4 0
      src/ceval.h
  3. 6 0
      src/common.h
  4. 3 1
      src/compiler.h
  5. 1 0
      src/opcodes.h
  6. 1 1
      src/pocketpy.h
  7. 58 15
      src/vm.h
  8. 3 0
      tests/30_import.py

+ 1 - 1
docs/quick-start/modules.md

@@ -53,7 +53,7 @@ When you do `import` a module, the VM will try to find it in the following order
 
 1. Search `vm->_modules`, if found, return it.
 2. Search `vm->_lazy_modules`, if found, compile and execute it, then return it.
-3. Search `vm->_path` and try to load it from file system.
+3. Search the working directory and try to load it from file system via `read_file_cwd`.
 
 
 ### Filesystem hook

+ 4 - 0
src/ceval.h

@@ -483,6 +483,10 @@ __NEXT_STEP:;
         _name = StrName(byte.arg);
         PUSH(py_import(_name));
         DISPATCH();
+    TARGET(IMPORT_NAME_REL)
+        _name = StrName(byte.arg);
+        PUSH(py_import(_name, true));
+        DISPATCH();
     TARGET(IMPORT_STAR)
         _0 = POPX();
         for(auto& [name, value]: _0->attr().items()){

+ 6 - 0
src/common.h

@@ -169,4 +169,10 @@ inline PyObject* const PY_BEGIN_CALL = (PyObject*)0b010011;
 inline PyObject* const PY_OP_CALL = (PyObject*)0b100011;
 inline PyObject* const PY_OP_YIELD = (PyObject*)0b110011;
 
+#ifdef _WIN32
+    char kPlatformSep = '\\';
+#else
+    char kPlatformSep = '/';
+#endif
+
 } // namespace pkpy

+ 3 - 1
src/compiler.h

@@ -505,9 +505,11 @@ __SUBSCR_END:
 
     Str _compile_import() {
         if(name_scope() != NAME_GLOBAL) SyntaxError("import statement should be used in global scope");
+        Opcode op = OP_IMPORT_NAME;
+        if(match(TK("."))) op = OP_IMPORT_NAME_REL;
         consume(TK("@id"));
         Str name = prev().str();
-        ctx()->emit(OP_IMPORT_NAME, StrName(name).index, prev().line);
+        ctx()->emit(op, StrName(name).index, prev().line);
         return name;
     }
 

+ 1 - 0
src/opcodes.h

@@ -95,6 +95,7 @@ OPCODE(GET_ITER)
 OPCODE(FOR_ITER)
 /**************************/
 OPCODE(IMPORT_NAME)
+OPCODE(IMPORT_NAME_REL)
 OPCODE(IMPORT_STAR)
 /**************************/
 OPCODE(UNPACK_SEQUENCE)

+ 1 - 1
src/pocketpy.h

@@ -205,7 +205,7 @@ inline void init_builtins(VM* _vm) {
     _vm->bind__repr__(_vm->tp_object, [](VM* vm, PyObject* obj) {
         if(is_tagged(obj)) FATAL_ERROR();
         std::stringstream ss;
-        ss << "<" << OBJ_NAME(vm->_t(obj)) << " object at " << std::hex << obj << ">";
+        ss << "<" << OBJ_NAME(vm->_t(obj)) << " object at 0x" << std::hex << obj << ">";
         return VAR(ss.str());
     });
 

+ 58 - 15
src/vm.h

@@ -114,7 +114,6 @@ public:
     
     NameDict _modules;                                 // loaded modules
     std::map<StrName, Str> _lazy_modules;              // lazy loaded modules
-    std::vector<Str> _path;                            // search path
 
     PyObject* None;
     PyObject* True;
@@ -547,31 +546,75 @@ public:
         return _all_types[obj->type].obj;
     }
 
-    PyObject* py_import(StrName name){
+    struct ImportContext{
+        // 0: normal; 1: __init__.py; 2: relative
+        std::vector<std::pair<StrName, int>> pending;
+
+        struct Temp{
+            VM* vm;
+            StrName name;
+
+            Temp(VM* vm, StrName name, int type): vm(vm), name(name){
+                ImportContext* ctx = &vm->_import_context;
+                for(auto& [k,v]: ctx->pending){
+                    if(k == name){
+                        vm->_error("ImportError", fmt("circular import ", name.escape()));
+                    }
+                }
+                ctx->pending.emplace_back(name, type);
+            }
+
+            ~Temp(){
+                ImportContext* ctx = &vm->_import_context;
+                ctx->pending.pop_back();
+            }
+        };
+
+        Temp temp(VM* vm, StrName name, int type){
+            return Temp(vm, name, type);
+        }
+    };
+
+    ImportContext _import_context;
+
+    PyObject* py_import(StrName name, bool relative=false){
+        Str filename;
+        int type;
+        if(relative){
+            ImportContext* ctx = &_import_context;
+            type = 2;
+            for(auto it=ctx->pending.rbegin(); it!=ctx->pending.rend(); ++it){
+                if(it->second == 2) continue;
+                if(it->second == 1){
+                    filename = fmt(it->first, kPlatformSep, name, ".py");
+                    name = fmt(it->first, '.', name).c_str();
+                    break;
+                }
+            }
+            if(filename.length() == 0) _error("ImportError", "relative import outside of package");
+        }else{
+            type = 0;
+            filename = fmt(name, ".py");
+        }
         PyObject* ext_mod = _modules.try_get(name);
         if(ext_mod == nullptr){
             Str source;
             auto it = _lazy_modules.find(name);
             if(it == _lazy_modules.end()){
-                Bytes b = _read_file_cwd(fmt(name, ".py"));
-                if(!b) {
-                    for(Str path: _path){
-#ifdef _WIN32
-                        const char* sep = "\\";
-#else
-                        const char* sep = "/";
-#endif
-                        b = _read_file_cwd(fmt(path, sep, name, ".py"));
-                        if(b) break;
-                    }
-                    if(!b) _error("ImportError", fmt("module ", name.escape(), " not found"));
+                Bytes b = _read_file_cwd(filename);
+                if(!relative && !b){
+                    filename = fmt(name, kPlatformSep, "__init__.py");
+                    b = _read_file_cwd(filename);
+                    if(b) type = 1;
                 }
+                if(!b) _error("ImportError", fmt("module ", name.escape(), " not found"));
                 source = Str(b.str());
             }else{
                 source = it->second;
                 _lazy_modules.erase(it);
             }
-            CodeObject_ code = compile(source, Str(name.sv())+".py", EXEC_MODE);
+            auto _ = _import_context.temp(this, name, type);
+            CodeObject_ code = compile(source, filename, EXEC_MODE);
             PyObject* new_mod = new_module(name);
             _exec(code, new_mod);
             new_mod->attr()._try_perfect_rehash();

+ 3 - 0
tests/30_import.py

@@ -0,0 +1,3 @@
+import test
+
+assert test.add(1, 2) == 13