cpp_function.h 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589
  1. #pragma once
  2. #include "cast.h"
  3. #include <map>
  4. namespace pybind11 {
  5. // append the overload to the beginning of the overload list
  6. struct prepend {};
  7. namespace impl {
  8. template <typename... Args>
  9. struct constructor {};
  10. template <typename Fn, typename Args = callable_args_t<Fn>>
  11. struct factory;
  12. template <typename Fn, typename... Args>
  13. struct factory<Fn, std::tuple<Args...>> {
  14. Fn fn;
  15. auto make() {
  16. using Self = callable_return_t<Fn>;
  17. return [fn = std::move(fn)](Self* self, Args... args) {
  18. new (self) Self(fn(args...));
  19. };
  20. }
  21. };
  22. } // namespace impl
  23. template <typename... Args>
  24. impl::constructor<Args...> init() {
  25. return {};
  26. }
  27. template <typename Fn>
  28. impl::factory<Fn> init(Fn&& fn) {
  29. return {std::forward<Fn>(fn)};
  30. }
  31. // TODO: support more customized tags
  32. //
  33. // template <std::size_t Nurse, std::size_t... Patients>
  34. // struct keep_alive {};
  35. //
  36. // template <typename T>
  37. // struct call_guard {
  38. // static_assert(std::is_default_constructible_v<T>, "call_guard must be default constructible");
  39. // };
  40. //
  41. // struct kw_only {};
  42. //
  43. // struct pos_only {};
  44. class cpp_function : public function {
  45. PYBIND11_TYPE_IMPLEMENT(function, pkpy::NativeFunc, vm->tp_native_func);
  46. public:
  47. template <typename Fn, typename... Extras>
  48. cpp_function(Fn&& f, const Extras&... extras) {}
  49. template <typename T>
  50. decltype(auto) get_userdata_as() {
  51. #if PK_VERSION_MAJOR == 2
  52. return self()._userdata.as<T>();
  53. #else
  54. return self()._userdata._cast<T>();
  55. #endif
  56. }
  57. template <typename T>
  58. void set_userdata(T&& value) {
  59. self()._userdata = std::forward<T>(value);
  60. }
  61. };
  62. } // namespace pybind11
  63. namespace pybind11::impl {
  64. template <typename Callable,
  65. typename Extra,
  66. typename Args = callable_args_t<Callable>,
  67. typename IndexSequence = std::make_index_sequence<std::tuple_size_v<Args>>>
  68. struct template_parser;
  69. class function_record {
  70. private:
  71. template <typename C, typename E, typename A, typename I>
  72. friend struct template_parser;
  73. struct arguments_t {
  74. std::vector<pkpy::StrName> names;
  75. std::vector<handle> defaults;
  76. };
  77. using destructor_t = void (*)(function_record*);
  78. using wrapper_t = handle (*)(function_record&, pkpy::ArgsView, bool convert, handle parent);
  79. static_assert(std::is_trivially_copyable_v<pkpy::StrName>);
  80. private:
  81. union {
  82. void* data;
  83. char buffer[16];
  84. };
  85. wrapper_t wrapper = nullptr;
  86. function_record* next = nullptr;
  87. arguments_t* arguments = nullptr;
  88. destructor_t destructor = nullptr;
  89. const char* signature = nullptr;
  90. return_value_policy policy = return_value_policy::automatic;
  91. public:
  92. template <typename Fn, typename... Extras>
  93. function_record(Fn&& f, const Extras&... extras) {
  94. using Callable = std::decay_t<Fn>;
  95. if constexpr(std::is_trivially_copyable_v<Callable> && sizeof(Callable) <= sizeof(buffer)) {
  96. // if the callable object is trivially copyable and the size is less than 16 bytes, store it in the
  97. // buffer
  98. new (buffer) auto(std::forward<Fn>(f));
  99. destructor = [](function_record* self) {
  100. reinterpret_cast<Callable*>(self->buffer)->~Callable();
  101. };
  102. } else {
  103. // otherwise, store it in the heap
  104. data = new auto(std::forward<Fn>(f));
  105. destructor = [](function_record* self) {
  106. delete static_cast<Callable*>(self->data);
  107. };
  108. }
  109. using Parser = template_parser<Callable, std::tuple<Extras...>>;
  110. Parser::initialize(*this, extras...);
  111. wrapper = Parser::wrapper;
  112. }
  113. function_record(const function_record&) = delete;
  114. function_record& operator= (const function_record&) = delete;
  115. function_record& operator= (function_record&&) = delete;
  116. function_record(function_record&& other) noexcept {
  117. std::memcpy(this, &other, sizeof(function_record));
  118. std::memset(&other, 0, sizeof(function_record));
  119. }
  120. ~function_record() {
  121. if(destructor) { destructor(this); }
  122. if(arguments) { delete arguments; }
  123. if(next) { delete next; }
  124. if(signature) { delete[] signature; }
  125. }
  126. void append(function_record* record) {
  127. function_record* p = this;
  128. while(p->next) {
  129. p = p->next;
  130. }
  131. p->next = record;
  132. }
  133. template <typename T>
  134. T& _as() {
  135. if constexpr(std::is_trivially_copyable_v<T> && sizeof(T) <= sizeof(buffer)) {
  136. return *reinterpret_cast<T*>(buffer);
  137. } else {
  138. return *static_cast<T*>(data);
  139. }
  140. }
  141. handle operator() (pkpy::ArgsView view) {
  142. function_record* p = this;
  143. // foreach function record and call the function with not convert
  144. while(p != nullptr) {
  145. handle result = p->wrapper(*p, view, false, {});
  146. if(result) { return result; }
  147. p = p->next;
  148. }
  149. p = this;
  150. // foreach function record and call the function with convert
  151. while(p != nullptr) {
  152. handle result = p->wrapper(*p, view, true, {});
  153. if(result) { return result; }
  154. p = p->next;
  155. }
  156. std::string msg = "no matching function found, function signature:\n";
  157. std::size_t index = 0;
  158. p = this;
  159. while(p != nullptr) {
  160. msg += " ";
  161. msg += p->signature;
  162. msg += "\n";
  163. p = p->next;
  164. }
  165. vm->TypeError(msg);
  166. PK_UNREACHABLE();
  167. }
  168. };
  169. template <typename Fn, std::size_t... Is, typename... Args>
  170. handle invoke(Fn&& fn,
  171. std::index_sequence<Is...>,
  172. std::tuple<impl::type_caster<Args>...>& casters,
  173. return_value_policy policy,
  174. handle parent) {
  175. using underlying_type = std::decay_t<Fn>;
  176. using return_type = callable_return_t<underlying_type>;
  177. constexpr bool is_void = std::is_void_v<return_type>;
  178. constexpr bool is_member_function_pointer = std::is_member_function_pointer_v<underlying_type>;
  179. if constexpr(is_member_function_pointer) {
  180. // helper function to unpack the arguments to call the member pointer
  181. auto unpack = [&](class_type_t<underlying_type>& self, auto&... args) {
  182. return (self.*fn)(args...);
  183. };
  184. if constexpr(!is_void) {
  185. return pybind11::cast(unpack(std::get<Is>(casters).value...), policy, parent);
  186. } else {
  187. unpack(std::get<Is>(casters).value...);
  188. return vm->None;
  189. }
  190. } else {
  191. if constexpr(!is_void) {
  192. return pybind11::cast(fn(std::get<Is>(casters).value...), policy, parent);
  193. } else {
  194. fn(std::get<Is>(casters).value...);
  195. return vm->None;
  196. }
  197. }
  198. }
  199. struct arguments_info_t {
  200. int argc = 0;
  201. int args_pos = -1;
  202. int kwargs_pos = -1;
  203. };
  204. struct extras_info_t {
  205. int doc_pos = -1;
  206. int named_argc = 0;
  207. int policy_pos = -1;
  208. };
  209. template <typename Callable, typename... Extras, typename... Args, std::size_t... Is>
  210. struct template_parser<Callable, std::tuple<Extras...>, std::tuple<Args...>, std::index_sequence<Is...>> {
  211. constexpr static arguments_info_t parse_arguments() {
  212. constexpr auto args_count = types_count_v<args, Args...>;
  213. constexpr auto kwargs_count = types_count_v<kwargs, Args...>;
  214. static_assert(args_count <= 1, "py::args can occur at most once");
  215. static_assert(kwargs_count <= 1, "py::kwargs can occur at most once");
  216. constexpr auto args_pos = type_index_v<args, Args...>;
  217. constexpr auto kwargs_pos = type_index_v<kwargs, Args...>;
  218. if constexpr(kwargs_count == 1) {
  219. static_assert(kwargs_pos == sizeof...(Args) - 1, "py::kwargs must be the last argument");
  220. // FIXME: temporarily, args and kwargs must be at the end of the arguments list
  221. if constexpr(args_count == 1) {
  222. static_assert(args_pos == kwargs_pos - 1, "py::args must be before py::kwargs");
  223. }
  224. }
  225. return {sizeof...(Args), args_pos, kwargs_pos};
  226. }
  227. constexpr static extras_info_t parse_extras() {
  228. constexpr auto doc_count = types_count_v<const char*, Extras...>;
  229. constexpr auto policy_count = types_count_v<return_value_policy, Extras...>;
  230. static_assert(doc_count <= 1, "doc can occur at most once");
  231. static_assert(policy_count <= 1, "return_value_policy can occur at most once");
  232. constexpr auto doc_pos = type_index_v<const char*, Extras...>;
  233. constexpr auto policy_pos = type_index_v<return_value_policy, Extras...>;
  234. constexpr auto named_argc = types_count_v<arg, Extras...>;
  235. constexpr auto normal_argc =
  236. sizeof...(Args) - (arguments_info.args_pos != -1) - (arguments_info.kwargs_pos != -1);
  237. static_assert(named_argc == 0 || named_argc == normal_argc,
  238. "named arguments must be the same as the number of function arguments");
  239. return {doc_pos, named_argc, policy_pos};
  240. }
  241. constexpr inline static auto arguments_info = parse_arguments();
  242. constexpr inline static auto extras_info = parse_extras();
  243. static void initialize(function_record& record, const Extras&... extras) {
  244. auto extras_tuple = std::make_tuple(extras...);
  245. constexpr static bool has_named_args = (extras_info.named_argc > 0);
  246. // set return value policy
  247. if constexpr(extras_info.policy_pos != -1) { record.policy = std::get<extras_info.policy_pos>(extras_tuple); }
  248. // TODO: set others
  249. // set default arguments
  250. if constexpr(has_named_args) {
  251. record.arguments = new function_record::arguments_t();
  252. auto add_arguments = [&](const auto& arg) {
  253. if constexpr(std::is_same_v<pybind11::arg, remove_cvref_t<decltype(arg)>>) {
  254. auto& arguments = *record.arguments;
  255. arguments.names.emplace_back(arg.name);
  256. arguments.defaults.emplace_back(arg.default_);
  257. }
  258. };
  259. (add_arguments(extras), ...);
  260. }
  261. // set signature
  262. {
  263. std::string sig = "(";
  264. std::size_t index = 0;
  265. auto append = [&](auto _t) {
  266. using T = pybind11_decay_t<typename decltype(_t)::type>;
  267. if constexpr(std::is_same_v<T, args>) {
  268. sig += "*args";
  269. } else if constexpr(std::is_same_v<T, kwargs>) {
  270. sig += "**kwargs";
  271. } else if constexpr(has_named_args) {
  272. sig += record.arguments->names[index].c_str();
  273. sig += ": ";
  274. sig += type_info::of<T>().name;
  275. if(record.arguments->defaults[index]) {
  276. sig += " = ";
  277. sig += record.arguments->defaults[index].repr();
  278. }
  279. } else {
  280. sig += "_: ";
  281. sig += type_info::of<T>().name;
  282. }
  283. if(index + 1 < arguments_info.argc) { sig += ", "; }
  284. index++;
  285. };
  286. (append(type_identity<Args>{}), ...);
  287. sig += ")";
  288. char* buffer = new char[sig.size() + 1];
  289. std::memcpy(buffer, sig.data(), sig.size());
  290. buffer[sig.size()] = '\0';
  291. record.signature = buffer;
  292. }
  293. }
  294. static handle wrapper(function_record& record, pkpy::ArgsView view, bool convert, handle parent) {
  295. constexpr auto argc = arguments_info.argc;
  296. constexpr auto named_argc = extras_info.named_argc;
  297. constexpr auto args_pos = arguments_info.args_pos;
  298. constexpr auto kwargs_pos = arguments_info.kwargs_pos;
  299. constexpr auto normal_argc = argc - (args_pos != -1) - (kwargs_pos != -1);
  300. // avoid gc call in bound function
  301. vm->heap.gc_scope_lock();
  302. // add 1 to avoid zero-size array when argc is 0
  303. handle stack[argc + 1] = {};
  304. // ensure the number of passed arguments is no greater than the number of parameters
  305. if(args_pos == -1 && view.size() > normal_argc) { return handle(); }
  306. // if have default arguments, load them
  307. if constexpr(named_argc > 0) {
  308. auto& defaults = record.arguments->defaults;
  309. std::memcpy(stack, defaults.data(), defaults.size() * sizeof(handle));
  310. }
  311. // load arguments from call arguments
  312. const auto size = std::min(view.size(), normal_argc);
  313. std::memcpy(stack, view.begin(), size * sizeof(handle));
  314. // pack the args
  315. if constexpr(args_pos != -1) {
  316. const auto n = std::max(view.size() - normal_argc, 0);
  317. tuple args = tuple(n);
  318. for(std::size_t i = 0; i < n; ++i) {
  319. args[i] = view[normal_argc + i];
  320. }
  321. stack[args_pos] = args;
  322. }
  323. // resolve keyword arguments
  324. const auto n = vm->s_data._sp - view.end();
  325. int index = 0;
  326. if constexpr(named_argc > 0) {
  327. int arg_index = 0;
  328. auto& arguments = *record.arguments;
  329. while(arg_index < named_argc && index < n) {
  330. const auto key = pkpy::_py_cast<pkpy::i64>(vm, view.end()[index]);
  331. const auto value = view.end()[index + 1];
  332. const auto name = pkpy::StrName(key);
  333. auto& arg_name = record.arguments->names[arg_index];
  334. if(name == arg_name) {
  335. stack[arg_index] = value;
  336. index += 2;
  337. }
  338. arg_index += 1;
  339. }
  340. }
  341. // pack the kwargs
  342. if constexpr(kwargs_pos != -1) {
  343. dict kwargs;
  344. while(index < n) {
  345. const auto key = pkpy::_py_cast<pkpy::i64>(vm, view.end()[index]);
  346. const str name = str(pkpy::StrName(key).sv());
  347. kwargs[name] = view.end()[index + 1];
  348. index += 2;
  349. }
  350. stack[kwargs_pos] = kwargs;
  351. }
  352. // if have rest keyword arguments, call fails
  353. if(index != n) { return handle(); }
  354. // check if all the arguments are valid
  355. for(std::size_t i = 0; i < argc; ++i) {
  356. if(!stack[i]) { return handle(); }
  357. }
  358. // ok, all the arguments are valid, call the function
  359. std::tuple<impl::type_caster<Args>...> casters;
  360. // check type compatibility
  361. if(((std::get<Is>(casters).load(stack[Is], convert)) && ...)) {
  362. return invoke(record._as<Callable>(), std::index_sequence<Is...>{}, casters, record.policy, parent);
  363. }
  364. return handle();
  365. }
  366. };
  367. inline auto _wrapper(pkpy::VM* vm, pkpy::ArgsView view) {
  368. auto&& record = unpack<function_record>(view);
  369. return record(view).ptr();
  370. }
  371. template <bool is_method, typename Fn, typename... Extras>
  372. handle bind_function(const handle& obj, const char* name, Fn&& fn, pkpy::BindType type, const Extras&... extras) {
  373. // do not use cpp_function directly to avoid unnecessary reference count change
  374. pkpy::PyVar var = obj.ptr();
  375. cpp_function callable = var->attr().try_get(name);
  376. function_record* record = nullptr;
  377. if constexpr(is_method && types_count_v<arg, Extras...> > 0) {
  378. // if the function is a method and has named arguments
  379. // prepend self to the arguments list
  380. record = new function_record(std::forward<Fn>(fn), arg("self"), extras...);
  381. } else {
  382. record = new function_record(std::forward<Fn>(fn), extras...);
  383. }
  384. if(!callable) {
  385. // if the function is not bound yet, bind it
  386. void* data = interpreter::take_ownership(std::move(*record));
  387. callable = interpreter::bind_func(var, name, -1, _wrapper, data, type);
  388. } else {
  389. // if the function is already bound, append the new record to the function
  390. function_record* last = callable.get_userdata_as<function_record*>();
  391. if constexpr((types_count_v<prepend, Extras...> != 0)) {
  392. // if prepend is specified, append the new record to the beginning of the list
  393. callable.set_userdata(record);
  394. record->append(last);
  395. } else {
  396. // otherwise, append the new record to the end of the list
  397. last->append(record);
  398. }
  399. }
  400. return callable;
  401. }
  402. } // namespace pybind11::impl
  403. namespace pybind11::impl {
  404. template <typename Getter>
  405. pkpy::PyVar getter_wrapper(pkpy::VM* vm, pkpy::ArgsView view) {
  406. handle result = vm->None;
  407. auto&& getter = unpack<Getter>(view);
  408. constexpr auto policy = return_value_policy::reference_internal;
  409. if constexpr(std::is_member_pointer_v<Getter>) {
  410. using Self = class_type_t<Getter>;
  411. auto& self = handle(view[0])._as<instance>()._as<Self>();
  412. if constexpr(std::is_member_object_pointer_v<Getter>) {
  413. // specialize for pointer to data member
  414. result = cast(self.*getter, policy, view[0]);
  415. } else {
  416. // specialize for pointer to member function
  417. result = cast((self.*getter)(), policy, view[0]);
  418. }
  419. } else {
  420. // specialize for function pointer and lambda
  421. using Self = remove_cvref_t<std::tuple_element_t<0, callable_args_t<Getter>>>;
  422. auto& self = handle(view[0])._as<instance>()._as<Self>();
  423. result = cast(getter(self), policy, view[0]);
  424. }
  425. return result.ptr();
  426. }
  427. template <typename Setter>
  428. pkpy::PyVar setter_wrapper(pkpy::VM* vm, pkpy::ArgsView view) {
  429. auto&& setter = unpack<Setter>(view);
  430. if constexpr(std::is_member_pointer_v<Setter>) {
  431. using Self = class_type_t<Setter>;
  432. auto& self = handle(view[0])._as<instance>()._as<Self>();
  433. if constexpr(std::is_member_object_pointer_v<Setter>) {
  434. // specialize for pointer to data member
  435. impl::type_caster<member_type_t<Setter>> caster;
  436. if(caster.load(view[1], true)) {
  437. self.*setter = caster.value;
  438. return vm->None;
  439. }
  440. } else {
  441. // specialize for pointer to member function
  442. impl::type_caster<std::tuple_element_t<1, callable_args_t<Setter>>> caster;
  443. if(caster.load(view[1], true)) {
  444. (self.*setter)(caster.value);
  445. return vm->None;
  446. }
  447. }
  448. } else {
  449. // specialize for function pointer and lambda
  450. using Self = remove_cvref_t<std::tuple_element_t<0, callable_args_t<Setter>>>;
  451. auto& self = handle(view[0])._as<instance>()._as<Self>();
  452. impl::type_caster<std::tuple_element_t<1, callable_args_t<Setter>>> caster;
  453. if(caster.load(view[1], true)) {
  454. setter(self, caster.value);
  455. return vm->None;
  456. }
  457. }
  458. vm->TypeError("Unexpected argument type");
  459. PK_UNREACHABLE();
  460. }
  461. template <typename Getter, typename Setter, typename... Extras>
  462. handle bind_property(const handle& obj, const char* name, Getter&& getter_, Setter&& setter_, const Extras&... extras) {
  463. handle getter = none();
  464. handle setter = none();
  465. using Wrapper = pkpy::PyVar (*)(pkpy::VM*, pkpy::ArgsView);
  466. constexpr auto create = [](Wrapper wrapper, int argc, auto&& f) {
  467. if constexpr(need_host<remove_cvref_t<decltype(f)>>) {
  468. // otherwise, store it in the type_info
  469. void* data = interpreter::take_ownership(std::forward<decltype(f)>(f));
  470. // store the index in the object
  471. return vm->heap.gcnew<pkpy::NativeFunc>(vm->tp_native_func, wrapper, argc, data);
  472. } else {
  473. // if the function is trivially copyable and the size is less than 16 bytes, store it in the object
  474. // directly
  475. return vm->heap.gcnew<pkpy::NativeFunc>(vm->tp_native_func, wrapper, argc, f);
  476. }
  477. };
  478. getter = create(impl::getter_wrapper<std::decay_t<Getter>>, 1, std::forward<Getter>(getter_));
  479. if constexpr(!std::is_same_v<Setter, std::nullptr_t>) {
  480. setter = create(impl::setter_wrapper<std::decay_t<Setter>>, 2, std::forward<Setter>(setter_));
  481. }
  482. handle property = pybind11::property(getter, setter);
  483. setattr(obj, name, property);
  484. return property;
  485. }
  486. } // namespace pybind11::impl