blueloveTH 2 years ago
parent
commit
ddd841f0c6
5 changed files with 19 additions and 14 deletions
  1. 6 6
      docs/quick-start/modules.md
  2. 3 2
      src/io.h
  3. 1 0
      src/pocketpy.h
  4. 4 6
      src/vm.h
  5. 5 0
      tests/30_import.py

+ 6 - 6
docs/quick-start/modules.md

@@ -53,16 +53,16 @@ When you do `import` a module, the VM will try to find it in the following order
 
 1. Search `vm->_modules`, if found, return it.
 2. Search `vm->_lazy_modules`, if found, compile and execute it, then return it.
-3. Search the working directory and try to load it from file system via `read_file_cwd`.
+3. Search the working directory and try to load it from file system via `vm->_import_handler`.
 
 
-### Filesystem hook
+### Customized import handler
 
-You can use `set_read_file_cwd` to provide a custom filesystem hook, which is used for `import` (3rd step).
-The default implementation is:
+You can use `vm->_import_handler` to provide a custom import handler for the 3rd step.
+if both `enable_os` and `PK_ENABLE_OS` are `true`, the default `import_handler` is as follows:
 
 ```cpp
-set_read_file_cwd([](const Str& name){
+inline Bytes _default_import_handler(const Str& name){
     std::filesystem::path path(name.sv());
     bool exists = std::filesystem::exists(path);
     if(!exists) return Bytes();
@@ -75,5 +75,5 @@ set_read_file_cwd([](const Str& name){
     fread(buffer.data(), 1, buffer.size(), fp);
     fclose(fp);
     return Bytes(std::move(buffer));
-});
+};
 ```

+ 3 - 2
src/io.h

@@ -11,7 +11,7 @@
 
 namespace pkpy{
 
-inline int _ = set_read_file_cwd([](const Str& name){
+inline Bytes _default_import_handler(const Str& name){
     std::filesystem::path path(name.sv());
     bool exists = std::filesystem::exists(path);
     if(!exists) return Bytes();
@@ -24,7 +24,7 @@ inline int _ = set_read_file_cwd([](const Str& name){
     fread(buffer.data(), 1, buffer.size(), fp);
     fclose(fp);
     return Bytes(std::move(buffer));
-});
+};
 
 struct FileIO {
     PY_CLASS(FileIO, io, FileIO)
@@ -183,6 +183,7 @@ inline void add_module_os(VM* vm){
 namespace pkpy{
 inline void add_module_io(void* vm){}
 inline void add_module_os(void* vm){}
+inline Bytes _default_import_handler(const Str& name) { return Bytes(); }
 } // namespace pkpy
 
 #endif

+ 1 - 0
src/pocketpy.h

@@ -1458,6 +1458,7 @@ inline void VM::post_init(){
         add_module_io(this);
         add_module_os(this);
         add_module_requests(this);
+        _import_handler = _default_import_handler;
     }
 
     add_module_linalg(this);

+ 4 - 6
src/vm.h

@@ -25,10 +25,6 @@ namespace pkpy{
 #define POPX()            (s_data.popx())
 #define STACK_VIEW(n)     (s_data.view(n))
 
-typedef Bytes (*ReadFileCwdFunc)(const Str& name);
-inline ReadFileCwdFunc _read_file_cwd = [](const Str& name) { return Bytes(); };
-inline int set_read_file_cwd(ReadFileCwdFunc func) { _read_file_cwd = func; return 0; }
-
 #define DEF_NATIVE_2(ctype, ptype)                                      \
     template<> inline ctype py_cast<ctype>(VM* vm, PyObject* obj) {     \
         vm->check_non_tagged_type(obj, vm->ptype);                      \
@@ -127,6 +123,7 @@ public:
 
     PrintFunc _stdout;
     PrintFunc _stderr;
+    Bytes (*_import_handler)(const Str& name);
 
     // for quick access
     Type tp_object, tp_type, tp_int, tp_float, tp_bool, tp_str;
@@ -145,6 +142,7 @@ public:
         callstack.reserve(8);
         _main = nullptr;
         _last_exception = nullptr;
+        _import_handler = [](const Str& name) { return Bytes(); };
         init_builtin_types();
     }
 
@@ -604,10 +602,10 @@ public:
             Str source;
             auto it = _lazy_modules.find(name);
             if(it == _lazy_modules.end()){
-                Bytes b = _read_file_cwd(filename);
+                Bytes b = _import_handler(filename);
                 if(!relative && !b){
                     filename = fmt(name, kPlatformSep, "__init__.py");
-                    b = _read_file_cwd(filename);
+                    b = _import_handler(filename);
                     if(b) type = 1;
                 }
                 if(!b) _error("ImportError", fmt("module ", name.escape(), " not found"));

+ 5 - 0
tests/30_import.py

@@ -1,3 +1,8 @@
+try:
+    import os
+except ImportError:
+    exit(0)
+
 import test1
 
 assert test1.add(1, 2) == 13