lua_bridge.hpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. #pragma once
  2. #include "pocketpy.h"
  3. extern "C"{
  4. #include "lua.h"
  5. #include "lauxlib.h"
  6. }
  7. namespace pkpy{
  8. /******************************************************************/
  9. void initialize_lua_bridge(VM* vm, lua_State* newL);
  10. /******************************************************************/
  11. static lua_State* _L;
  12. static void lua_push_from_python(VM*, PyObject*);
  13. static PyObject* lua_popx_to_python(VM*);
  14. template<typename T>
  15. static void table_apply(VM* vm, T f){
  16. PK_ASSERT(lua_istable(_L, -1));
  17. lua_pushnil(_L); // [key]
  18. while(lua_next(_L, -2) != 0){ // [key, val]
  19. lua_pushvalue(_L, -2); // [key, val, key]
  20. PyObject* key = lua_popx_to_python(vm);
  21. PyObject* val = lua_popx_to_python(vm);
  22. f(key, val); // [key]
  23. }
  24. lua_pop(_L, 1); // []
  25. }
  26. struct LuaExceptionGuard{
  27. int base_size;
  28. LuaExceptionGuard(){ base_size = lua_gettop(_L); }
  29. ~LuaExceptionGuard(){
  30. int delta = lua_gettop(_L) - base_size;
  31. if(delta > 0) lua_pop(_L, delta);
  32. }
  33. };
  34. #define LUA_PROTECTED(__B) { LuaExceptionGuard __guard; __B; }
  35. struct PyLuaObject{
  36. PK_ALWAYS_PASS_BY_POINTER(PyLuaObject)
  37. int r;
  38. PyLuaObject(){ r = luaL_ref(_L, LUA_REGISTRYINDEX); }
  39. ~PyLuaObject(){ luaL_unref(_L, LUA_REGISTRYINDEX, r); }
  40. };
  41. struct PyLuaTable: PyLuaObject{
  42. PY_CLASS(PyLuaTable, lua, Table)
  43. static void _register(VM* vm, PyObject* mod, PyObject* type){
  44. Type t = PK_OBJ_GET(Type, type);
  45. PyTypeInfo* ti = &vm->_all_types[t];
  46. ti->m__getattr__ = [](VM* vm, PyObject* obj, StrName name){
  47. const PyLuaTable& self = _CAST(PyLuaTable&, obj);
  48. LUA_PROTECTED(
  49. lua_rawgeti(_L, LUA_REGISTRYINDEX, self.r);
  50. lua_pushstring(_L, std::string(name.sv()).c_str());
  51. lua_gettable(_L, -2);
  52. PyObject* ret = lua_popx_to_python(vm);
  53. lua_pop(_L, 1);
  54. return ret;
  55. )
  56. };
  57. ti->m__setattr__ = [](VM* vm, PyObject* obj, StrName name, PyObject* val){
  58. const PyLuaTable& self = _CAST(PyLuaTable&, obj);
  59. LUA_PROTECTED(
  60. lua_rawgeti(_L, LUA_REGISTRYINDEX, self.r);
  61. lua_pushstring(_L, std::string(name.sv()).c_str());
  62. lua_push_from_python(vm, val);
  63. lua_settable(_L, -3);
  64. lua_pop(_L, 1);
  65. )
  66. };
  67. vm->bind_constructor<1>(type, [](VM* vm, ArgsView args){
  68. lua_newtable(_L); // push an empty table onto the stack
  69. PyObject* obj = vm->heap.gcnew<PyLuaTable>(PyLuaTable::_type(vm));
  70. return obj;
  71. });
  72. vm->bind__len__(t, [](VM* vm, PyObject* obj){
  73. const PyLuaTable& self = _CAST(PyLuaTable&, obj);
  74. lua_rawgeti(_L, LUA_REGISTRYINDEX, self.r);
  75. i64 len = 0;
  76. lua_pushnil(_L);
  77. while(lua_next(_L, -2) != 0){ len += 1; lua_pop(_L, 1); }
  78. lua_pop(_L, 1);
  79. return len;
  80. });
  81. vm->bind__getitem__(t, [](VM* vm, PyObject* obj, PyObject* key){
  82. const PyLuaTable& self = _CAST(PyLuaTable&, obj);
  83. LUA_PROTECTED(
  84. lua_rawgeti(_L, LUA_REGISTRYINDEX, self.r);
  85. lua_push_from_python(vm, key);
  86. lua_gettable(_L, -2);
  87. PyObject* ret = lua_popx_to_python(vm);
  88. lua_pop(_L, 1);
  89. return ret;
  90. )
  91. });
  92. vm->bind__setitem__(t, [](VM* vm, PyObject* obj, PyObject* key, PyObject* val){
  93. const PyLuaTable& self = _CAST(PyLuaTable&, obj);
  94. LUA_PROTECTED(
  95. lua_rawgeti(_L, LUA_REGISTRYINDEX, self.r);
  96. lua_push_from_python(vm, key);
  97. lua_push_from_python(vm, val);
  98. lua_settable(_L, -3);
  99. lua_pop(_L, 1);
  100. )
  101. });
  102. vm->bind__delitem__(t, [](VM* vm, PyObject* obj, PyObject* key){
  103. const PyLuaTable& self = _CAST(PyLuaTable&, obj);
  104. LUA_PROTECTED(
  105. lua_rawgeti(_L, LUA_REGISTRYINDEX, self.r);
  106. lua_push_from_python(vm, key);
  107. lua_pushnil(_L);
  108. lua_settable(_L, -3);
  109. lua_pop(_L, 1);
  110. )
  111. });
  112. vm->bind__contains__(t, [](VM* vm, PyObject* obj, PyObject* key){
  113. const PyLuaTable& self = _CAST(PyLuaTable&, obj);
  114. LUA_PROTECTED(
  115. lua_rawgeti(_L, LUA_REGISTRYINDEX, self.r);
  116. lua_push_from_python(vm, key);
  117. lua_gettable(_L, -2);
  118. bool ret = lua_isnil(_L, -1) == 0;
  119. lua_pop(_L, 2);
  120. return ret ? vm->True : vm->False;
  121. )
  122. });
  123. vm->bind(type, "keys(self) -> list", [](VM* vm, ArgsView args){
  124. const PyLuaTable& self = _CAST(PyLuaTable&, args[0]);
  125. LUA_PROTECTED(
  126. lua_rawgeti(_L, LUA_REGISTRYINDEX, self.r);
  127. List ret;
  128. table_apply(vm, [&](PyObject* key, PyObject* val){ ret.push_back(key); });
  129. lua_pop(_L, 1);
  130. return VAR(std::move(ret));
  131. )
  132. });
  133. vm->bind(type, "values(self) -> list", [](VM* vm, ArgsView args){
  134. const PyLuaTable& self = _CAST(PyLuaTable&, args[0]);
  135. LUA_PROTECTED(
  136. lua_rawgeti(_L, LUA_REGISTRYINDEX, self.r);
  137. List ret;
  138. table_apply(vm, [&](PyObject* key, PyObject* val){ ret.push_back(val); });
  139. lua_pop(_L, 1);
  140. return VAR(std::move(ret));
  141. )
  142. });
  143. vm->bind(type, "items(self) -> list[tuple]", [](VM* vm, ArgsView args){
  144. const PyLuaTable& self = _CAST(PyLuaTable&, args[0]);
  145. LUA_PROTECTED(
  146. lua_rawgeti(_L, LUA_REGISTRYINDEX, self.r);
  147. List ret;
  148. table_apply(vm, [&](PyObject* key, PyObject* val){
  149. PyObject* item = VAR(Tuple({key, val}));
  150. ret.push_back(item);
  151. });
  152. lua_pop(_L, 1);
  153. return VAR(std::move(ret));
  154. )
  155. });
  156. }
  157. };
  158. static PyObject* lua_popx_multi_to_python(VM* vm, int count){
  159. if(count == 0){
  160. return vm->None;
  161. }else if(count == 1){
  162. return lua_popx_to_python(vm);
  163. }else if(count > 1){
  164. Tuple ret(count);
  165. for(int i=0; i<count; i++){
  166. ret[i] = lua_popx_to_python(vm);
  167. }
  168. return VAR(std::move(ret));
  169. }
  170. PK_FATAL_ERROR()
  171. }
  172. struct PyLuaFunction: PyLuaObject{
  173. PY_CLASS(PyLuaFunction, lua, Function)
  174. static void _register(VM* vm, PyObject* mod, PyObject* type){
  175. vm->bind_notimplemented_constructor<PyLuaFunction>(type);
  176. vm->bind_method<-1>(type, "__call__", [](VM* vm, ArgsView args){
  177. if(args.size() < 1) vm->TypeError("__call__ takes at least 1 argument");
  178. const PyLuaFunction& self = _CAST(PyLuaFunction&, args[0]);
  179. int base_size = lua_gettop(_L);
  180. LUA_PROTECTED(
  181. lua_rawgeti(_L, LUA_REGISTRYINDEX, self.r);
  182. for(int i=1; i<args.size(); i++){
  183. lua_push_from_python(vm, args[i]);
  184. }
  185. if(lua_pcall(_L, args.size()-1, LUA_MULTRET, 0)){
  186. const char* error = lua_tostring(_L, -1);
  187. lua_pop(_L, 1);
  188. vm->RuntimeError(error);
  189. }
  190. return lua_popx_multi_to_python(vm, lua_gettop(_L) - base_size);
  191. )
  192. });
  193. }
  194. };
  195. void lua_push_from_python(VM* vm, PyObject* val){
  196. if(val == vm->None){
  197. lua_pushnil(_L);
  198. return;
  199. }
  200. Type t = vm->_tp(val);
  201. switch(t.index){
  202. case VM::tp_bool.index:
  203. lua_pushboolean(_L, val == vm->True);
  204. return;
  205. case VM::tp_int.index:
  206. lua_pushinteger(_L, _CAST(i64, val));
  207. return;
  208. case VM::tp_float.index:
  209. lua_pushnumber(_L, _CAST(f64, val));
  210. return;
  211. case VM::tp_str.index:
  212. lua_pushstring(_L, _CAST(CString, val));
  213. return;
  214. }
  215. if(is_non_tagged_type(val, PyLuaTable::_type(vm))){
  216. const PyLuaTable& table = _CAST(PyLuaTable&, val);
  217. lua_rawgeti(_L, LUA_REGISTRYINDEX, table.r);
  218. return;
  219. }
  220. if(is_non_tagged_type(val, PyLuaFunction::_type(vm))){
  221. const PyLuaFunction& func = _CAST(PyLuaFunction&, val);
  222. lua_rawgeti(_L, LUA_REGISTRYINDEX, func.r);
  223. return;
  224. }
  225. vm->RuntimeError(fmt("unsupported python type: ", obj_type_name(vm, t).escape()));
  226. }
  227. PyObject* lua_popx_to_python(VM* vm) {
  228. int type = lua_type(_L, -1);
  229. switch (type) {
  230. case LUA_TNIL: {
  231. lua_pop(_L, 1);
  232. return vm->None;
  233. }
  234. case LUA_TBOOLEAN: {
  235. bool val = lua_toboolean(_L, -1);
  236. lua_pop(_L, 1);
  237. return val ? vm->True : vm->False;
  238. }
  239. case LUA_TNUMBER: {
  240. double val = lua_tonumber(_L, -1);
  241. lua_pop(_L, 1);
  242. return VAR(val);
  243. }
  244. case LUA_TSTRING: {
  245. const char* val = lua_tostring(_L, -1);
  246. lua_pop(_L, 1);
  247. return VAR(val);
  248. }
  249. case LUA_TTABLE: {
  250. PyObject* obj = vm->heap.gcnew<PyLuaTable>(PyLuaTable::_type(vm));
  251. return obj;
  252. }
  253. case LUA_TFUNCTION: {
  254. PyObject* obj = vm->heap.gcnew<PyLuaFunction>(PyLuaFunction::_type(vm));
  255. return obj;
  256. }
  257. default: {
  258. const char* type_name = lua_typename(_L, type);
  259. lua_pop(_L, 1);
  260. vm->RuntimeError(fmt("unsupported lua type: '", type_name, "'"));
  261. }
  262. }
  263. PK_UNREACHABLE()
  264. }
  265. void initialize_lua_bridge(VM* vm, lua_State* newL){
  266. PyObject* mod = vm->new_module("lua");
  267. if(_L != nullptr){
  268. throw std::runtime_error("lua bridge already initialized");
  269. }
  270. _L = newL;
  271. PyLuaTable::register_class(vm, mod);
  272. PyLuaFunction::register_class(vm, mod);
  273. vm->bind(mod, "dostring(__source: str)", [](VM* vm, ArgsView args){
  274. const char* source = CAST(CString, args[0]);
  275. int base_size = lua_gettop(_L);
  276. if (luaL_dostring(_L, source)) {
  277. const char* error = lua_tostring(_L, -1);
  278. lua_pop(_L, 1); // pop error message from the stack
  279. vm->RuntimeError(error);
  280. }
  281. return lua_popx_multi_to_python(vm, lua_gettop(_L) - base_size);
  282. });
  283. }
  284. } // namespace pkpy