Преглед изворни кода

add `__float__` and `__int__` and `__round__`

blueloveTH пре 1 година
родитељ
комит
bf208c3733
4 измењених фајлова са 63 додато и 34 уклоњено
  1. 3 0
      include/pocketpy/xmacros/magics.h
  2. 15 15
      src/public/modules.c
  3. 20 19
      src/public/py_number.c
  4. 25 0
      tests/99_extras.py

+ 3 - 0
include/pocketpy/xmacros/magics.h

@@ -59,6 +59,9 @@ MAGIC_METHOD(__package__)
 MAGIC_METHOD(__path__)
 MAGIC_METHOD(__path__)
 MAGIC_METHOD(__class__)
 MAGIC_METHOD(__class__)
 MAGIC_METHOD(__abs__)
 MAGIC_METHOD(__abs__)
+MAGIC_METHOD(__float__)
+MAGIC_METHOD(__int__)
+MAGIC_METHOD(__round__)
 MAGIC_METHOD(__getattr__)
 MAGIC_METHOD(__getattr__)
 MAGIC_METHOD(__missing__)
 MAGIC_METHOD(__missing__)
 
 

+ 15 - 15
src/public/modules.c

@@ -287,22 +287,23 @@ static bool builtins_round(int argc, py_Ref argv) {
         return TypeError("round() takes 1 or 2 arguments");
         return TypeError("round() takes 1 or 2 arguments");
     }
     }
 
 
-    if(py_isint(py_arg(0))) {
+    if(argv->type == tp_int) {
         py_assign(py_retval(), py_arg(0));
         py_assign(py_retval(), py_arg(0));
         return true;
         return true;
-    }
-
-    PY_CHECK_ARG_TYPE(0, tp_float);
-    py_f64 x = py_tofloat(py_arg(0));
-    py_f64 offset = x >= 0 ? 0.5 : -0.5;
-    if(ndigits == -1) {
-        py_newint(py_retval(), (py_i64)(x + offset));
+    } else if(argv->type == tp_float) {
+        PY_CHECK_ARG_TYPE(0, tp_float);
+        py_f64 x = py_tofloat(py_arg(0));
+        py_f64 offset = x >= 0 ? 0.5 : -0.5;
+        if(ndigits == -1) {
+            py_newint(py_retval(), (py_i64)(x + offset));
+            return true;
+        }
+        py_f64 factor = pow(10, ndigits);
+        py_newfloat(py_retval(), (py_i64)(x * factor + offset) / factor);
         return true;
         return true;
     }
     }
 
 
-    py_f64 factor = pow(10, ndigits);
-    py_newfloat(py_retval(), (py_i64)(x * factor + offset) / factor);
-    return true;
+    return pk_callmagic(__round__, argc, argv);
 }
 }
 
 
 static bool builtins_print(int argc, py_Ref argv) {
 static bool builtins_print(int argc, py_Ref argv) {
@@ -442,12 +443,11 @@ static bool builtins_ord(int argc, py_Ref argv) {
     PY_CHECK_ARG_TYPE(0, tp_str);
     PY_CHECK_ARG_TYPE(0, tp_str);
     c11_sv sv = py_tosv(py_arg(0));
     c11_sv sv = py_tosv(py_arg(0));
     if(c11_sv__u8_length(sv) != 1) {
     if(c11_sv__u8_length(sv) != 1) {
-        return TypeError("ord() expected a character, but string of length %d found", c11_sv__u8_length(sv));
+        return TypeError("ord() expected a character, but string of length %d found",
+                         c11_sv__u8_length(sv));
     }
     }
     int u8bytes = c11__u8_header(sv.data[0], true);
     int u8bytes = c11__u8_header(sv.data[0], true);
-    if (u8bytes == 0) {
-        return ValueError("invalid char: %c", sv.data[0]);
-    }
+    if(u8bytes == 0) { return ValueError("invalid char: %c", sv.data[0]); }
     int value = c11__u8_value(u8bytes, sv.data);
     int value = c11__u8_value(u8bytes, sv.data);
     py_newint(py_retval(), value);
     py_newint(py_retval(), value);
     return true;
     return true;

+ 20 - 19
src/public/py_number.c

@@ -297,7 +297,7 @@ static bool int__new__(int argc, py_Ref argv) {
                 return true;
                 return true;
             }
             }
             case tp_str: break;  // leave to the next block
             case tp_str: break;  // leave to the next block
-            default: return TypeError("invalid arguments for int()");
+            default: return pk_callmagic(__int__, 1, argv+1);
         }
         }
     }
     }
     // 2+ args -> error
     // 2+ args -> error
@@ -350,26 +350,27 @@ static bool float__new__(int argc, py_Ref argv) {
             py_newfloat(py_retval(), py_tobool(&argv[1]));
             py_newfloat(py_retval(), py_tobool(&argv[1]));
             return true;
             return true;
         }
         }
-        case tp_str: break;  // leave to the next block
-        default: return TypeError("invalid arguments for float()");
-    }
-    // str to float
-    c11_sv sv = py_tosv(py_arg(1));
+        case tp_str: {
+            // str to float
+            c11_sv sv = py_tosv(py_arg(1));
 
 
-    if(c11__sveq2(sv, "inf")) {
-        py_newfloat(py_retval(), INFINITY);
-        return true;
-    }
-    if(c11__sveq2(sv, "-inf")) {
-        py_newfloat(py_retval(), -INFINITY);
-        return true;
-    }
+            if(c11__sveq2(sv, "inf")) {
+                py_newfloat(py_retval(), INFINITY);
+                return true;
+            }
+            if(c11__sveq2(sv, "-inf")) {
+                py_newfloat(py_retval(), -INFINITY);
+                return true;
+            }
 
 
-    char* p_end;
-    py_f64 float_out = strtod(sv.data, &p_end);
-    if(p_end != sv.data + sv.size) return ValueError("invalid literal for float(): %q", sv);
-    py_newfloat(py_retval(), float_out);
-    return true;
+            char* p_end;
+            py_f64 float_out = strtod(sv.data, &p_end);
+            if(p_end != sv.data + sv.size) return ValueError("invalid literal for float(): %q", sv);
+            py_newfloat(py_retval(), float_out);
+            return true;
+        }
+        default: return pk_callmagic(__float__, 1, argv+1);
+    }
 }
 }
 
 
 // tp_bool
 // tp_bool

+ 25 - 0
tests/99_extras.py

@@ -51,3 +51,28 @@ assert A()[1:2, :A()[3:4, ::-1]] == (slice(1, 2, None), slice(None, (slice(3, 4,
 # test right associative
 # test right associative
 assert 2**2**3 == 256
 assert 2**2**3 == 256
 assert (2**2**3)**2 == 65536
 assert (2**2**3)**2 == 65536
+
+class Number:
+    def __float__(self):
+        return 1.0
+    
+    def __int__(self):
+        return 2
+    
+    def __divmod__(self, other):
+        return 3, 4
+    
+    def __round__(self, *args):
+        return args
+
+assert divmod(Number(), 0) == (3, 4)
+assert float(Number()) == 1.0
+assert int(Number()) == 2
+
+assert round(Number()) == tuple()
+assert round(Number(), 1) == (1,)
+
+
+
+
+