blueloveTH il y a 3 ans
Parent
commit
a4d9f8dc82
3 fichiers modifiés avec 139 ajouts et 55 suppressions
  1. 123 49
      src/cffi.h
  2. 5 2
      src/obj.h
  3. 11 4
      src/test.cpp

+ 123 - 49
src/cffi.h

@@ -9,8 +9,8 @@ namespace pkpy {
 
 template<typename Ret, typename... Params>
 struct NativeProxyFunc {
-    //using T = Ret(*)(Params...);
-    using T = std::function<Ret(Params...)>;
+    using T = Ret(*)(Params...);
+    // using T = std::function<Ret(Params...)>;
     static constexpr int N = sizeof...(Params);
     T func;
     NativeProxyFunc(T func) : func(func) {}
@@ -35,51 +35,86 @@ struct NativeProxyFunc {
     }
 };
 
+template<typename T>
+constexpr int type_index() { return 0; }
+template<> constexpr int type_index<void>() { return 1; }
+template<> constexpr int type_index<char>() { return 2; }
+template<> constexpr int type_index<short>() { return 3; }
+template<> constexpr int type_index<int>() { return 4; }
+template<> constexpr int type_index<long>() { return 5; }
+template<> constexpr int type_index<long long>() { return 6; }
+template<> constexpr int type_index<unsigned char>() { return 7; }
+template<> constexpr int type_index<unsigned short>() { return 8; }
+template<> constexpr int type_index<unsigned int>() { return 9; }
+template<> constexpr int type_index<unsigned long>() { return 10; }
+template<> constexpr int type_index<unsigned long long>() { return 11; }
+template<> constexpr int type_index<float>() { return 12; }
+template<> constexpr int type_index<double>() { return 13; }
+template<> constexpr int type_index<bool>() { return 14; }
+
+template<typename T>
+struct TypeId{ inline static int id; };
+
 struct TypeInfo;
 
 struct MemberInfo{
-    TypeInfo* type;
+    const TypeInfo* type;
     int offset;
 };
 
 struct TypeInfo{
     const char* name;
     int size;
-    int index;      // for basic types only
+    int index;
     std::map<StrName, MemberInfo> members;
-
-    TypeInfo(const char name[], int size, int index) : name(name), size(size), index(index) {}
-    TypeInfo(const char name[], int size, std::map<StrName, MemberInfo> members)
-        : name(name), size(size), index(-1), members(members) {}
-    TypeInfo() : name(nullptr), size(0), index(-1) {}
 };
 
-template<typename T>
-constexpr int type_index() { return -1; }
-template<> constexpr int type_index<void>() { return 0; }
-template<> constexpr int type_index<char>() { return 1; }
-template<> constexpr int type_index<short>() { return 2; }
-template<> constexpr int type_index<int>() { return 3; }
-template<> constexpr int type_index<long>() { return 4; }
-template<> constexpr int type_index<long long>() { return 5; }
-template<> constexpr int type_index<unsigned char>() { return 6; }
-template<> constexpr int type_index<unsigned short>() { return 7; }
-template<> constexpr int type_index<unsigned int>() { return 8; }
-template<> constexpr int type_index<unsigned long>() { return 9; }
-template<> constexpr int type_index<unsigned long long>() { return 10; }
-template<> constexpr int type_index<float>() { return 11; }
-template<> constexpr int type_index<double>() { return 12; }
-template<> constexpr int type_index<bool>() { return 13; }
-
 struct Vec2 {
     float x, y;
 };
 
-static std::map<std::string_view, TypeInfo> _type_infos;
+struct TypeDB{
+    std::vector<TypeInfo> _by_index;
+    std::map<std::string_view, int> _by_name;
+
+    template<typename T>
+    void register_type(const char name[], std::map<StrName, MemberInfo>&& members){
+        TypeInfo ti;
+        ti.name = name;
+        if constexpr(std::is_same_v<T, void>) ti.size = 1;
+        else ti.size = sizeof(T);
+        ti.members = std::move(members);
+        TypeId<T>::id = ti.index = _by_index.size()+1;    // 0 is reserved for NULL
+        _by_name[name] = ti.index;
+        _by_index.push_back(ti);
+    }
+
+    const TypeInfo* get(int index) const {
+        return index == 0 ? nullptr : &_by_index[index-1];
+    }
+
+    const TypeInfo* get(const char name[]) const {
+        auto it = _by_name.find(name);
+        if(it == _by_name.end()) return nullptr;
+        return get(it->second);
+    }
+
+    const TypeInfo* get(const Str& s) const {
+        return get(s.c_str());
+    }
+
+    template<typename T>
+    const TypeInfo* get() const {
+        return get(TypeId<T>::id);
+    }
+};
+
+static TypeDB _type_db;
+
 
 auto _ = [](){
-    #define REGISTER_BASIC_TYPE(T) _type_infos[#T] = TypeInfo(#T, sizeof(T), type_index<T>())
-    _type_infos["void"] = TypeInfo("void", 1, type_index<void>());
+    #define REGISTER_BASIC_TYPE(T) _type_db.register_type<T>(#T, {});
+    _type_db.register_type<void>("void", {});
     REGISTER_BASIC_TYPE(char);
     REGISTER_BASIC_TYPE(short);
     REGISTER_BASIC_TYPE(int);
@@ -95,15 +130,15 @@ auto _ = [](){
     REGISTER_BASIC_TYPE(bool);
     #undef REGISTER_BASIC_TYPE
 
-    _type_infos["Vec2"] = TypeInfo("Vec2", sizeof(Vec2), {
-        {"x", {&_type_infos["float"], offsetof(Vec2, x)}},
-        {"y", {&_type_infos["float"], offsetof(Vec2, y)}},
+    _type_db.register_type<Vec2>("Vec2", {
+        {"x", { _type_db.get<float>(), offsetof(Vec2, x) }},
+        {"y", { _type_db.get<float>(), offsetof(Vec2, y) }},
     });
     return 0;
 }();
 
 struct Pointer{
-    PY_CLASS(Pointer, c, ptr_)
+    PY_CLASS(Pointer, c, _ptr)
 
     const TypeInfo* ctype;      // this is immutable
     int level;                  // level of pointer
@@ -113,7 +148,7 @@ struct Pointer{
         return level == 1 ? ctype->size : sizeof(void*);
     }
 
-    Pointer() : ctype(&_type_infos["void"]), level(1), ptr(nullptr) {}
+    Pointer() : ctype(_type_db.get<void>()), level(1), ptr(nullptr) {}
     Pointer(const TypeInfo* ctype, int level, char* ptr): ctype(ctype), level(level), ptr(ptr) {}
     Pointer(const TypeInfo* ctype, char* ptr): ctype(ctype), level(1), ptr(ptr) {}
 
@@ -198,10 +233,10 @@ struct Pointer{
                 else break;
             }
             if(level == 0) vm->TypeError("expect a pointer type, such as 'int*'");
-            Str type = name.substr(0, name.size()-level);
-            auto it = _type_infos.find(type);
-            if(it == _type_infos.end()) vm->TypeError("unknown type: " + type.escape(true));
-            return VAR_T(Pointer, &it->second, level, self.ptr);
+            Str type_s = name.substr(0, name.size()-level);
+            const TypeInfo* type = _type_db.get(type_s);
+            if(type == nullptr) vm->TypeError("unknown type: " + type_s.escape(true));
+            return VAR_T(Pointer, type, level, self.ptr);
         });
     }
 
@@ -270,7 +305,7 @@ struct Pointer{
 
 
 struct Value {
-    PY_CLASS(Value, c, value_)
+    PY_CLASS(Value, c, _value)
 
     char* data;
     Pointer head;
@@ -305,15 +340,15 @@ struct CType{
 
     const TypeInfo* type;
 
-    CType() : type(&_type_infos["void"]) {}
+    CType() : type(_type_db.get<void>()) {}
     CType(const TypeInfo* type) : type(type) {}
 
     static void _register(VM* vm, PyVar mod, PyVar type){
         vm->bind_static_method<1>(type, "__new__", [](VM* vm, Args& args) {
             const Str& name = CAST(Str&, args[0]);
-            auto it = _type_infos.find(name);
-            if (it == _type_infos.end()) vm->TypeError("unknown type: " + name.escape(true));
-            return VAR_T(CType, &it->second);
+            const TypeInfo* type = _type_db.get(name);
+            if(type == nullptr) vm->TypeError("unknown type: " + name.escape(true));
+            return VAR_T(CType, type);
         });
 
         vm->bind_method<0>(type, "__call__", [](VM* vm, Args& args) {
@@ -333,7 +368,7 @@ void add_module_c(VM* vm){
 
     vm->bind_func<1>(mod, "malloc", [](VM* vm, Args& args) {
         i64 size = CAST(i64, args[0]);
-        return VAR_T(Pointer, &_type_infos["void"], (char*)malloc(size));
+        return VAR_T(Pointer, _type_db.get<void>(), (char*)malloc(size));
     });
 
     vm->bind_func<1>(mod, "free", [](VM* vm, Args& args) {
@@ -353,9 +388,9 @@ void add_module_c(VM* vm){
     vm->bind_func<1>(mod, "sizeof", [](VM* vm, Args& args) {
         const Str& name = CAST(Str&, args[0]);
         if(name.find('*') != Str::npos) return VAR(sizeof(void*));
-        auto it = _type_infos.find(name);
-        if(it == _type_infos.end()) vm->TypeError("unknown type: " + name.escape(true));
-        return VAR(it->second.size);
+        const TypeInfo* type = _type_db.get(name);
+        if(type == nullptr) vm->TypeError("unknown type: " + name.escape(true));
+        return VAR(type->size);
     });
 
     vm->bind_func<3>(mod, "memset", [](VM* vm, Args& args) {
@@ -368,11 +403,50 @@ void add_module_c(VM* vm){
 }
 
 PyVar py_var(VM* vm, void* p){
-    return VAR_T(Pointer, &_type_infos["void"], (char*)p);
+    return VAR_T(Pointer, _type_db.get<void>(), (char*)p);
 }
 
 PyVar py_var(VM* vm, char* p){
-    return VAR_T(Pointer, &_type_infos["char"], (char*)p);
+    return VAR_T(Pointer, _type_db.get<char>(), (char*)p);
+}
+
+/***********************************************/
+
+template<typename T>
+struct _pointer {
+    static constexpr int level = 0;
+    using baseT = T;
+};
+
+template<typename T>
+struct _pointer<T*> {
+    static constexpr int level = _pointer<T>::level + 1;
+    using baseT = typename _pointer<T>::baseT;
+};
+
+template<typename T>
+struct pointer {
+    static constexpr int level = _pointer<std::decay_t<T>>::level;
+    using baseT = typename _pointer<std::decay_t<T>>::baseT;
+};
+
+template<typename T>
+std::enable_if_t<std::is_pointer_v<T>, T>
+py_cast(VM* vm, const PyVar& var){
+    Pointer& p = CAST(Pointer&, var);
+    const TypeInfo* type = _type_db.get<typename pointer<T>::baseT>();
+    const int level = pointer<T>::level;
+    if(p.ctype != type || p.level != level){
+        vm->TypeError("invalid pointer cast");
+    }
+    return reinterpret_cast<T>(p.ptr);
+}
+
+template<typename T>
+std::enable_if_t<std::is_pointer_v<T>, PyVar>
+py_var(VM* vm, T p){
+    const TypeInfo* type = _type_db.get<typename pointer<T>::baseT>();
+    return VAR_T(Pointer, type, pointer<T>::level, (char*)p);
 }
 
 }   // namespace pkpy

+ 5 - 2
src/obj.h

@@ -3,6 +3,7 @@
 #include "common.h"
 #include "namedict.h"
 #include "tuplelist.h"
+#include <type_traits>
 
 namespace pkpy {
     
@@ -178,7 +179,8 @@ template<typename T>
 void _check_py_class(VM* vm, const PyVar& var);
 
 template<typename __T>
-__T py_cast(VM* vm, const PyVar& obj) {
+std::enable_if_t<!std::is_pointer_v<__T>, __T>
+py_cast(VM* vm, const PyVar& obj) {
     using T = std::decay_t<__T>;
     if constexpr(is_py_class<T>::value){
         _check_py_class<T>(vm, obj);
@@ -188,7 +190,8 @@ __T py_cast(VM* vm, const PyVar& obj) {
     }
 }
 template<typename __T>
-__T _py_cast(VM* vm, const PyVar& obj) {
+std::enable_if_t<!std::is_pointer_v<__T>, __T>
+_py_cast(VM* vm, const PyVar& obj) {
     using T = std::decay_t<__T>;
     if constexpr(is_py_class<T>::value){
         return OBJ_GET(T, obj);

+ 11 - 4
src/test.cpp

@@ -3,15 +3,22 @@
 
 using namespace pkpy;
 
-double add(int a, double b){
-    return a + b;
+float* f(int* a){
+    *a = 100;
+    return new float(3.5f);
 }
 
 int main(){
     VM* vm = pkpy_new_vm(true);
+    vm->bind_builtin_func<1>("f", NativeProxyFunc(&f));
 
-    vm->bind_builtin_func<2>("add", ProxyFunction(&add));
-    pkpy_vm_exec(vm, "print( add(1, 2.0) )");
+    pkpy_vm_exec(vm, R"(
+from c import *
+p = malloc(4).cast("int*")
+ret = f(p)
+print(p.get())          # 100
+print(ret, ret.get())   # 3.5
+)");
 
     pkpy_delete(vm);
     return 0;