Przeglądaj źródła

fix builtin modules and `super()`

blueloveTH 11 miesięcy temu
rodzic
commit
9ff3417621

+ 0 - 1
include/pocketpy/xmacros/magics.h

@@ -8,7 +8,6 @@ MAGIC_METHOD(__ge__)
 /////////////////////////////
 MAGIC_METHOD(__neg__)
 MAGIC_METHOD(__abs__)
-MAGIC_METHOD(__float__)
 MAGIC_METHOD(__int__)
 MAGIC_METHOD(__round__)
 MAGIC_METHOD(__divmod__)

+ 9 - 4
src/public/internal.c

@@ -171,9 +171,12 @@ bool pk_loadmethod(py_StackRef self, py_Name name) {
 
     if(name == __new__) {
         // __new__ acts like a @staticmethod
-        if(py_istype(self, tp_type)) {
+        if(self->type == tp_type) {
             // T.__new__(...)
             type = py_totype(self);
+        } else if(self->type == tp_super) {
+            // super(T, obj).__new__(...)
+            type = *(py_Type*)py_touserdata(self);
         } else {
             // invalid usage of `__new__`
             return false;
@@ -187,12 +190,16 @@ bool pk_loadmethod(py_StackRef self, py_Name name) {
         return false;
     }
 
+    py_TValue self_bak;  // to avoid overlapping
     // handle super() proxy
     if(py_istype(self, tp_super)) {
         type = *(py_Type*)py_touserdata(self);
-        *self = *py_getslot(self, 0);
+        // BUG: here we modify `self` which refers to the stack directly
+        // If `pk_loadmethod` fails, `self` will be corrupted
+        self_bak = *py_getslot(self, 0);
     } else {
         type = self->type;
+        self_bak = *self;
     }
 
     py_Ref cls_var = py_tpfindname(type, name);
@@ -200,8 +207,6 @@ bool pk_loadmethod(py_StackRef self, py_Name name) {
         switch(cls_var->type) {
             case tp_function:
             case tp_nativefunc: {
-                py_TValue self_bak = *self;
-                // `out` may overlap with `self`. If we assign `out`, `self` may be corrupted.
                 self[0] = *cls_var;
                 self[1] = self_bak;
                 break;

+ 1 - 1
src/public/py_dict.c

@@ -283,7 +283,7 @@ static bool dict__init__(int argc, py_Ref argv) {
     for(int i = 0; i < length; i++) {
         py_Ref tuple = &p[i];
         if(!py_istuple(tuple) || py_tuple_len(tuple) != 2) {
-            return TypeError("dict.__init__() argument must be a list of tuple-2");
+            return ValueError("dict.__init__() argument must be a list of tuple-2");
         }
         py_Ref key = py_tuple_getitem(tuple, 0);
         py_Ref val = py_tuple_getitem(tuple, 1);

+ 1 - 1
src/public/py_number.c

@@ -388,7 +388,7 @@ static bool float__new__(int argc, py_Ref argv) {
             py_newfloat(py_retval(), float_out);
             return true;
         }
-        default: return pk_callmagic(__float__, 1, argv + 1);
+        default: return TypeError("float() argument must be a string or a real number");
     }
 }
 

+ 44 - 40
tests/77_builtin_func_1.py

@@ -7,7 +7,7 @@ class TestSuperBase():
         return self.base_attr
     
     def error(self):
-        raise Exception('未能拦截错误')
+        raise RuntimeError('未能拦截错误')
     
 
 class TestSuperChild1(TestSuperBase):
@@ -20,7 +20,7 @@ class TestSuperChild1(TestSuperBase):
     def error_handling(self):
         try:
             super(TestSuperChild1, self).error()
-        except:
+        except RuntimeError:
             pass
 
 class TestSuperChild2(TestSuperBase):
@@ -54,18 +54,22 @@ class TestSuperNoBaseMethod(TestSuperBase):
     def __init__(self):
         super(TestSuperNoBaseMethod, self).append(1)
 
+class TestSuperNoParent():
+    def method(self):
+        super(TestSuperNoParent, self).method()
+
 try:
-    t = TestSuperNoParent()
-    print('未能拦截错误')
+    t = TestSuperNoParent().method()
+    print('未能拦截错误2')
     exit(2)
-except:
+except AttributeError:
     pass
 
 try:
     t = TestSuperNoBaseMethod()
-    print('未能拦截错误')
+    print('未能拦截错误3')
     exit(3)
-except:
+except AttributeError:
     pass
 
 class B():
@@ -82,17 +86,17 @@ class D():
 try:
     c = C()
     c.method()
-    print('未能拦截错误')
+    print('未能拦截错误4')
     exit(4)
-except:
+except AttributeError:
     pass
 
 try:
     d = D()
     d.method()
-    print('未能拦截错误')
+    print('未能拦截错误5')
     exit(5)
-except:
+except TypeError:
     pass
 
 # test hash:
@@ -133,16 +137,16 @@ assert type(hash(a)) is int
 # 测试不可哈希对象
 try:
     hash({1:1})
-    print('未能拦截错误')
+    print('未能拦截错误6')
     exit(6)
-except:
+except TypeError:
     pass
 
 try:
     hash([1])
-    print('未能拦截错误')
+    print('未能拦截错误7')
     exit(7)
-except:
+except TypeError:
     pass
 
 # test chr
@@ -165,24 +169,24 @@ repr(A())
 
 try:
     range(1,2,3,4)
-    print('未能拦截错误, 在测试 range')
+    print('未能拦截错误8, 在测试 range')
     exit(8)
-except:
+except TypeError:
     pass
 
 # /************ int ************/
 try:
     int('asad')
-    print('未能拦截错误, 在测试 int')
+    print('未能拦截错误9, 在测试 int')
     exit(9)
-except:
+except ValueError:
     pass
 
 try:
     int(123, 16)
-    print('未能拦截错误, 在测试 int')
+    print('未能拦截错误10, 在测试 int')
     exit(10)
-except:
+except TypeError:
     pass
 
 assert type(10//11) is int
@@ -191,16 +195,16 @@ assert type(11%2) is int
 
 try:
     float('asad')
-    print('未能拦截错误, 在测试 float')
+    print('未能拦截错误11, 在测试 float')
     exit(11)
-except:
+except ValueError:
     pass
 
 try:
     float([])
-    print('未能拦截错误, 在测试 float')
+    print('未能拦截错误12, 在测试 float')
     exit(12)
-except:
+except TypeError:
     pass
 
 # /************ str ************/
@@ -212,10 +216,10 @@ assert type(12 * '12') is str
 assert type('25363546'.index('63')) is int
 try:
     '25363546'.index('err')
-    print('未能拦截错误, 在测试 str.index')
+    print('未能拦截错误13, 在测试 str.index')
     exit(13)
-except:
-    pass
+except ValueError as e:
+    assert str(e) == "substring not found"
 
 
 # 未完全测试准确性-----------------------------------------------
@@ -227,9 +231,9 @@ assert '25363546'.find('err') == -1
 # /************ list ************/
 try:
     list(1,2)
-    print('未能拦截错误, 在测试 list')
+    print('未能拦截错误14, 在测试 list')
     exit(14)
-except:
+except TypeError:
     pass
 
 # 未完全测试准确性----------------------------------------------
@@ -237,10 +241,10 @@ except:
 assert type([1,2,3,4,5].index(4)) is int
 try:
     [1,2,3,4,5].index(6)
-    print('未能拦截错误, 在测试 list.index')
+    print('未能拦截错误15, 在测试 list.index')
     exit(15)
-except:
-    pass
+except ValueError as e:
+    assert str(e) == "list.index(x): x not in list"
 
 
 
@@ -248,19 +252,19 @@ except:
 # test list.remove:
 try:
     [1,2,3,4,5].remove(6)
-    print('未能拦截错误, 在测试 list.remove')
+    print('未能拦截错误16, 在测试 list.remove')
     exit(16)
-except:
-    pass
+except ValueError as e:
+    assert str(e) == "list.remove(x): x not in list"
 
 
 # 未完全测试准确性----------------------------------------------
 # test list.pop:
 try:
     [1,2,3,4,5].pop(1,2,3,4)
-    print('未能拦截错误, 在测试 list.pop')
+    print('未能拦截错误17, 在测试 list.pop')
     exit(17)
-except:
+except TypeError:
     pass
 
 
@@ -274,9 +278,9 @@ assert type(12 * [12]) is list
 # test tuple:
 try:
     tuple(1,2)
-    print('未能拦截错误, 在测试 tuple')
+    print('未能拦截错误18, 在测试 tuple')
     exit(18)
-except:
+except TypeError:
     pass
 
 assert [1,2,2,3,3,3].count(3) == 3

+ 4 - 4
tests/77_builtin_func_2.py

@@ -70,21 +70,21 @@ try:
     dict([(1, 2, 3)])
     print('未能拦截错误, 在测试 dict')
     exit(1)
-except:
+except ValueError:
     pass
 
 try:
     dict([(1, 2)], 1)
     print('未能拦截错误, 在测试 dict')
     exit(1)
-except:
+except TypeError:
     pass
 
 try:
     hash(dict([(1,2)]))
     print('未能拦截错误, 在测试 dict.__hash__')
     exit(1)
-except:
+except TypeError:
     pass
 
 # test dict.__iter__
@@ -102,7 +102,7 @@ try:
     {1:2, 3:4}.get(1,1, 1)
     print('未能拦截错误, 在测试 dict.get')
     exit(1)
-except:
+except TypeError:
     pass
 
 # 未完全测试准确性-----------------------------------------------