Prechádzať zdrojové kódy

fix https://github.com/blueloveTH/pocketpy/issues/131

blueloveTH 2 rokov pred
rodič
commit
60e666c12e
4 zmenil súbory, kde vykonal 148 pridanie a 9 odobranie
  1. 1 0
      include/pocketpy/str.h
  2. 5 2
      src/ceval.cpp
  3. 23 7
      src/pocketpy.cpp
  4. 119 0
      tests/40_class_ex.py

+ 1 - 0
include/pocketpy/str.h

@@ -190,6 +190,7 @@ const StrName __name__ = StrName::get("__name__");
 const StrName __all__ = StrName::get("__all__");
 const StrName __package__ = StrName::get("__package__");
 const StrName __path__ = StrName::get("__path__");
+const StrName __class__ = StrName::get("__class__");
 
 const StrName pk_id_add = StrName::get("add");
 const StrName pk_id_set = StrName::get("set");

+ 5 - 2
src/ceval.cpp

@@ -659,11 +659,14 @@ __NEXT_STEP:;
         _0 = POPX();
         _0->attr()._try_perfect_rehash();
         DISPATCH();
-    TARGET(STORE_CLASS_ATTR)
+    TARGET(STORE_CLASS_ATTR){
         _name = StrName(byte.arg);
         _0 = POPX();
+        if(is_non_tagged_type(_0, tp_function)){
+            _0->attr().set(__class__, TOP());
+        }
         TOP()->attr().set(_name, _0);
-        DISPATCH();
+    } DISPATCH();
     /*****************************************/
     TARGET(WITH_ENTER)
         call_method(POPX(), __enter__);

+ 23 - 7
src/pocketpy.cpp

@@ -108,16 +108,32 @@ void init_builtins(VM* _vm) {
 #undef BIND_NUM_ARITH_OPT
 #undef BIND_NUM_LOGICAL_OPT
 
-    _vm->bind_builtin_func<2>("super", [](VM* vm, ArgsView args) {
-        vm->check_non_tagged_type(args[0], vm->tp_type);
-        Type type = PK_OBJ_GET(Type, args[0]);
-        if(!vm->isinstance(args[1], type)){
-            Str _0 = obj_type_name(vm, PK_OBJ_GET(Type, vm->_t(args[1])));
+    _vm->bind_builtin_func<-1>("super", [](VM* vm, ArgsView args) {
+        PyObject* class_arg = nullptr;
+        PyObject* self_arg = nullptr;
+        if(args.size() == 2){
+            class_arg = args[0];
+            self_arg = args[1];
+        }else if(args.size() == 0){
+            FrameId frame = vm->top_frame();
+            if(frame->_callable != nullptr){
+                class_arg = frame->_callable->attr().try_get(__class__);
+                if(frame->_locals.size() > 0) self_arg = frame->_locals[0];
+            }
+            if(class_arg == nullptr || self_arg == nullptr){
+                vm->TypeError("super(): unable to determine the class context, use super(class, self) instead");
+            }
+        }else{
+            vm->TypeError("super() takes 0 or 2 arguments");
+        }
+        vm->check_non_tagged_type(class_arg, vm->tp_type);
+        Type type = PK_OBJ_GET(Type, class_arg);
+        if(!vm->isinstance(self_arg, type)){
+            Str _0 = obj_type_name(vm, PK_OBJ_GET(Type, vm->_t(self_arg)));
             Str _1 = obj_type_name(vm, type);
             vm->TypeError("super(): " + _0.escape() + " is not an instance of " + _1.escape());
         }
-        Type base = vm->_all_types[type].base;
-        return vm->heap.gcnew<Super>(vm->tp_super, args[1], base);
+        return vm->heap.gcnew<Super>(vm->tp_super, self_arg, vm->_all_types[type].base);
     });
 
     _vm->bind_builtin_func<2>("isinstance", [](VM* vm, ArgsView args) {

+ 119 - 0
tests/40_class_ex.py

@@ -0,0 +1,119 @@
+class A:
+    def __init__(self, a, b):
+        self.a = a
+        self.b = b
+
+    def add(self):
+        return self.a + self.b
+
+    def sub(self):
+        return self.a - self.b
+    
+a = A(1, 2)
+assert a.add() == 3
+assert a.sub() == -1
+
+assert A.__base__ is object
+
+class B(A):
+    def __init__(self, a, b, c):
+        super().__init__(a, b)
+        self.c = c
+
+    def add(self):
+        return self.a + self.b + self.c
+
+    def sub(self):
+        return self.a - self.b - self.c
+
+assert B.__base__ is A    
+
+b = B(1, 2, 3)
+assert b.add() == 6
+assert b.sub() == -4
+
+class C(B):
+    def __init__(self, a, b, c, d):
+        super().__init__(a, b, c)
+        self.d = d
+
+    def add(self):
+        return self.a + self.b + self.c + self.d
+
+    def sub(self):
+        return self.a - self.b - self.c - self.d
+    
+assert C.__base__ is B
+
+c = C(1, 2, 3, 4)
+assert c.add() == 10
+assert c.sub() == -8
+
+class D(C):
+    def __init__(self, a, b, c, d, e):
+        super().__init__(a, b, c, d)
+        self.e = e
+
+    def add(self):
+        return super().add() + self.e
+
+    def sub(self):
+        return super().sub() - self.e
+    
+assert D.__base__ is C
+
+d = D(1, 2, 3, 4, 5)
+assert d.add() == 15
+assert d.sub() == -13
+
+assert isinstance(1, int)
+assert isinstance(1, object)
+assert isinstance(C, type)
+assert isinstance(C, object)
+assert isinstance(d, object)
+assert isinstance(d, C)
+assert isinstance(d, B)
+assert isinstance(d, A)
+assert isinstance(object, object)
+assert isinstance(type, object)
+
+assert isinstance(1, (float, int))
+assert isinstance(1, (float, object))
+assert not isinstance(1, (float, str))
+assert isinstance(object, (int, type, float))
+assert not isinstance(object, (int, float, str))
+
+try:
+    isinstance(1, (1, 2))
+    exit(1)
+except TypeError:
+    pass
+
+try:
+    isinstance(1, 1)
+    exit(1)
+except TypeError:
+    pass
+
+class A:
+    a = 1
+    b = 2
+
+assert A.a == 1
+assert A.b == 2
+
+class B(A):
+    b = 3
+    c = 4
+
+# assert B.a == 1  ...bug here
+assert B.b == 3
+assert B.c == 4
+
+import c
+
+class A(c.void_p):
+    pass
+    
+a = A()
+assert repr(a).startswith('<void* at')