blueloveTH 2 years ago
parent
commit
cd1280d350
7 changed files with 369 additions and 37 deletions
  1. 320 0
      3rd/lua_bridge.hpp
  2. 0 1
      CMakeLists.txt
  3. 1 1
      compile_flags.txt
  4. 5 2
      include/pocketpy/vm.h
  5. 1 1
      src/ceval.cpp
  6. 6 9
      src/pocketpy.cpp
  7. 36 23
      src/vm.cpp

+ 320 - 0
3rd/lua_bridge.hpp

@@ -0,0 +1,320 @@
+#pragma once
+
+#include "pocketpy.h"
+
+extern "C"{
+    #include "lua.h"
+    #include "lauxlib.h"
+}
+
+namespace pkpy{
+
+/******************************************************************/
+
+void initialize_lua_bridge(VM* vm, lua_State* newL);
+
+/******************************************************************/
+
+lua_State* _L;
+void lua_push_from_python(VM*, PyObject*);
+PyObject* lua_popx_to_python(VM*);
+
+template<typename T>
+static void table_apply(VM* vm, T f){
+    PK_ASSERT(lua_istable(_L, -1));
+    lua_pushnil(_L);                 // [key]
+    while(lua_next(_L, -2) != 0){    // [key, val]
+        lua_pushvalue(_L, -2);       // [key, val, key]
+        PyObject* key = lua_popx_to_python(vm);
+        PyObject* val = lua_popx_to_python(vm);
+        f(key, val);                // [key]
+    }
+    lua_pop(_L, 1);                  // []
+}
+
+struct LuaExceptionGuard{
+    int base_size;
+    LuaExceptionGuard(){ base_size = lua_gettop(_L); }
+    ~LuaExceptionGuard(){
+        int delta = lua_gettop(_L) - base_size;
+        if(delta > 0) lua_pop(_L, delta);
+    }
+};
+
+#define LUA_PROTECTED(__B) { LuaExceptionGuard __guard; __B; }
+
+struct PyLuaObject{
+    PK_ALWAYS_PASS_BY_POINTER(PyLuaObject)
+    int r;
+    PyLuaObject(){ r = luaL_ref(_L, LUA_REGISTRYINDEX); }
+    ~PyLuaObject(){ luaL_unref(_L, LUA_REGISTRYINDEX, r); }
+};
+
+struct PyLuaTable: PyLuaObject{
+    PY_CLASS(PyLuaTable, lua, Table)
+
+    static void _register(VM* vm, PyObject* mod, PyObject* type){
+        Type t = PK_OBJ_GET(Type, type);
+        PyTypeInfo* ti = &vm->_all_types[t];
+        ti->m__getattr__ = [](VM* vm, PyObject* obj, StrName name){
+            const PyLuaTable& self = _CAST(PyLuaTable&, obj);
+            LUA_PROTECTED(
+                lua_rawgeti(_L, LUA_REGISTRYINDEX, self.r);
+                lua_pushstring(_L, std::string(name.sv()).c_str());
+                lua_gettable(_L, -2);
+                PyObject* ret = lua_popx_to_python(vm);
+                lua_pop(_L, 1);
+                return ret;
+            )
+        };
+
+        ti->m__setattr__ = [](VM* vm, PyObject* obj, StrName name, PyObject* val){
+            const PyLuaTable& self = _CAST(PyLuaTable&, obj);
+            LUA_PROTECTED(
+                lua_rawgeti(_L, LUA_REGISTRYINDEX, self.r);
+                lua_pushstring(_L, std::string(name.sv()).c_str());
+                lua_push_from_python(vm, val);
+                lua_settable(_L, -3);
+                lua_pop(_L, 1);
+            )
+        };
+
+        vm->bind_constructor<1>(type, [](VM* vm, ArgsView args){
+            lua_newtable(_L);    // push an empty table onto the stack
+            PyObject* obj = vm->heap.gcnew<PyLuaTable>(PyLuaTable::_type(vm));
+            return obj;
+        });
+
+        vm->bind__len__(t, [](VM* vm, PyObject* obj){
+            const PyLuaTable& self = _CAST(PyLuaTable&, obj);
+            lua_rawgeti(_L, LUA_REGISTRYINDEX, self.r);
+            i64 len = 0;
+            lua_pushnil(_L);
+            while(lua_next(_L, -2) != 0){ len += 1; lua_pop(_L, 1); }
+            lua_pop(_L, 1);
+            return len;
+        });
+
+        vm->bind__getitem__(t, [](VM* vm, PyObject* obj, PyObject* key){
+            const PyLuaTable& self = _CAST(PyLuaTable&, obj);
+            LUA_PROTECTED(
+                lua_rawgeti(_L, LUA_REGISTRYINDEX, self.r);
+                lua_push_from_python(vm, key);
+                lua_gettable(_L, -2);
+                PyObject* ret = lua_popx_to_python(vm);
+                lua_pop(_L, 1);
+                return ret;
+            )
+        });
+
+        vm->bind__setitem__(t, [](VM* vm, PyObject* obj, PyObject* key, PyObject* val){
+            const PyLuaTable& self = _CAST(PyLuaTable&, obj);
+            LUA_PROTECTED(
+                lua_rawgeti(_L, LUA_REGISTRYINDEX, self.r);
+                lua_push_from_python(vm, key);
+                lua_push_from_python(vm, val);
+                lua_settable(_L, -3);
+                lua_pop(_L, 1);
+            )
+        });
+
+        vm->bind__delitem__(t, [](VM* vm, PyObject* obj, PyObject* key){
+            const PyLuaTable& self = _CAST(PyLuaTable&, obj);
+            LUA_PROTECTED(
+                lua_rawgeti(_L, LUA_REGISTRYINDEX, self.r);
+                lua_push_from_python(vm, key);
+                lua_pushnil(_L);
+                lua_settable(_L, -3);
+                lua_pop(_L, 1);
+            )
+        });
+
+        vm->bind__contains__(t, [](VM* vm, PyObject* obj, PyObject* key){
+            const PyLuaTable& self = _CAST(PyLuaTable&, obj);
+            LUA_PROTECTED(
+                lua_rawgeti(_L, LUA_REGISTRYINDEX, self.r);
+                lua_push_from_python(vm, key);
+                lua_gettable(_L, -2);
+                bool ret = lua_isnil(_L, -1) == 0;
+                lua_pop(_L, 2);
+                return ret ? vm->True : vm->False;
+            )
+        });
+
+        vm->bind(type, "keys(self) -> list", [](VM* vm, ArgsView args){
+            const PyLuaTable& self = _CAST(PyLuaTable&, args[0]);
+            LUA_PROTECTED(
+                lua_rawgeti(_L, LUA_REGISTRYINDEX, self.r);
+                List ret;
+                table_apply(vm, [&](PyObject* key, PyObject* val){ ret.push_back(key); });
+                lua_pop(_L, 1);
+                return VAR(std::move(ret));
+            )
+        });
+
+        vm->bind(type, "values(self) -> list", [](VM* vm, ArgsView args){
+            const PyLuaTable& self = _CAST(PyLuaTable&, args[0]);
+            LUA_PROTECTED(
+                lua_rawgeti(_L, LUA_REGISTRYINDEX, self.r);
+                List ret;
+                table_apply(vm, [&](PyObject* key, PyObject* val){ ret.push_back(val); });
+                lua_pop(_L, 1);
+                return VAR(std::move(ret));
+            )
+        });
+
+        vm->bind(type, "items(self) -> list[tuple]", [](VM* vm, ArgsView args){
+            const PyLuaTable& self = _CAST(PyLuaTable&, args[0]);
+            LUA_PROTECTED(
+                lua_rawgeti(_L, LUA_REGISTRYINDEX, self.r);
+                List ret;
+                table_apply(vm, [&](PyObject* key, PyObject* val){
+                    PyObject* item = VAR(Tuple({key, val}));
+                    ret.push_back(item);
+                });
+                lua_pop(_L, 1);
+                return VAR(std::move(ret));
+            )
+        });
+    }
+};
+
+static PyObject* lua_popx_multi_to_python(VM* vm, int count){
+    if(count == 0){
+        return vm->None;
+    }else if(count == 1){
+        return lua_popx_to_python(vm);
+    }else if(count > 1){
+        Tuple ret(count);
+        for(int i=0; i<count; i++){
+            ret[i] = lua_popx_to_python(vm);
+        }
+        return VAR(std::move(ret));
+    }
+    PK_FATAL_ERROR()
+}
+
+struct PyLuaFunction: PyLuaObject{
+    PY_CLASS(PyLuaFunction, lua, Function)
+
+    static void _register(VM* vm, PyObject* mod, PyObject* type){
+        vm->bind_notimplemented_constructor<PyLuaFunction>(type);
+
+        vm->bind_method<-1>(type, "__call__", [](VM* vm, ArgsView args){
+            if(args.size() < 1) vm->TypeError("__call__ takes at least 1 argument");
+            const PyLuaFunction& self = _CAST(PyLuaFunction&, args[0]);
+            int base_size = lua_gettop(_L);
+            LUA_PROTECTED(
+                lua_rawgeti(_L, LUA_REGISTRYINDEX, self.r);
+                for(int i=1; i<args.size(); i++){
+                    lua_push_from_python(vm, args[i]);
+                }
+                if(lua_pcall(_L, args.size()-1, LUA_MULTRET, 0)){
+                    const char* error = lua_tostring(_L, -1);
+                    lua_pop(_L, 1);
+                    vm->RuntimeError(error);
+                } 
+                return lua_popx_multi_to_python(vm, lua_gettop(_L) - base_size);
+            )
+        });
+    }
+};
+
+void lua_push_from_python(VM* vm, PyObject* val){
+    if(val == vm->None){
+        lua_pushnil(_L);
+        return;
+    }
+    Type t = vm->_tp(val);
+    switch(t.index){
+        case VM::tp_bool.index:
+            lua_pushboolean(_L, val == vm->True);
+            return;
+        case VM::tp_int.index:
+            lua_pushinteger(_L, _CAST(i64, val));
+            return;
+        case VM::tp_float.index:
+            lua_pushnumber(_L, _CAST(f64, val));
+            return;
+        case VM::tp_str.index:
+            lua_pushstring(_L, _CAST(CString, val));
+            return;
+    }
+
+    if(is_non_tagged_type(val, PyLuaTable::_type(vm))){
+        const PyLuaTable& table = _CAST(PyLuaTable&, val);
+        lua_rawgeti(_L, LUA_REGISTRYINDEX, table.r);
+        return;
+    }
+
+    if(is_non_tagged_type(val, PyLuaFunction::_type(vm))){
+        const PyLuaFunction& func = _CAST(PyLuaFunction&, val);
+        lua_rawgeti(_L, LUA_REGISTRYINDEX, func.r);
+        return;
+    }
+    vm->RuntimeError(fmt("unsupported python type: ", obj_type_name(vm, t).escape()));
+}
+
+PyObject* lua_popx_to_python(VM* vm) {
+    int type = lua_type(_L, -1);
+    switch (type) {
+        case LUA_TNIL: {
+            lua_pop(_L, 1);
+            return vm->None;
+        }
+        case LUA_TBOOLEAN: {
+            bool val = lua_toboolean(_L, -1);
+            lua_pop(_L, 1);
+            return val ? vm->True : vm->False;
+        }
+        case LUA_TNUMBER: {
+            double val = lua_tonumber(_L, -1);
+            lua_pop(_L, 1);
+            return VAR(val);
+        }
+        case LUA_TSTRING: {
+            const char* val = lua_tostring(_L, -1);
+            lua_pop(_L, 1);
+            return VAR(val);
+        }
+        case LUA_TTABLE: {
+            PyObject* obj = vm->heap.gcnew<PyLuaTable>(PyLuaTable::_type(vm));
+            return obj;
+        }
+        case LUA_TFUNCTION: {
+            PyObject* obj = vm->heap.gcnew<PyLuaFunction>(PyLuaFunction::_type(vm));
+            return obj;
+        }
+        default: {
+            const char* type_name = lua_typename(_L, type);
+            lua_pop(_L, 1);
+            vm->RuntimeError(fmt("unsupported lua type: '", type_name, "'"));
+        }
+    }
+    PK_UNREACHABLE()
+}
+
+void initialize_lua_bridge(VM* vm, lua_State* newL){
+    PyObject* mod = vm->new_module("lua");
+
+    if(_L != nullptr){
+        throw std::runtime_error("lua bridge already initialized");
+    }
+    _L = newL;
+
+    PyLuaTable::register_class(vm, mod);
+    PyLuaFunction::register_class(vm, mod);
+
+    vm->bind(mod, "dostring(__source: str)", [](VM* vm, ArgsView args){
+        const char* source = CAST(CString, args[0]);
+        int base_size = lua_gettop(_L);
+        if (luaL_dostring(_L, source)) {
+            const char* error = lua_tostring(_L, -1);
+            lua_pop(_L, 1);  // pop error message from the stack
+            vm->RuntimeError(error);
+        }
+        return lua_popx_multi_to_python(vm, lua_gettop(_L) - base_size);
+    });
+}
+
+}   // namespace pkpy

