소스 검색

reimpl `round` in cpp

blueloveTH 2 년 전
부모
커밋
7741d592da
3개의 변경된 파일37개의 추가작업 그리고 9개의 파일을 삭제
  1. 0 9
      python/builtins.py
  2. 22 0
      src/pocketpy.cpp
  3. 15 0
      tests/70_builtins.py

+ 0 - 9
python/builtins.py

@@ -4,15 +4,6 @@ def print(*args, sep=' ', end='\n'):
     s = sep.join([str(i) for i in args])
     _sys.stdout.write(s + end)
 
-def round(x, ndigits=0):
-    assert ndigits >= 0
-    if ndigits == 0:
-        return int(x + 0.5) if x >= 0 else int(x - 0.5)
-    if x >= 0:
-        return int(x * 10**ndigits + 0.5) / 10**ndigits
-    else:
-        return int(x * 10**ndigits - 0.5) / 10**ndigits
-
 def abs(x):
     return -x if x < 0 else x
 

+ 22 - 0
src/pocketpy.cpp

@@ -123,6 +123,28 @@ void init_builtins(VM* _vm) {
         return VAR(MappingProxy(mod));
     });
 
+    // def round(x, ndigits=0):
+    //     assert ndigits >= 0
+    //     if ndigits == 0:
+    //         return int(x + 0.5) if x >= 0 else int(x - 0.5)
+    //     if x >= 0:
+    //         return int(x * 10**ndigits + 0.5) / 10**ndigits
+    //     else:
+    //         return int(x * 10**ndigits - 0.5) / 10**ndigits
+    _vm->bind(_vm->builtins, "round(x, ndigits=0)", [](VM* vm, ArgsView args) {
+        f64 x = CAST(f64, args[0]);
+        int ndigits = CAST(int, args[1]);
+        if(ndigits == 0){
+            return x >= 0 ? VAR((i64)(x + 0.5)) : VAR((i64)(x - 0.5));
+        }
+        if(ndigits < 0) vm->ValueError("ndigits should be non-negative");
+        if(x >= 0){
+            return VAR((i64)(x * std::pow(10, ndigits) + 0.5) / std::pow(10, ndigits));
+        }else{
+            return VAR((i64)(x * std::pow(10, ndigits) - 0.5) / std::pow(10, ndigits));
+        }
+    });
+
     _vm->bind_builtin_func<3>("pow", [](VM* vm, ArgsView args) {
         i64 lhs = CAST(i64, args[0]);   // assume lhs>=0
         i64 rhs = CAST(i64, args[1]);   // assume rhs>=0

+ 15 - 0
tests/70_builtins.py

@@ -2,6 +2,21 @@ assert round(23.2) == 23
 assert round(23.8) == 24
 assert round(-23.2) == -23
 assert round(-23.8) == -24
+# round with precision
+assert round(23.2, 1) == 23.2
+assert round(23.8, 1) == 23.8
+assert round(-23.2, 1) == -23.2
+assert round(-23.8, 1) == -23.8
+assert round(3.14159, 4) == 3.1416
+assert round(3.14159, 3) == 3.142
+assert round(3.14159, 2) == 3.14
+assert round(3.14159, 1) == 3.1
+assert round(3.14159, 0) == 3
+assert round(-3.14159, 4) == -3.1416
+assert round(-3.14159, 3) == -3.142
+assert round(-3.14159, 2) == -3.14
+assert round(-3.14159, 1) == -3.1
+assert round(-3.14159, 0) == -3
 
 a = [1,2,3,-1]
 assert sorted(a) == [-1,1,2,3]