1
0

cpp_function.h 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. #pragma once
  2. #include "cast.h"
  3. #include <map>
  4. namespace pybind11 {
  5. // append the overload to the beginning of the overload list
  6. struct prepend {};
  7. template <typename... Args>
  8. struct init {};
  9. // TODO: support more customized tags
  10. //
  11. // template <std::size_t Nurse, std::size_t... Patients>
  12. // struct keep_alive {};
  13. //
  14. // template <typename T>
  15. // struct call_guard {
  16. // static_assert(std::is_default_constructible_v<T>, "call_guard must be default constructible");
  17. // };
  18. //
  19. // struct kw_only {};
  20. //
  21. // struct pos_only {};
  22. class cpp_function : public function {
  23. PYBIND11_TYPE_IMPLEMENT(function, pkpy::NativeFunc, vm->tp_native_func);
  24. public:
  25. template <typename Fn, typename... Extras>
  26. cpp_function(Fn&& f, const Extras&... extras) {}
  27. template <typename T>
  28. decltype(auto) get_userdata_as() {
  29. #if PK_VERSION_MAJOR == 2
  30. return self()._userdata.as<T>();
  31. #else
  32. return self()._userdata._cast<T>();
  33. #endif
  34. }
  35. template <typename T>
  36. void set_userdata(T&& value) {
  37. self()._userdata = std::forward<T>(value);
  38. }
  39. };
  40. } // namespace pybind11
  41. namespace pybind11::impl {
  42. template <typename Callable,
  43. typename Extra,
  44. typename Args = callable_args_t<Callable>,
  45. typename IndexSequence = std::make_index_sequence<std::tuple_size_v<Args>>>
  46. struct template_parser;
  47. class function_record {
  48. private:
  49. template <typename C, typename E, typename A, typename I>
  50. friend struct template_parser;
  51. struct arguments_t {
  52. std::vector<pkpy::StrName> names;
  53. std::vector<handle> defaults;
  54. };
  55. using destructor_t = void (*)(function_record*);
  56. using wrapper_t = handle (*)(function_record&, pkpy::ArgsView, bool convert, handle parent);
  57. static_assert(std::is_trivially_copyable_v<pkpy::StrName>);
  58. private:
  59. union {
  60. void* data;
  61. char buffer[16];
  62. };
  63. wrapper_t wrapper = nullptr;
  64. function_record* next = nullptr;
  65. arguments_t* arguments = nullptr;
  66. destructor_t destructor = nullptr;
  67. const char* signature = nullptr;
  68. return_value_policy policy = return_value_policy::automatic;
  69. public:
  70. template <typename Fn, typename... Extras>
  71. function_record(Fn&& f, const Extras&... extras) {
  72. using Callable = std::decay_t<Fn>;
  73. if constexpr(std::is_trivially_copyable_v<Callable> && sizeof(Callable) <= sizeof(buffer)) {
  74. // if the callable object is trivially copyable and the size is less than 16 bytes, store it in the
  75. // buffer
  76. new (buffer) auto(std::forward<Fn>(f));
  77. destructor = [](function_record* self) {
  78. reinterpret_cast<Callable*>(self->buffer)->~Callable();
  79. };
  80. } else {
  81. // otherwise, store it in the heap
  82. data = new auto(std::forward<Fn>(f));
  83. destructor = [](function_record* self) {
  84. delete static_cast<Callable*>(self->data);
  85. };
  86. }
  87. using Parser = template_parser<Callable, std::tuple<Extras...>>;
  88. Parser::initialize(*this, extras...);
  89. wrapper = Parser::wrapper;
  90. }
  91. function_record(const function_record&) = delete;
  92. function_record& operator= (const function_record&) = delete;
  93. function_record& operator= (function_record&&) = delete;
  94. function_record(function_record&& other) noexcept {
  95. std::memcpy(this, &other, sizeof(function_record));
  96. std::memset(&other, 0, sizeof(function_record));
  97. }
  98. ~function_record() {
  99. if(destructor) { destructor(this); }
  100. if(arguments) { delete arguments; }
  101. if(next) { delete next; }
  102. if(signature) { delete[] signature; }
  103. }
  104. void append(function_record* record) {
  105. function_record* p = this;
  106. while(p->next) {
  107. p = p->next;
  108. }
  109. p->next = record;
  110. }
  111. template <typename T>
  112. T& _as() {
  113. if constexpr(std::is_trivially_copyable_v<T> && sizeof(T) <= sizeof(buffer)) {
  114. return *reinterpret_cast<T*>(buffer);
  115. } else {
  116. return *static_cast<T*>(data);
  117. }
  118. }
  119. handle operator() (pkpy::ArgsView view) {
  120. function_record* p = this;
  121. // foreach function record and call the function with not convert
  122. while(p != nullptr) {
  123. handle result = p->wrapper(*p, view, false, {});
  124. if(result) { return result; }
  125. p = p->next;
  126. }
  127. p = this;
  128. // foreach function record and call the function with convert
  129. while(p != nullptr) {
  130. handle result = p->wrapper(*p, view, true, {});
  131. if(result) { return result; }
  132. p = p->next;
  133. }
  134. std::string msg = "no matching function found, function signature:\n";
  135. std::size_t index = 0;
  136. p = this;
  137. while(p != nullptr) {
  138. msg += " ";
  139. msg += p->signature;
  140. msg += "\n";
  141. p = p->next;
  142. }
  143. vm->TypeError(msg);
  144. PK_UNREACHABLE();
  145. }
  146. };
  147. template <typename Fn, std::size_t... Is, typename... Args>
  148. handle invoke(Fn&& fn,
  149. std::index_sequence<Is...>,
  150. std::tuple<type_caster<Args>...>& casters,
  151. return_value_policy policy,
  152. handle parent) {
  153. using underlying_type = std::decay_t<Fn>;
  154. using return_type = callable_return_t<underlying_type>;
  155. constexpr bool is_void = std::is_void_v<return_type>;
  156. constexpr bool is_member_function_pointer = std::is_member_function_pointer_v<underlying_type>;
  157. if constexpr(is_member_function_pointer) {
  158. // helper function to unpack the arguments to call the member pointer
  159. auto unpack = [&](class_type_t<underlying_type>& self, auto&... args) {
  160. return (self.*fn)(args...);
  161. };
  162. if constexpr(!is_void) {
  163. return pybind11::cast(unpack(std::get<Is>(casters).value...), policy, parent);
  164. } else {
  165. unpack(std::get<Is>(casters).value...);
  166. return vm->None;
  167. }
  168. } else {
  169. if constexpr(!is_void) {
  170. return pybind11::cast(fn(std::get<Is>(casters).value...), policy, parent);
  171. } else {
  172. fn(std::get<Is>(casters).value...);
  173. return vm->None;
  174. }
  175. }
  176. }
  177. struct arguments_info_t {
  178. int argc = 0;
  179. int args_pos = -1;
  180. int kwargs_pos = -1;
  181. };
  182. struct extras_info_t {
  183. int doc_pos = -1;
  184. int named_argc = 0;
  185. int policy_pos = -1;
  186. };
  187. template <typename Callable, typename... Extras, typename... Args, std::size_t... Is>
  188. struct template_parser<Callable, std::tuple<Extras...>, std::tuple<Args...>, std::index_sequence<Is...>> {
  189. constexpr static arguments_info_t parse_arguments() {
  190. constexpr auto args_count = types_count_v<args, Args...>;
  191. constexpr auto kwargs_count = types_count_v<kwargs, Args...>;
  192. static_assert(args_count <= 1, "py::args can occur at most once");
  193. static_assert(kwargs_count <= 1, "py::kwargs can occur at most once");
  194. constexpr auto args_pos = type_index_v<args, Args...>;
  195. constexpr auto kwargs_pos = type_index_v<kwargs, Args...>;
  196. if constexpr(kwargs_count == 1) {
  197. static_assert(kwargs_pos == sizeof...(Args) - 1, "py::kwargs must be the last argument");
  198. // FIXME: temporarily, args and kwargs must be at the end of the arguments list
  199. if constexpr(args_count == 1) {
  200. static_assert(args_pos == kwargs_pos - 1, "py::args must be before py::kwargs");
  201. }
  202. }
  203. return {sizeof...(Args), args_pos, kwargs_pos};
  204. }
  205. constexpr static extras_info_t parse_extras() {
  206. constexpr auto doc_count = types_count_v<const char*, Extras...>;
  207. constexpr auto policy_count = types_count_v<return_value_policy, Extras...>;
  208. static_assert(doc_count <= 1, "doc can occur at most once");
  209. static_assert(policy_count <= 1, "return_value_policy can occur at most once");
  210. constexpr auto doc_pos = type_index_v<const char*, Extras...>;
  211. constexpr auto policy_pos = type_index_v<return_value_policy, Extras...>;
  212. constexpr auto named_argc = types_count_v<arg, Extras...>;
  213. constexpr auto normal_argc =
  214. sizeof...(Args) - (arguments_info.args_pos != -1) - (arguments_info.kwargs_pos != -1);
  215. static_assert(named_argc == 0 || named_argc == normal_argc,
  216. "named arguments must be the same as the number of function arguments");
  217. return {doc_pos, named_argc, policy_pos};
  218. }
  219. constexpr inline static auto arguments_info = parse_arguments();
  220. constexpr inline static auto extras_info = parse_extras();
  221. static void initialize(function_record& record, const Extras&... extras) {
  222. auto extras_tuple = std::make_tuple(extras...);
  223. constexpr static bool has_named_args = (extras_info.named_argc > 0);
  224. // set return value policy
  225. if constexpr(extras_info.policy_pos != -1) { record.policy = std::get<extras_info.policy_pos>(extras_tuple); }
  226. // TODO: set others
  227. // set default arguments
  228. if constexpr(has_named_args) {
  229. record.arguments = new function_record::arguments_t();
  230. auto add_arguments = [&](const auto& arg) {
  231. if constexpr(std::is_same_v<pybind11::arg, remove_cvref_t<decltype(arg)>>) {
  232. auto& arguments = *record.arguments;
  233. arguments.names.emplace_back(arg.name);
  234. arguments.defaults.emplace_back(arg.default_);
  235. }
  236. };
  237. (add_arguments(extras), ...);
  238. }
  239. // set signature
  240. {
  241. std::string sig = "(";
  242. std::size_t index = 0;
  243. auto append = [&](auto _t) {
  244. using T = pybind11_decay_t<typename decltype(_t)::type>;
  245. if constexpr(std::is_same_v<T, args>) {
  246. sig += "*args";
  247. } else if constexpr(std::is_same_v<T, kwargs>) {
  248. sig += "**kwargs";
  249. } else if constexpr(has_named_args) {
  250. sig += record.arguments->names[index].c_str();
  251. sig += ": ";
  252. sig += type_info::of<T>().name;
  253. if(record.arguments->defaults[index]) {
  254. sig += " = ";
  255. sig += record.arguments->defaults[index].repr();
  256. }
  257. } else {
  258. sig += "_: ";
  259. sig += type_info::of<T>().name;
  260. }
  261. if(index + 1 < arguments_info.argc) { sig += ", "; }
  262. index++;
  263. };
  264. (append(type_identity<Args>{}), ...);
  265. sig += ")";
  266. char* buffer = new char[sig.size() + 1];
  267. std::memcpy(buffer, sig.data(), sig.size());
  268. buffer[sig.size()] = '\0';
  269. record.signature = buffer;
  270. }
  271. }
  272. static handle wrapper(function_record& record, pkpy::ArgsView view, bool convert, handle parent) {
  273. constexpr auto argc = arguments_info.argc;
  274. constexpr auto named_argc = extras_info.named_argc;
  275. constexpr auto args_pos = arguments_info.args_pos;
  276. constexpr auto kwargs_pos = arguments_info.kwargs_pos;
  277. constexpr auto normal_argc = argc - (args_pos != -1) - (kwargs_pos != -1);
  278. // avoid gc call in bound function
  279. vm->heap.gc_scope_lock();
  280. // add 1 to avoid zero-size array when argc is 0
  281. handle stack[argc + 1] = {};
  282. // ensure the number of passed arguments is no greater than the number of parameters
  283. if(args_pos == -1 && view.size() > normal_argc) { return handle(); }
  284. // if have default arguments, load them
  285. if constexpr(named_argc > 0) {
  286. auto& defaults = record.arguments->defaults;
  287. std::memcpy(stack, defaults.data(), defaults.size() * sizeof(handle));
  288. }
  289. // load arguments from call arguments
  290. const auto size = std::min(view.size(), normal_argc);
  291. std::memcpy(stack, view.begin(), size * sizeof(handle));
  292. // pack the args
  293. if constexpr(args_pos != -1) {
  294. const auto n = std::max(view.size() - normal_argc, 0);
  295. tuple args = tuple(n);
  296. for(std::size_t i = 0; i < n; ++i) {
  297. args[i] = view[normal_argc + i];
  298. }
  299. stack[args_pos] = args;
  300. }
  301. // resolve keyword arguments
  302. const auto n = vm->s_data._sp - view.end();
  303. std::size_t index = 0;
  304. if constexpr(named_argc > 0) {
  305. std::size_t arg_index = 0;
  306. auto& arguments = *record.arguments;
  307. while(arg_index < named_argc && index < n) {
  308. const auto key = pkpy::_py_cast<pkpy::i64>(vm, view.end()[index]);
  309. const auto value = view.end()[index + 1];
  310. const auto name = pkpy::StrName(key);
  311. auto& arg_name = record.arguments->names[arg_index];
  312. if(name == arg_name) {
  313. stack[arg_index] = value;
  314. index += 2;
  315. }
  316. arg_index += 1;
  317. }
  318. }
  319. // pack the kwargs
  320. if constexpr(kwargs_pos != -1) {
  321. dict kwargs;
  322. while(index < n) {
  323. const auto key = pkpy::_py_cast<pkpy::i64>(vm, view.end()[index]);
  324. const str name = str(pkpy::StrName(key).sv());
  325. kwargs[name] = view.end()[index + 1];
  326. index += 2;
  327. }
  328. stack[kwargs_pos] = kwargs;
  329. }
  330. // if have rest keyword arguments, call fails
  331. if(index != n) { return handle(); }
  332. // check if all the arguments are valid
  333. for(std::size_t i = 0; i < argc; ++i) {
  334. if(!stack[i]) { return handle(); }
  335. }
  336. // ok, all the arguments are valid, call the function
  337. std::tuple<type_caster<Args>...> casters;
  338. // check type compatibility
  339. if(((std::get<Is>(casters).load(stack[Is], convert)) && ...)) {
  340. return invoke(record._as<Callable>(), std::index_sequence<Is...>{}, casters, record.policy, parent);
  341. }
  342. return handle();
  343. }
  344. };
  345. inline auto _wrapper(pkpy::VM* vm, pkpy::ArgsView view) {
  346. auto&& record = unpack<function_record>(view);
  347. return record(view).ptr();
  348. }
  349. template <typename Fn, typename... Extras>
  350. handle bind_function(const handle& obj, const char* name, Fn&& fn, pkpy::BindType type, const Extras&... extras) {
  351. // do not use cpp_function directly to avoid unnecessary reference count change
  352. pkpy::PyVar var = obj.ptr();
  353. cpp_function callable = var->attr().try_get(name);
  354. // if the function is not bound yet, bind it
  355. if(!callable) {
  356. auto record = function_record(std::forward<Fn>(fn), extras...);
  357. void* data = interpreter::take_ownership(std::move(record));
  358. callable = interpreter::bind_func(var, name, -1, _wrapper, data);
  359. } else {
  360. function_record* record = new function_record(std::forward<Fn>(fn), extras...);
  361. function_record* last = callable.get_userdata_as<function_record*>();
  362. if constexpr((types_count_v<prepend, Extras...> != 0)) {
  363. // if prepend is specified, append the new record to the beginning of the list
  364. fn.set_userdata(record);
  365. record->append(last);
  366. } else {
  367. // otherwise, append the new record to the end of the list
  368. last->append(record);
  369. }
  370. }
  371. return callable;
  372. }
  373. } // namespace pybind11::impl
  374. namespace pybind11::impl {
  375. template <typename Getter>
  376. pkpy::PyVar getter_wrapper(pkpy::VM* vm, pkpy::ArgsView view) {
  377. handle result = vm->None;
  378. auto&& getter = unpack<Getter>(view);
  379. constexpr auto policy = return_value_policy::reference_internal;
  380. if constexpr(std::is_member_pointer_v<Getter>) {
  381. using Self = class_type_t<Getter>;
  382. auto& self = handle(view[0])._as<instance>()._as<Self>();
  383. if constexpr(std::is_member_object_pointer_v<Getter>) {
  384. // specialize for pointer to data member
  385. result = cast(self.*getter, policy, view[0]);
  386. } else {
  387. // specialize for pointer to member function
  388. result = cast((self.*getter)(), policy, view[0]);
  389. }
  390. } else {
  391. // specialize for function pointer and lambda
  392. using Self = remove_cvref_t<std::tuple_element_t<0, callable_args_t<Getter>>>;
  393. auto& self = handle(view[0])._as<instance>()._as<Self>();
  394. result = cast(getter(self), policy, view[0]);
  395. }
  396. return result.ptr();
  397. }
  398. template <typename Setter>
  399. pkpy::PyVar setter_wrapper(pkpy::VM* vm, pkpy::ArgsView view) {
  400. auto&& setter = unpack<Setter>(view);
  401. if constexpr(std::is_member_pointer_v<Setter>) {
  402. using Self = class_type_t<Setter>;
  403. auto& self = handle(view[0])._as<instance>()._as<Self>();
  404. if constexpr(std::is_member_object_pointer_v<Setter>) {
  405. // specialize for pointer to data member
  406. type_caster<member_type_t<Setter>> caster;
  407. if(caster.load(view[1], true)) {
  408. self.*setter = caster.value;
  409. return vm->None;
  410. }
  411. } else {
  412. // specialize for pointer to member function
  413. type_caster<std::tuple_element_t<1, callable_args_t<Setter>>> caster;
  414. if(caster.load(view[1], true)) {
  415. (self.*setter)(caster.value);
  416. return vm->None;
  417. }
  418. }
  419. } else {
  420. // specialize for function pointer and lambda
  421. using Self = remove_cvref_t<std::tuple_element_t<0, callable_args_t<Setter>>>;
  422. auto& self = handle(view[0])._as<instance>()._as<Self>();
  423. type_caster<std::tuple_element_t<1, callable_args_t<Setter>>> caster;
  424. if(caster.load(view[1], true)) {
  425. setter(self, caster.value);
  426. return vm->None;
  427. }
  428. }
  429. vm->TypeError("Unexpected argument type");
  430. PK_UNREACHABLE();
  431. }
  432. template <typename Getter, typename Setter, typename... Extras>
  433. handle bind_property(const handle& obj, const char* name, Getter&& getter_, Setter&& setter_, const Extras&... extras) {
  434. handle getter = none();
  435. handle setter = none();
  436. using Wrapper = pkpy::PyVar (*)(pkpy::VM*, pkpy::ArgsView);
  437. constexpr auto create = [](Wrapper wrapper, int argc, auto&& f) {
  438. if constexpr(need_host<remove_cvref_t<decltype(f)>>) {
  439. // otherwise, store it in the type_info
  440. void* data = interpreter::take_ownership(std::forward<decltype(f)>(f));
  441. // store the index in the object
  442. return vm->heap.gcnew<pkpy::NativeFunc>(vm->tp_native_func, wrapper, argc, data);
  443. } else {
  444. // if the function is trivially copyable and the size is less than 16 bytes, store it in the object
  445. // directly
  446. return vm->heap.gcnew<pkpy::NativeFunc>(vm->tp_native_func, wrapper, argc, f);
  447. }
  448. };
  449. getter = create(impl::getter_wrapper<std::decay_t<Getter>>, 1, std::forward<Getter>(getter_));
  450. if constexpr(!std::is_same_v<Setter, std::nullptr_t>) {
  451. setter = create(impl::setter_wrapper<std::decay_t<Setter>>, 2, std::forward<Setter>(setter_));
  452. }
  453. handle property = pybind11::property(getter, setter);
  454. setattr(obj, name, property);
  455. return property;
  456. }
  457. } // namespace pybind11::impl