+ 0 - 1
CMakeLists.txt

@@ -41,7 +41,6 @@ aux_source_directory(${CMAKE_CURRENT_LIST_DIR}/src POCKETPY_SRC)
 option(PK_USE_CJSON "Use cJSON" OFF)
 if(PK_USE_CJSON)
     add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/3rd/cjson)
-    include_directories(${CMAKE_CURRENT_LIST_DIR}/3rd/cjson/include)
     add_definitions(-DPK_USE_CJSON)
 endif()
 

+ 1 - 1
compile_flags.txt

@@ -4,5 +4,5 @@
 -std=c++17
 -stdlib=libc++
 -Iinclude/
--I3rd/box2d/include/
 -I3rd/cjson/include/
+

+ 5 - 2
include/pocketpy/vm.h

@@ -96,6 +96,10 @@ struct PyTypeInfo{
     PyObject* (*m__getitem__)(VM* vm, PyObject*, PyObject*) = nullptr;
     void (*m__setitem__)(VM* vm, PyObject*, PyObject*, PyObject*) = nullptr;
     void (*m__delitem__)(VM* vm, PyObject*, PyObject*) = nullptr;
+
+    // attributes
+    void (*m__setattr__)(VM* vm, PyObject*, StrName, PyObject*) = nullptr;
+    PyObject* (*m__getattr__)(VM* vm, PyObject*, StrName) = nullptr;
 };
 
 struct FrameId{
@@ -174,7 +178,7 @@ public:
     PyObject* py_json(PyObject* obj);
     PyObject* py_iter(PyObject* obj);
 
-    PyObject* find_name_in_mro(PyObject* cls, StrName name);
+    PyObject* find_name_in_mro(Type cls, StrName name);
     bool isinstance(PyObject* obj, Type base);
     bool issubclass(Type cls, Type base);
     PyObject* exec(Str source, Str filename, CompileMode mode, PyObject* _module=nullptr);
@@ -221,7 +225,6 @@ public:
 
     PyObject* new_type_object(PyObject* mod, StrName name, Type base, bool subclass_enabled=true);
     Type _new_type_object(StrName name, Type base=0, bool subclass_enabled=false);
-    PyTypeInfo* _type_info(Type type);
     const PyTypeInfo* _inst_type_info(PyObject* obj);
 
 #define BIND_UNARY_SPECIAL(name)                                                        \

+ 1 - 1
src/ceval.cpp

@@ -778,7 +778,7 @@ __NEXT_STEP:;
         PK_ASSERT(_curr_class != nullptr);
         StrName _name(byte.arg);
         Type type = PK_OBJ_GET(Type, _curr_class);
-        _type_info(type)->annotated_fields.push_back(_name);
+        _all_types[type].annotated_fields.push_back(_name);
     } DISPATCH();
     /*****************************************/
     TARGET(WITH_ENTER)

