Procházet zdrojové kódy

add pybind11 implementation for module reload

Kevin Eady před 1 rokem
rodič
revize
5bebf3e2df

+ 5 - 0
include/pybind11/internal/module.h

@@ -19,6 +19,11 @@ class module_ : public object {
         return steal<module_>(m);
     }
 
+    void reload() {
+        bool ok = py_importlib_reload(ptr());
+        if(!ok) { throw error_already_set(); }
+    }
+
     module_ def_submodule(const char* name, const char* doc = nullptr) {
         // auto package = (attr("__package__").cast<std::string>() += ".") +=
         // attr("__name__").cast<std::string_view>();

+ 58 - 0
include/pybind11/tests/module.cpp

@@ -79,4 +79,62 @@ TEST_F(PYBIND11_TEST, dynamic_module) {
     EXPECT_EQ(math.attr("sub")(4, 3).cast<int>(), 1);
 }
 
+struct import_callback {
+    using cb_type = decltype(py_callbacks()->importfile);
+
+    import_callback() {
+        assert(_importfile == nullptr);
+        _importfile = py_callbacks()->importfile;
+        py_callbacks()->importfile = importfile;
+    };
+
+    ~import_callback() {
+        assert(_importfile != nullptr);
+        py_callbacks()->importfile = _importfile;
+        _importfile = nullptr;
+    };
+
+    static char* importfile(const char* path) {
+        if(value.empty()) return _importfile(path);
+        // +1 for the null terminator
+        char* cstr = new char[value.size() + 1];
+
+        std::strcpy(cstr, value.c_str());
+        return cstr;
+    }
+
+    static std::string value;
+
+private:
+    static cb_type _importfile;
+};
+
+import_callback::cb_type import_callback::_importfile = nullptr;
+std::string import_callback::value = "";
+
+TEST_F(PYBIND11_TEST, reload_module) {
+    import_callback cb;
+
+    import_callback::value = "value = 1\n";
+    auto mod = py::module::import("reload_module");
+    EXPECT_EQ(mod.attr("value").cast<int>(), 1);
+
+    import_callback::value = "value = 2\n";
+    mod.reload();
+    EXPECT_EQ(mod.attr("value").cast<int>(), 2);
+
+    import_callback::value = "raise ValueError()";
+    // Reload in Python raises a ValueError
+    py::exec(
+        "import importlib\nimport reload_module\ntry:\n    importlib.reload(reload_module)\nexcept ValueError:\n    pass");
+
+    // Reload in C++ raises a ValueError
+    try {
+        mod.reload();
+    } catch(py::error_already_set& e) {
+        if(e.match(tp_ValueError)) { return; }
+        std::rethrow_exception(std::current_exception());
+    }
+}
+
 }  // namespace