Просмотр исходного кода

optimize `int()` and `float()`

blueloveTH 1 год назад
Родитель
Сommit
d7545071e5
4 измененных файлов с 63 добавлено и 57 удалено
  1. 53 41
      src/pocketpy.cpp
  2. 4 0
      tests/01_int.py
  3. 6 0
      tests/02_float.py
  4. 0 16
      tests/99_builtin_func.py

+ 53 - 41
src/pocketpy.cpp

@@ -399,31 +399,37 @@ void init_builtins(VM* _vm) {
         if(args.size() == 1+0) return VAR(0);
         // 1 arg
         if(args.size() == 1+1){
-            if (is_type(args[1], vm->tp_float)) return VAR((i64)CAST(f64, args[1]));
-            if (is_type(args[1], vm->tp_int)) return args[1];
-            if (is_type(args[1], vm->tp_bool)) return VAR(_CAST(bool, args[1]) ? 1 : 0);
+            switch(vm->_tp(args[1]).index){
+                case VM::tp_float.index:
+                    return VAR((i64)_CAST(f64, args[1]));
+                case VM::tp_int.index:
+                    return args[1];
+                case VM::tp_bool.index:
+                    return VAR(args[1]==vm->True ? 1 : 0);
+                case VM::tp_str.index:
+                    break;
+                default:
+                    vm->TypeError("invalid arguments for int()");
+            }
         }
+        // 2+ args -> error
         if(args.size() > 1+2) vm->TypeError("int() takes at most 2 arguments");
-        // 2 args
-        if (is_type(args[1], vm->tp_str)) {
-            int base = 10;
-            if(args.size() == 1+2) base = CAST(i64, args[2]);
-            const Str& s = CAST(Str&, args[1]);
-            std::string_view sv = s.sv();
-            bool negative = false;
-            if(!sv.empty() && (sv[0] == '+' || sv[0] == '-')){
-                negative = sv[0] == '-';
-                sv.remove_prefix(1);
-            }
-            i64 val;
-            if(parse_int(sv, &val, base) != IntParsingResult::Success){
-                vm->ValueError(_S("invalid literal for int() with base ", base, ": ", s.escape()));
-            }
-            if(negative) val = -val;
-            return VAR(val);
+        // 1 or 2 args with str
+        int base = 10;
+        if(args.size() == 1+2) base = CAST(i64, args[2]);
+        const Str& s = CAST(Str&, args[1]);
+        std::string_view sv = s.sv();
+        bool negative = false;
+        if(!sv.empty() && (sv[0] == '+' || sv[0] == '-')){
+            negative = sv[0] == '-';
+            sv.remove_prefix(1);
         }
-        vm->TypeError("invalid arguments for int()");
-        return vm->None;
+        i64 val;
+        if(parse_int(sv, &val, base) != IntParsingResult::Success){
+            vm->ValueError(_S("invalid literal for int() with base ", base, ": ", s.escape()));
+        }
+        if(negative) val = -val;
+        return VAR(val);
     });
 
     _vm->bind__floordiv__(VM::tp_int, [](VM* vm, PyObject* _0, PyObject* _1) {
@@ -460,26 +466,32 @@ void init_builtins(VM* _vm) {
         if(args.size() == 1+0) return VAR(0.0);
         if(args.size() > 1+1) vm->TypeError("float() takes at most 1 argument");
         // 1 arg
-        if (is_type(args[1], vm->tp_int)) return VAR((f64)CAST(i64, args[1]));
-        if (is_type(args[1], vm->tp_float)) return args[1];
-        if (is_type(args[1], vm->tp_bool)) return VAR(_CAST(bool, args[1]) ? 1.0 : 0.0);
-        if (is_type(args[1], vm->tp_str)) {
-            const Str& s = CAST(Str&, args[1]);
-            if(s == "inf") return VAR(INFINITY);
-            if(s == "-inf") return VAR(-INFINITY);
-
-            double float_out;
-            char* p_end;
-            try{
-                float_out = std::strtod(s.data, &p_end);
-                PK_ASSERT(p_end == s.end());
-            }catch(...){
-                vm->ValueError("invalid literal for float(): " + s.escape());
-            }
-            return VAR(float_out);
+        switch(vm->_tp(args[1]).index){
+            case VM::tp_int.index:
+                return VAR((f64)CAST(i64, args[1]));
+            case VM::tp_float.index:
+                return args[1];
+            case VM::tp_bool.index:
+                return VAR(args[1]==vm->True ? 1.0 : 0.0);
+            case VM::tp_str.index:
+                break;
+            default:
+                vm->TypeError("invalid arguments for float()");
         }
-        vm->TypeError("invalid arguments for float()");
-        return vm->None;
+        // str to float
+        const Str& s = PK_OBJ_GET(Str, args[1]);
+        if(s == "inf") return VAR(INFINITY);
+        if(s == "-inf") return VAR(-INFINITY);
+
+        double float_out;
+        char* p_end;
+        try{
+            float_out = std::strtod(s.data, &p_end);
+            PK_ASSERT(p_end == s.end());
+        }catch(...){
+            vm->ValueError("invalid literal for float(): " + s.escape());
+        }
+        return VAR(float_out);
     });
 
     _vm->bind__hash__(VM::tp_float, [](VM* vm, PyObject* _0) {

+ 4 - 0
tests/01_int.py

@@ -53,6 +53,10 @@ assert str(1) == '1'
 assert repr(1) == '1'
 
 # test int()
+assert int() == 0
+assert int(True) == 1
+assert int(False) == 0
+
 assert int(1) == 1
 assert int(1.0) == 1
 assert int(1.1) == 1

+ 6 - 0
tests/02_float.py

@@ -36,6 +36,12 @@ assert str(1.0) == '1.0'
 assert repr(1.0) == '1.0'
 
 # test float()
+assert float() == 0.0
+assert float(True) == 1.0
+assert float(False) == 0.0
+assert float(1) == 1.0
+assert float(-2) == -2.0
+
 assert eq(float(1), 1.0)
 assert eq(float(1.0), 1.0)
 assert eq(float(1.1), 1.1)

+ 0 - 16
tests/99_builtin_func.py

@@ -160,22 +160,6 @@ class A():
 
 repr(A())
 
-
-# 未完全测试准确性-----------------------------------------------
-#     33600:  318:    _vm->bind_constructor<-1>("range", [](VM* vm, ArgsView args) {
-#     16742:  319:        args._begin += 1;   // skip cls
-#     16742:  320:        Range r;
-#     16742:  321:        switch (args.size()) {
-#      8735:  322:            case 1: r.stop = CAST(i64, args[0]); break;
-#      3867:  323:            case 2: r.start = CAST(i64, args[0]); r.stop = CAST(i64, args[1]); break;
-#      4140:  324:            case 3: r.start = CAST(i64, args[0]); r.stop = CAST(i64, args[1]); r.step = CAST(i64, args[2]); break;
-#     #####:  325:            default: vm->TypeError("expected 1-3 arguments, got " + std::to_string(args.size()));
-#     #####:  326:        }
-#     33484:  327:        return VAR(r);
-#     16742:  328:    });
-#         -:  329:
-# test range:
-
 try:
     range(1,2,3,4)
     print('未能拦截错误, 在测试 range')