ykiko 1 ano atrás
pai
commit
f162cd308a

+ 1 - 1
docs/pybind11.md

@@ -17,7 +17,7 @@ Or explicitly call `py::interpreter::initialize()` and `py::interpreter::finaliz
 #include <pybind11/pybind11.h>
 namespace py = pybind11;
 
-PYBIND11_MODULE(example, m) {
+PYBIND11_EMBEDDED_MODULE(example, m) {
     m.def("add", [](int a, int b) {
         return a + b;
     });

+ 27 - 6
include/pybind11/internal/class.h

@@ -47,11 +47,11 @@ public:
 
     /// bind constructor
     template <typename... Args, typename... Extra>
-    class_& def(init<Args...>, const Extra&... extra) {
+    class_& def(impl::constructor<Args...>, const Extra&... extra) {
         if constexpr(!std::is_constructible_v<T, Args...>) {
             static_assert(std::is_constructible_v<T, Args...>, "Invalid constructor arguments");
         } else {
-            impl::bind_function(
+            impl::bind_function<true>(
                 *this,
                 "__init__",
                 [](T* self, Args... args) {
@@ -63,17 +63,29 @@ public:
         }
     }
 
+    template <typename Fn, typename... Extra>
+    class_& def(impl::factory<Fn> factory, const Extra&... extra) {
+        using ret = callable_return_t<Fn>;
+
+        if constexpr(!std::is_same_v<T, ret>) {
+            static_assert(std::is_same_v<T, ret>, "Factory function must return the class type");
+        } else {
+            impl::bind_function<true>(*this, "__init__", factory.make(), pkpy::BindType::DEFAULT, extra...);
+            return *this;
+        }
+    }
+
     /// bind member function
     template <typename Fn, typename... Extra>
     class_& def(const char* name, Fn&& f, const Extra&... extra) {
-        using first = std::tuple_element_t<0, callable_args_t<remove_cvref_t<Fn>>>;
-        constexpr bool is_first_base_of_v = std::is_base_of_v<remove_cvref_t<first>, T>;
+        using first = remove_cvref_t<std::tuple_element_t<0, callable_args_t<remove_cvref_t<Fn>>>>;
+        constexpr bool is_first_base_of_v = std::is_base_of_v<first, T> || std::is_same_v<first, T>;
 
         if constexpr(!is_first_base_of_v) {
             static_assert(is_first_base_of_v,
                           "If you want to bind member function, the first argument must be the base class");
         } else {
-            impl::bind_function(*this, name, std::forward<Fn>(f), pkpy::BindType::DEFAULT, extra...);
+            impl::bind_function<true>(*this, name, std::forward<Fn>(f), pkpy::BindType::DEFAULT, extra...);
         }
 
         return *this;
@@ -91,7 +103,7 @@ public:
     /// bind static function
     template <typename Fn, typename... Extra>
     class_& def_static(const char* name, Fn&& f, const Extra&... extra) {
-        impl::bind_function(*this, name, std::forward<Fn>(f), pkpy::BindType::STATICMETHOD, extra...);
+        impl::bind_function<false>(*this, name, std::forward<Fn>(f), pkpy::BindType::STATICMETHOD, extra...);
         return *this;
     }
 
@@ -163,6 +175,15 @@ public:
     template <typename... Args>
     enum_(const handle& scope, const char* name, Args&&... args) :
         class_<T, Others...>(scope, name, std::forward<Args>(args)...) {
+
+        Base::def(init([](int value) {
+            return static_cast<T>(value);
+        }));
+
+        Base::def("__eq__", [](T& self, T& other) {
+            return self == other;
+        });
+
         Base::def_property_readonly("value", [](T& self) {
             return int_(static_cast<std::underlying_type_t<T>>(self));
         });

+ 46 - 8
include/pybind11/internal/cpp_function.h

@@ -7,8 +7,37 @@ namespace pybind11 {
 // append the overload to the beginning of the overload list
 struct prepend {};
 
+namespace impl {
+
 template <typename... Args>
-struct init {};
+struct constructor {};
+
+template <typename Fn, typename Args = callable_args_t<Fn>>
+struct factory;
+
+template <typename Fn, typename... Args>
+struct factory<Fn, std::tuple<Args...>> {
+    Fn fn;
+
+    auto make() {
+        using Self = callable_return_t<Fn>;
+        return [fn = std::move(fn)](Self* self, Args... args) {
+            new (self) Self(fn(args...));
+        };
+    }
+};
+
+}  // namespace impl
+
+template <typename... Args>
+impl::constructor<Args...> init() {
+    return {};
+}
+
+template <typename Fn>
+impl::factory<Fn> init(Fn&& fn) {
+    return {std::forward<Fn>(fn)};
+}
 
 //  TODO: support more customized tags
 //
@@ -256,6 +285,7 @@ struct template_parser<Callable, std::tuple<Extras...>, std::tuple<Args...>, std
         constexpr auto named_argc = types_count_v<arg, Extras...>;
         constexpr auto normal_argc =
             sizeof...(Args) - (arguments_info.args_pos != -1) - (arguments_info.kwargs_pos != -1);
+
         static_assert(named_argc == 0 || named_argc == normal_argc,
                       "named arguments must be the same as the number of function arguments");
 
@@ -419,24 +449,32 @@ inline auto _wrapper(pkpy::VM* vm, pkpy::ArgsView view) {
     return record(view).ptr();
 }
 
-template <typename Fn, typename... Extras>
+template <bool is_method, typename Fn, typename... Extras>
 handle bind_function(const handle& obj, const char* name, Fn&& fn, pkpy::BindType type, const Extras&... extras) {
     // do not use cpp_function directly to avoid unnecessary reference count change
     pkpy::PyVar var = obj.ptr();
     cpp_function callable = var->attr().try_get(name);
+    function_record* record = nullptr;
+
+    if constexpr(is_method && types_count_v<arg, Extras...> > 0) {
+        // if the function is a method and has named arguments
+        // prepend self to the arguments list
+        record = new function_record(std::forward<Fn>(fn), arg("self"), extras...);
+    } else {
+        record = new function_record(std::forward<Fn>(fn), extras...);
+    }
 
-    // if the function is not bound yet, bind it
     if(!callable) {
-        auto record = function_record(std::forward<Fn>(fn), extras...);
-        void* data = interpreter::take_ownership(std::move(record));
-        callable = interpreter::bind_func(var, name, -1, _wrapper, data);
+        // if the function is not bound yet, bind it
+        void* data = interpreter::take_ownership(std::move(*record));
+        callable = interpreter::bind_func(var, name, -1, _wrapper, data, type);
     } else {
-        function_record* record = new function_record(std::forward<Fn>(fn), extras...);
+        // if the function is already bound, append the new record to the function
         function_record* last = callable.get_userdata_as<function_record*>();
 
         if constexpr((types_count_v<prepend, Extras...> != 0)) {
             // if prepend is specified, append the new record to the beginning of the list
-            fn.set_userdata(record);
+            callable.set_userdata(record);
             record->append(last);
         } else {
             // otherwise, append the new record to the end of the list

+ 1 - 1
include/pybind11/internal/module.h

@@ -29,7 +29,7 @@ public:
 
     template <typename Fn, typename... Extras>
     module_& def(const char* name, Fn&& fn, const Extras... extras) {
-        impl::bind_function(*this, name, std::forward<Fn>(fn), pkpy::BindType::DEFAULT, extras...);
+        impl::bind_function<false>(*this, name, std::forward<Fn>(fn), pkpy::BindType::DEFAULT, extras...);
         return *this;
     }
 };