1
0

cpp_function.h 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. #pragma once
  2. #include "cast.h"
  3. #include <bitset>
  4. namespace pybind11 {
  5. template <std::size_t Nurse, std::size_t... Patients>
  6. struct keep_alive {};
  7. template <typename T>
  8. struct call_guard {
  9. static_assert(std::is_default_constructible_v<T>,
  10. "call_guard must be default constructible");
  11. };
  12. // append the overload to the beginning of the overload list
  13. struct prepend {};
  14. template <typename... Args>
  15. struct init {};
  16. // TODO: support more customized tags
  17. // struct kw_only {};
  18. //
  19. // struct pos_only {};
  20. //
  21. // struct default_arg {};
  22. //
  23. // struct arg {
  24. // const char* name;
  25. // const char* description;
  26. // };
  27. //
  28. // struct default_arg {
  29. // const char* name;
  30. // const char* description;
  31. // const char* value;
  32. // };
  33. template <typename Fn,
  34. typename Extra,
  35. typename Args = callable_args_t<std::decay_t<Fn>>,
  36. typename IndexSequence = std::make_index_sequence<std::tuple_size_v<Args>>>
  37. struct generator;
  38. class function_record {
  39. union {
  40. void* data;
  41. char buffer[16];
  42. };
  43. // TODO: optimize the function_record size to reduce memory usage
  44. const char* name;
  45. function_record* next;
  46. void (*destructor)(function_record*);
  47. return_value_policy policy = return_value_policy::automatic;
  48. handle (*wrapper)(function_record&, pkpy::ArgsView, bool convert, handle parent);
  49. template <typename Fn, typename Extra, typename Args, typename IndexSequence>
  50. friend struct generator;
  51. public:
  52. template <typename Fn, typename... Extras>
  53. function_record(Fn&& f, const char* name, const Extras&... extras) :
  54. name(name), next(nullptr) {
  55. if constexpr(sizeof(f) <= sizeof(buffer)) {
  56. new (buffer) auto(std::forward<Fn>(f));
  57. destructor = [](function_record* self) {
  58. reinterpret_cast<Fn*>(self->buffer)->~Fn();
  59. };
  60. } else {
  61. data = new auto(std::forward<Fn>(f));
  62. destructor = [](function_record* self) { delete static_cast<Fn*>(self->data); };
  63. }
  64. using Generator = generator<std::decay_t<Fn>, std::tuple<Extras...>>;
  65. Generator::initialize(*this, extras...);
  66. wrapper = Generator::generate();
  67. }
  68. ~function_record() { destructor(this); }
  69. template <typename Fn>
  70. auto& cast() {
  71. if constexpr(sizeof(Fn) <= sizeof(buffer)) {
  72. return *reinterpret_cast<Fn*>(buffer);
  73. } else {
  74. return *static_cast<Fn*>(data);
  75. }
  76. }
  77. void append(function_record* record) {
  78. function_record* p = this;
  79. while(p->next != nullptr) {
  80. p = p->next;
  81. }
  82. p->next = record;
  83. }
  84. handle operator() (pkpy::ArgsView view) {
  85. function_record* p = this;
  86. // foreach function record and call the function with not convert
  87. while(p != nullptr) {
  88. handle result = p->wrapper(*this, view, false, {});
  89. if(result) {
  90. return result;
  91. }
  92. p = p->next;
  93. }
  94. p = this;
  95. // foreach function record and call the function with convert
  96. while(p != nullptr) {
  97. handle result = p->wrapper(*this, view, true, {});
  98. if(result) {
  99. return result;
  100. }
  101. p = p->next;
  102. }
  103. vm->TypeError("no matching function found");
  104. }
  105. };
  106. template <typename Fn, std::size_t... Is, typename... Args>
  107. handle invoke(Fn&& fn,
  108. std::index_sequence<Is...>,
  109. std::tuple<type_caster<Args>...>& casters,
  110. return_value_policy policy,
  111. handle parent) {
  112. using underlying_type = std::decay_t<Fn>;
  113. using ret = callable_return_t<underlying_type>;
  114. // if the return type is void, return None
  115. if constexpr(std::is_void_v<ret>) {
  116. // resolve the member function pointer
  117. if constexpr(std::is_member_function_pointer_v<underlying_type>) {
  118. [&](class_type_t<underlying_type>& self, auto&... args) {
  119. (self.*fn)(args...);
  120. }(std::get<Is>(casters).value...);
  121. } else {
  122. fn(std::get<Is>(casters).value...);
  123. }
  124. return vm->None;
  125. } else {
  126. // resolve the member function pointer
  127. if constexpr(std::is_member_function_pointer_v<remove_cvref_t<Fn>>) {
  128. return type_caster<ret>::cast(
  129. [&](class_type_t<underlying_type>& self, auto&... args) {
  130. return (self.*fn)(args...);
  131. }(std::get<Is>(casters).value...),
  132. policy,
  133. parent);
  134. } else {
  135. return type_caster<ret>::cast(fn(std::get<Is>(casters).value...), policy, parent);
  136. }
  137. }
  138. }
  139. template <typename Fn, typename... Args, std::size_t... Is, typename... Extras>
  140. struct generator<Fn, std::tuple<Extras...>, std::tuple<Args...>, std::index_sequence<Is...>> {
  141. static void initialize(function_record& record, const Extras&... extras) {}
  142. static auto generate() {
  143. return +[](function_record& self, pkpy::ArgsView view, bool convert, handle parent) {
  144. // FIXME:
  145. // Temporarily, args and kwargs must be at the end of the arguments list
  146. // Named arguments are not supported yet
  147. constexpr bool has_args = types_count_v<args, remove_cvref_t<Args>...> != 0;
  148. constexpr bool has_kwargs = types_count_v<kwargs, remove_cvref_t<Args>...> != 0;
  149. constexpr std::size_t count = sizeof...(Args) - has_args - has_kwargs;
  150. handle stack[sizeof...(Args)] = {};
  151. // initialize the stack
  152. if(!has_args && (view.size() != count)) {
  153. return handle();
  154. }
  155. if(has_args && (view.size() < count)) {
  156. return handle();
  157. }
  158. for(std::size_t i = 0; i < count; ++i) {
  159. stack[i] = view[i];
  160. }
  161. // pack the args and kwargs
  162. if constexpr(has_args) {
  163. const auto n = view.size() - count;
  164. pkpy::PyVar var = vm->new_object<pkpy::Tuple>(vm->tp_tuple, n);
  165. auto& tuple = var.obj_get<pkpy::Tuple>();
  166. for(std::size_t i = 0; i < n; ++i) {
  167. tuple[i] = view[count + i];
  168. }
  169. stack[count] = var;
  170. }
  171. if constexpr(has_kwargs) {
  172. const auto n = vm->s_data._sp - view.end();
  173. pkpy::PyVar var = vm->new_object<pkpy::Dict>(vm->tp_dict);
  174. auto& dict = var.obj_get<pkpy::Dict>();
  175. for(std::size_t i = 0; i < n; i += 2) {
  176. pkpy::i64 index = pkpy::_py_cast<pkpy::i64>(vm, view[count + i]);
  177. pkpy::PyVar str =
  178. vm->new_object<pkpy::Str>(vm->tp_str, pkpy::StrName(index).sv());
  179. dict.set(vm, str, view[count + i + 1]);
  180. }
  181. stack[count + 1] = var;
  182. }
  183. // check if all the arguments are not valid
  184. for(std::size_t i = 0; i < sizeof...(Args); ++i) {
  185. if(!stack[i]) {
  186. return handle();
  187. }
  188. }
  189. // ok, all the arguments are valid, call the function
  190. std::tuple<type_caster<Args>...> casters;
  191. // check type compatibility
  192. if(((std::get<Is>(casters).load(stack[Is], convert)) && ...)) {
  193. return invoke(self.cast<Fn>(),
  194. std::index_sequence<Is...>{},
  195. casters,
  196. self.policy,
  197. parent);
  198. }
  199. return handle();
  200. };
  201. }
  202. };
  203. constexpr inline static auto _wrapper = +[](pkpy::VM*, pkpy::ArgsView view) {
  204. auto& record = pkpy::lambda_get_userdata<function_record>(view.begin());
  205. return record(view).ptr();
  206. };
  207. class cpp_function : public function {
  208. public:
  209. template <typename Fn, typename... Extras>
  210. cpp_function(Fn&& f, const Extras&... extras) {
  211. pkpy::any userdata = function_record(std::forward<Fn>(f), "anonymous", extras...);
  212. m_ptr = vm->bind_func(nullptr, "", -1, _wrapper, std::move(userdata));
  213. inc_ref();
  214. }
  215. };
  216. template <typename Fn, typename... Extras>
  217. handle bind_function(const handle& obj,
  218. const char* name,
  219. Fn&& fn,
  220. pkpy::BindType type,
  221. const Extras&... extras) {
  222. // do not use cpp_function directly to avoid unnecessary reference count change
  223. pkpy::PyVar var = obj.ptr();
  224. pkpy::PyVar callable = var->attr().try_get(name);
  225. // if the function is not bound yet, bind it
  226. if(!callable) {
  227. pkpy::any userdata = function_record(std::forward<Fn>(fn), name, extras...);
  228. callable = vm->bind_func(var, name, -1, _wrapper, std::move(userdata));
  229. } else {
  230. auto& userdata = callable.obj_get<pkpy::NativeFunc>()._userdata;
  231. function_record* record = new function_record(std::forward<Fn>(fn), name, extras...);
  232. constexpr bool is_prepend = (types_count_v<prepend, Extras...> != 0);
  233. if constexpr(is_prepend) {
  234. // if prepend is specified, append the new record to the beginning of the list
  235. function_record* last = (function_record*)userdata.data;
  236. userdata.data = record;
  237. record->append(last);
  238. } else {
  239. // otherwise, append the new record to the end of the list
  240. function_record* last = (function_record*)userdata.data;
  241. last->append(record);
  242. }
  243. }
  244. return callable;
  245. }
  246. template <typename Getter_, typename Setter_, typename... Extras>
  247. handle bind_property(const handle& obj,
  248. const char* name,
  249. Getter_&& getter_,
  250. Setter_&& setter_,
  251. const Extras&... extras) {
  252. pkpy::PyVar var = obj.ptr();
  253. pkpy::PyVar getter = vm->None;
  254. pkpy::PyVar setter = vm->None;
  255. using Getter = std::decay_t<Getter_>;
  256. using Setter = std::decay_t<Setter_>;
  257. getter = vm->new_object<pkpy::NativeFunc>(
  258. vm->tp_native_func,
  259. [](pkpy::VM* vm, pkpy::ArgsView view) -> pkpy::PyVar {
  260. auto& getter = pkpy::lambda_get_userdata<Getter>(view.begin());
  261. if constexpr(std::is_member_pointer_v<Getter>) {
  262. using Self = class_type_t<Getter>;
  263. auto& self = _builtin_cast<instance>(view[0]).cast<Self>();
  264. if constexpr(std::is_member_object_pointer_v<Getter>) {
  265. return type_caster<member_type_t<Getter>>::cast(
  266. self.*getter,
  267. return_value_policy::reference_internal,
  268. view[0])
  269. .ptr();
  270. } else {
  271. return type_caster<callable_return_t<Getter>>::cast(
  272. (self.*getter)(),
  273. return_value_policy::reference_internal,
  274. view[0])
  275. .ptr();
  276. }
  277. } else {
  278. using Self = std::tuple_element_t<0, callable_args_t<Getter>>;
  279. auto& self = _builtin_cast<instance>(view[0]).cast<Self>();
  280. return type_caster<callable_return_t<Getter>>::cast(
  281. getter(self),
  282. return_value_policy::reference_internal,
  283. view[0])
  284. .ptr();
  285. }
  286. },
  287. 1,
  288. std::forward<Getter_>(getter_));
  289. if constexpr(!std::is_same_v<Setter, std::nullptr_t>) {
  290. setter = vm->new_object<pkpy::NativeFunc>(
  291. vm->tp_native_func,
  292. [](pkpy::VM* vm, pkpy::ArgsView view) -> pkpy::PyVar {
  293. auto& setter = pkpy::lambda_get_userdata<Setter>(view.begin());
  294. if constexpr(std::is_member_pointer_v<Setter>) {
  295. using Self = class_type_t<Setter>;
  296. auto& self = _builtin_cast<instance>(view[0]).cast<Self>();
  297. if constexpr(std::is_member_object_pointer_v<Setter>) {
  298. type_caster<member_type_t<Setter>> caster;
  299. if(caster.load(view[1], true)) {
  300. self.*setter = caster.value;
  301. return vm->None;
  302. }
  303. } else {
  304. type_caster<std::tuple_element_t<1, callable_args_t<Setter>>> caster;
  305. if(caster.load(view[1], true)) {
  306. (self.*setter)(caster.value);
  307. return vm->None;
  308. }
  309. }
  310. } else {
  311. using Self = std::tuple_element_t<0, callable_args_t<Setter>>;
  312. auto& self = _builtin_cast<instance>(view[0]).cast<Self>();
  313. type_caster<std::tuple_element_t<1, callable_args_t<Setter>>> caster;
  314. if(caster.load(view[1], true)) {
  315. setter(self, caster.value);
  316. return vm->None;
  317. }
  318. }
  319. vm->TypeError("invalid argument");
  320. },
  321. 2,
  322. std::forward<Setter_>(setter_));
  323. }
  324. pkpy::PyVar property = vm->new_object<pkpy::Property>(vm->tp_property, getter, setter);
  325. var->attr().set(name, property);
  326. return property;
  327. }
  328. } // namespace pybind11