+ 6 - 9
src/pocketpy.cpp

@@ -1,11 +1,11 @@
 #include "pocketpy/pocketpy.h"
 
+namespace pkpy{
+
 #ifdef PK_USE_CJSON
-#include "cJSONw.hpp"
+void add_module_cjson(VM* vm);
 #endif
 
-namespace pkpy{
-
 void init_builtins(VM* _vm) {
 #define BIND_NUM_ARITH_OPT(name, op)                                                                    \
     _vm->bind##name(VM::tp_int, [](VM* vm, PyObject* lhs, PyObject* rhs) {                              \
@@ -142,8 +142,8 @@ void init_builtins(VM* _vm) {
     });
 
     _vm->bind_func<1>(_vm->builtins, "callable", [](VM* vm, ArgsView args) {
-        PyObject* cls = vm->_t(args[0]);
-        switch(PK_OBJ_GET(Type, cls).index){
+        Type cls = vm->_tp(args[0]);
+        switch(cls.index){
             case VM::tp_function.index: return vm->True;
             case VM::tp_native_func.index: return vm->True;
             case VM::tp_bound_method.index: return vm->True;
@@ -1619,7 +1619,7 @@ void VM::post_init(){
     });
 
     bind_property(_t(tp_type), "__annotations__", [](VM* vm, ArgsView args){
-        PyTypeInfo* ti = vm->_type_info(PK_OBJ_GET(Type, args[0]));
+        const PyTypeInfo* ti = &vm->_all_types[(PK_OBJ_GET(Type, args[0]))];
         Tuple t(ti->annotated_fields.size());
         for(int i=0; i<ti->annotated_fields.size(); i++){
             t[i] = VAR(ti->annotated_fields[i].sv());
@@ -1677,7 +1677,6 @@ void VM::post_init(){
         return VAR(MappingProxy(args[0]));
     });
 
-#if !PK_DEBUG_NO_BUILTINS
     add_module_sys(this);
     add_module_traceback(this);
     add_module_time(this);
@@ -1721,8 +1720,6 @@ void VM::post_init(){
 #ifdef PK_USE_CJSON
     add_module_cjson(this);
 #endif
-
-#endif
 }
 
 CodeObject_ VM::compile(const Str& source, const Str& filename, CompileMode mode, bool unknown_global_scope) {

+ 36 - 23
src/vm.cpp

@@ -131,14 +131,13 @@ namespace pkpy{
         callstack.pop();
     }
 
-    PyObject* VM::find_name_in_mro(PyObject* cls, StrName name){
+    PyObject* VM::find_name_in_mro(Type cls, StrName name){
         PyObject* val;
         do{
-            val = cls->attr().try_get(name);
+            val = _t(cls)->attr().try_get(name);
             if(val != nullptr) return val;
-            Type base = _all_types[PK_OBJ_GET(Type, cls)].base;
-            if(base.index == -1) break;
-            cls = _all_types[base].obj;
+            cls = _all_types[cls].base;
+            if(cls.index == -1) break;
         }while(true);
         return nullptr;
     }
@@ -219,10 +218,6 @@ namespace pkpy{
         return PK_OBJ_GET(Type, obj);
     }
 
-    PyTypeInfo* VM::_type_info(Type type){
-        return &_all_types[type];
-    }
-
     const PyTypeInfo* VM::_inst_type_info(PyObject* obj){
         if(is_int(obj)) return &_all_types[tp_int];
         if(is_float(obj)) return &_all_types[tp_float];
@@ -937,7 +932,7 @@ __FAST_CALL:
     if(is_non_tagged_type(callable, tp_type)){
         if(method_call) PK_FATAL_ERROR();
         // [type, NULL, args..., kwargs...]
-        PyObject* new_f = find_name_in_mro(callable, __new__);
+        PyObject* new_f = find_name_in_mro(PK_OBJ_GET(Type, callable), __new__);
         PyObject* obj;
 #if PK_DEBUG_EXTRA_CHECK
         PK_ASSERT(new_f != nullptr);
@@ -994,14 +989,14 @@ void VM::delattr(PyObject *_0, StrName _name){
 
 // https://docs.python.org/3/howto/descriptor.html#invocation-from-an-instance
 PyObject* VM::getattr(PyObject* obj, StrName name, bool throw_err){
-    PyObject* objtype;
+    Type objtype(0);
     // handle super() proxy
     if(is_non_tagged_type(obj, tp_super)){
         const Super& super = PK_OBJ_GET(Super, obj);
         obj = super.first;
-        objtype = _t(super.second);
+        objtype = super.second;
     }else{
-        objtype = _t(obj);
+        objtype = _tp(obj);
     }
     PyObject* cls_var = find_name_in_mro(objtype, name);
     if(cls_var != nullptr){
@@ -1015,7 +1010,7 @@ PyObject* VM::getattr(PyObject* obj, StrName name, bool throw_err){
     if(!is_tagged(obj) && obj->is_attr_valid()){
         PyObject* val;
         if(obj->type == tp_type){
-            val = find_name_in_mro(obj, name);
+            val = find_name_in_mro(PK_OBJ_GET(Type, obj), name);
             if(val != nullptr){
                 if(is_tagged(val)) return val;
                 if(val->type == tp_staticmethod) return PK_OBJ_GET(StaticMethod, val).func;
@@ -1038,11 +1033,16 @@ PyObject* VM::getattr(PyObject* obj, StrName name, bool throw_err){
                 case tp_staticmethod.index:
                     return PK_OBJ_GET(StaticMethod, cls_var).func;
                 case tp_classmethod.index:
-                    return VAR(BoundMethod(objtype, PK_OBJ_GET(ClassMethod, cls_var).func));
+                    return VAR(BoundMethod(_t(objtype), PK_OBJ_GET(ClassMethod, cls_var).func));
             }
         }
         return cls_var;
     }
+
+    const PyTypeInfo* ti = &_all_types[objtype];
+    if(ti->m__getattr__){
+        return ti->m__getattr__(this, obj, name);
+    }
     
     if(is_non_tagged_type(obj, tp_module)){
         Str path = CAST(Str&, obj->attr(__path__));
@@ -1062,14 +1062,14 @@ PyObject* VM::getattr(PyObject* obj, StrName name, bool throw_err){
 // try to load a unbound method (fallback to `getattr` if not found)
 PyObject* VM::get_unbound_method(PyObject* obj, StrName name, PyObject** self, bool throw_err, bool fallback){
     *self = PY_NULL;
-    PyObject* objtype;
+    Type objtype(0);
     // handle super() proxy
     if(is_non_tagged_type(obj, tp_super)){
         const Super& super = PK_OBJ_GET(Super, obj);
         obj = super.first;
-        objtype = _t(super.second);
+        objtype = super.second;
     }else{
-        objtype = _t(obj);
+        objtype = _tp(obj);
     }
     PyObject* cls_var = find_name_in_mro(objtype, name);
 
@@ -1085,7 +1085,7 @@ PyObject* VM::get_unbound_method(PyObject* obj, StrName name, PyObject** self, b
         if(!is_tagged(obj) && obj->is_attr_valid()){
             PyObject* val;
             if(obj->type == tp_type){
-                val = find_name_in_mro(obj, name);
+                val = find_name_in_mro(PK_OBJ_GET(Type, obj), name);
                 if(val != nullptr){
                     if(is_tagged(val)) return val;
                     if(val->type == tp_staticmethod) return PK_OBJ_GET(StaticMethod, val).func;
@@ -1112,25 +1112,31 @@ PyObject* VM::get_unbound_method(PyObject* obj, StrName name, PyObject** self, b
                     *self = PY_NULL;
                     return PK_OBJ_GET(StaticMethod, cls_var).func;
                 case tp_classmethod.index:
-                    *self = objtype;
+                    *self = _t(objtype);
                     return PK_OBJ_GET(ClassMethod, cls_var).func;
             }
         }
         return cls_var;
     }
+
+    const PyTypeInfo* ti = &_all_types[objtype];
+    if(fallback && ti->m__getattr__){
+        return ti->m__getattr__(this, obj, name);
+    }
+
     if(throw_err) AttributeError(obj, name);
     return nullptr;
 }
 
 void VM::setattr(PyObject* obj, StrName name, PyObject* value){
-    PyObject* objtype;
+    Type objtype(0);
     // handle super() proxy
     if(is_non_tagged_type(obj, tp_super)){
         Super& super = PK_OBJ_GET(Super, obj);
         obj = super.first;
-        objtype = _t(super.second);
+        objtype = super.second;
     }else{
-        objtype = _t(obj);
+        objtype = _tp(obj);
     }
     PyObject* cls_var = find_name_in_mro(objtype, name);
     if(cls_var != nullptr){
@@ -1145,6 +1151,13 @@ void VM::setattr(PyObject* obj, StrName name, PyObject* value){
             return;
         }
     }
+
+    const PyTypeInfo* ti = &_all_types[objtype];
+    if(ti->m__setattr__){
+        ti->m__setattr__(this, obj, name, value);
+        return;
+    }
+
     // handle instance __dict__
     if(is_tagged(obj) || !obj->is_attr_valid()) TypeError("cannot set attribute");
     obj->attr().set(name, value);