cpp_function.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. #pragma once
  2. #include "cast.h"
  3. #include <bitset>
  4. namespace pybind11 {
  5. template <std::size_t Nurse, std::size_t... Patients>
  6. struct keep_alive {};
  7. template <typename T>
  8. struct call_guard {
  9. static_assert(std::is_default_constructible_v<T>, "call_guard must be default constructible");
  10. };
  11. // append the overload to the beginning of the overload list
  12. struct prepend {};
  13. template <typename... Args>
  14. struct init {};
  15. // TODO: support more customized tags
  16. // struct kw_only {};
  17. //
  18. // struct pos_only {};
  19. //
  20. // struct default_arg {};
  21. //
  22. // struct arg {
  23. // const char* name;
  24. // const char* description;
  25. // };
  26. //
  27. // struct default_arg {
  28. // const char* name;
  29. // const char* description;
  30. // const char* value;
  31. // };
  32. template <typename Fn,
  33. typename Extra,
  34. typename Args = callable_args_t<std::decay_t<Fn>>,
  35. typename IndexSequence = std::make_index_sequence<std::tuple_size_v<Args>>>
  36. struct generator;
  37. class function_record {
  38. union {
  39. void* data;
  40. char buffer[16];
  41. };
  42. // TODO: optimize the function_record size to reduce memory usage
  43. const char* name;
  44. function_record* next;
  45. void (*destructor)(function_record*);
  46. return_value_policy policy = return_value_policy::automatic;
  47. handle (*wrapper)(function_record&, pkpy::ArgsView, bool convert, handle parent);
  48. template <typename Fn, typename Extra, typename Args, typename IndexSequence>
  49. friend struct generator;
  50. public:
  51. template <typename Fn, typename... Extras>
  52. function_record(Fn&& f, const char* name, const Extras&... extras) : name(name), next(nullptr) {
  53. if constexpr(sizeof(f) <= sizeof(buffer)) {
  54. new (buffer) auto(std::forward<Fn>(f));
  55. destructor = [](function_record* self) {
  56. reinterpret_cast<Fn*>(self->buffer)->~Fn();
  57. };
  58. } else {
  59. data = new auto(std::forward<Fn>(f));
  60. destructor = [](function_record* self) {
  61. delete static_cast<Fn*>(self->data);
  62. };
  63. }
  64. using Generator = generator<std::decay_t<Fn>, std::tuple<Extras...>>;
  65. Generator::initialize(*this, extras...);
  66. wrapper = Generator::generate();
  67. }
  68. ~function_record() { destructor(this); }
  69. template <typename Fn>
  70. auto& cast() {
  71. if constexpr(sizeof(Fn) <= sizeof(buffer)) {
  72. return *reinterpret_cast<Fn*>(buffer);
  73. } else {
  74. return *static_cast<Fn*>(data);
  75. }
  76. }
  77. void append(function_record* record) {
  78. function_record* p = this;
  79. while(p->next != nullptr) {
  80. p = p->next;
  81. }
  82. p->next = record;
  83. }
  84. handle operator() (pkpy::ArgsView view) {
  85. function_record* p = this;
  86. // foreach function record and call the function with not convert
  87. while(p != nullptr) {
  88. handle result = p->wrapper(*this, view, false, {});
  89. if(result) { return result; }
  90. p = p->next;
  91. }
  92. p = this;
  93. // foreach function record and call the function with convert
  94. while(p != nullptr) {
  95. handle result = p->wrapper(*this, view, true, {});
  96. if(result) { return result; }
  97. p = p->next;
  98. }
  99. vm->TypeError("no matching function found");
  100. }
  101. };
  102. template <typename Fn, std::size_t... Is, typename... Args>
  103. handle invoke(Fn&& fn,
  104. std::index_sequence<Is...>,
  105. std::tuple<type_caster<Args>...>& casters,
  106. return_value_policy policy,
  107. handle parent) {
  108. using underlying_type = std::decay_t<Fn>;
  109. using ret = callable_return_t<underlying_type>;
  110. // if the return type is void, return None
  111. if constexpr(std::is_void_v<ret>) {
  112. // resolve the member function pointer
  113. if constexpr(std::is_member_function_pointer_v<underlying_type>) {
  114. [&](class_type_t<underlying_type>& self, auto&... args) {
  115. (self.*fn)(args...);
  116. }(std::get<Is>(casters).value...);
  117. } else {
  118. fn(std::get<Is>(casters).value...);
  119. }
  120. return vm->None;
  121. } else {
  122. // resolve the member function pointer
  123. if constexpr(std::is_member_function_pointer_v<remove_cvref_t<Fn>>) {
  124. return type_caster<ret>::cast(
  125. [&](class_type_t<underlying_type>& self, auto&... args) {
  126. return (self.*fn)(args...);
  127. }(std::get<Is>(casters).value...),
  128. policy,
  129. parent);
  130. } else {
  131. return type_caster<ret>::cast(fn(std::get<Is>(casters).value...), policy, parent);
  132. }
  133. }
  134. }
  135. template <typename Fn, typename... Args, std::size_t... Is, typename... Extras>
  136. struct generator<Fn, std::tuple<Extras...>, std::tuple<Args...>, std::index_sequence<Is...>> {
  137. static void initialize(function_record& record, const Extras&... extras) {}
  138. static auto generate() {
  139. return +[](function_record& self, pkpy::ArgsView view, bool convert, handle parent) {
  140. // FIXME:
  141. // Temporarily, args and kwargs must be at the end of the arguments list
  142. // Named arguments are not supported yet
  143. constexpr bool has_args = types_count_v<args, remove_cvref_t<Args>...> != 0;
  144. constexpr bool has_kwargs = types_count_v<kwargs, remove_cvref_t<Args>...> != 0;
  145. constexpr std::size_t count = sizeof...(Args) - has_args - has_kwargs;
  146. handle stack[sizeof...(Args)] = {};
  147. // initialize the stack
  148. if(!has_args && (view.size() != count)) { return handle(); }
  149. if(has_args && (view.size() < count)) { return handle(); }
  150. for(std::size_t i = 0; i < count; ++i) {
  151. stack[i] = view[i];
  152. }
  153. // pack the args and kwargs
  154. if constexpr(has_args) {
  155. const auto n = view.size() - count;
  156. pkpy::PyVar var = vm->new_object<pkpy::Tuple>(vm->tp_tuple, n);
  157. auto& tuple = var.obj_get<pkpy::Tuple>();
  158. for(std::size_t i = 0; i < n; ++i) {
  159. tuple[i] = view[count + i];
  160. }
  161. stack[count] = var;
  162. }
  163. if constexpr(has_kwargs) {
  164. const auto n = vm->s_data._sp - view.end();
  165. pkpy::PyVar var = vm->new_object<pkpy::Dict>(vm->tp_dict);
  166. auto& dict = var.obj_get<pkpy::Dict>();
  167. for(std::size_t i = 0; i < n; i += 2) {
  168. pkpy::i64 index = pkpy::_py_cast<pkpy::i64>(vm, view[count + i]);
  169. pkpy::PyVar str = vm->new_object<pkpy::Str>(vm->tp_str, pkpy::StrName(index).sv());
  170. dict.set(vm, str, view[count + i + 1]);
  171. }
  172. stack[count + 1] = var;
  173. }
  174. // check if all the arguments are not valid
  175. for(std::size_t i = 0; i < sizeof...(Args); ++i) {
  176. if(!stack[i]) { return handle(); }
  177. }
  178. // ok, all the arguments are valid, call the function
  179. std::tuple<type_caster<Args>...> casters;
  180. // check type compatibility
  181. if(((std::get<Is>(casters).load(stack[Is], convert)) && ...)) {
  182. return invoke(self.cast<Fn>(), std::index_sequence<Is...>{}, casters, self.policy, parent);
  183. }
  184. return handle();
  185. };
  186. }
  187. };
  188. constexpr inline static auto _wrapper = +[](pkpy::VM*, pkpy::ArgsView view) {
  189. auto& record = pkpy::lambda_get_userdata<function_record>(view.begin());
  190. return record(view).ptr();
  191. };
  192. class cpp_function : public function {
  193. public:
  194. template <typename Fn, typename... Extras>
  195. cpp_function(Fn&& f, const Extras&... extras) {
  196. pkpy::any userdata = function_record(std::forward<Fn>(f), "anonymous", extras...);
  197. m_ptr = vm->bind_func(nullptr, "", -1, _wrapper, std::move(userdata));
  198. inc_ref();
  199. }
  200. };
  201. template <typename Fn, typename... Extras>
  202. handle bind_function(const handle& obj, const char* name, Fn&& fn, pkpy::BindType type, const Extras&... extras) {
  203. // do not use cpp_function directly to avoid unnecessary reference count change
  204. pkpy::PyVar var = obj.ptr();
  205. pkpy::PyVar callable = var->attr().try_get(name);
  206. // if the function is not bound yet, bind it
  207. if(!callable) {
  208. pkpy::any userdata = function_record(std::forward<Fn>(fn), name, extras...);
  209. callable = vm->bind_func(var, name, -1, _wrapper, std::move(userdata));
  210. } else {
  211. auto& userdata = callable.obj_get<pkpy::NativeFunc>()._userdata;
  212. function_record* record = new function_record(std::forward<Fn>(fn), name, extras...);
  213. constexpr bool is_prepend = (types_count_v<prepend, Extras...> != 0);
  214. if constexpr(is_prepend) {
  215. // if prepend is specified, append the new record to the beginning of the list
  216. function_record* last = (function_record*)userdata.data;
  217. userdata.data = record;
  218. record->append(last);
  219. } else {
  220. // otherwise, append the new record to the end of the list
  221. function_record* last = (function_record*)userdata.data;
  222. last->append(record);
  223. }
  224. }
  225. return callable;
  226. }
  227. template <typename Getter_, typename Setter_, typename... Extras>
  228. handle
  229. bind_property(const handle& obj, const char* name, Getter_&& getter_, Setter_&& setter_, const Extras&... extras) {
  230. pkpy::PyVar var = obj.ptr();
  231. pkpy::PyVar getter = vm->None;
  232. pkpy::PyVar setter = vm->None;
  233. using Getter = std::decay_t<Getter_>;
  234. using Setter = std::decay_t<Setter_>;
  235. getter = vm->new_object<pkpy::NativeFunc>(
  236. vm->tp_native_func,
  237. [](pkpy::VM* vm, pkpy::ArgsView view) -> pkpy::PyVar {
  238. auto& getter = pkpy::lambda_get_userdata<Getter>(view.begin());
  239. if constexpr(std::is_member_pointer_v<Getter>) {
  240. using Self = class_type_t<Getter>;
  241. auto& self = _builtin_cast<instance>(view[0]).cast<Self>();
  242. if constexpr(std::is_member_object_pointer_v<Getter>) {
  243. return type_caster<member_type_t<Getter>>::cast(self.*getter,
  244. return_value_policy::reference_internal,
  245. view[0])
  246. .ptr();
  247. } else {
  248. return type_caster<callable_return_t<Getter>>::cast((self.*getter)(),
  249. return_value_policy::reference_internal,
  250. view[0])
  251. .ptr();
  252. }
  253. } else {
  254. using Self = std::tuple_element_t<0, callable_args_t<Getter>>;
  255. auto& self = _builtin_cast<instance>(view[0]).cast<Self>();
  256. return type_caster<callable_return_t<Getter>>::cast(getter(self),
  257. return_value_policy::reference_internal,
  258. view[0])
  259. .ptr();
  260. }
  261. },
  262. 1,
  263. std::forward<Getter_>(getter_));
  264. if constexpr(!std::is_same_v<Setter, std::nullptr_t>) {
  265. setter = vm->new_object<pkpy::NativeFunc>(
  266. vm->tp_native_func,
  267. [](pkpy::VM* vm, pkpy::ArgsView view) -> pkpy::PyVar {
  268. auto& setter = pkpy::lambda_get_userdata<Setter>(view.begin());
  269. if constexpr(std::is_member_pointer_v<Setter>) {
  270. using Self = class_type_t<Setter>;
  271. auto& self = _builtin_cast<instance>(view[0]).cast<Self>();
  272. if constexpr(std::is_member_object_pointer_v<Setter>) {
  273. type_caster<member_type_t<Setter>> caster;
  274. if(caster.load(view[1], true)) {
  275. self.*setter = caster.value;
  276. return vm->None;
  277. }
  278. } else {
  279. type_caster<std::tuple_element_t<1, callable_args_t<Setter>>> caster;
  280. if(caster.load(view[1], true)) {
  281. (self.*setter)(caster.value);
  282. return vm->None;
  283. }
  284. }
  285. } else {
  286. using Self = std::tuple_element_t<0, callable_args_t<Setter>>;
  287. auto& self = _builtin_cast<instance>(view[0]).cast<Self>();
  288. type_caster<std::tuple_element_t<1, callable_args_t<Setter>>> caster;
  289. if(caster.load(view[1], true)) {
  290. setter(self, caster.value);
  291. return vm->None;
  292. }
  293. }
  294. vm->TypeError("invalid argument");
  295. },
  296. 2,
  297. std::forward<Setter_>(setter_));
  298. }
  299. pkpy::PyVar property = vm->new_object<pkpy::Property>(vm->tp_property, getter, setter);
  300. var->attr().set(name, property);
  301. return property;
  302. }
  303. } // namespace pybind11