blueloveTH 3 лет назад
Родитель
Сommit
e227910dbd
10 измененных файлов с 51 добавлено и 34 удалено
  1. 1 1
      src/__stl__.h
  2. 9 15
      src/builtins.h
  3. 4 2
      src/compiler.h
  4. 1 7
      src/hash_table8.hpp
  5. 2 2
      src/main.cpp
  6. 10 0
      src/pocketpy.h
  7. 1 5
      src/safestl.h
  8. 1 2
      src/str.h
  9. 11 0
      tests/_basic.py
  10. 11 0
      tests/_functions.py

+ 1 - 1
src/__stl__.h

@@ -4,6 +4,7 @@
 #pragma warning (disable:4267)
 #pragma warning (disable:4101)
 #define _CRT_NONSTDC_NO_DEPRECATE
+#define strdup _strdup
 #endif
 
 #include <sstream>
@@ -18,7 +19,6 @@
 #include <string_view>
 #include <queue>
 #include <iomanip>
-#include <map>
 
 #include <atomic>
 #include <iostream>

+ 9 - 15
src/builtins.h

@@ -1,33 +1,27 @@
 #pragma once
 
 const char* __BUILTINS_CODE = R"(
-def len(x):
-    return x.__len__()
-
 def print(*args, sep=' ', end='\n'):
     s = sep.join([str(i) for i in args])
     __sys_stdout_write(s + end)
 
-def round(x):
+def round(x, ndigits=0):
+    assert ndigits >= 0
+    if ndigits == 0:
+        return x >= 0 ? int(x + 0.5) : int(x - 0.5)
     if x >= 0:
-        return int(x + 0.5)
+        return int(x * 10**ndigits + 0.5) / 10**ndigits
     else:
-        return int(x - 0.5)
+        return int(x * 10**ndigits - 0.5) / 10**ndigits
 
 def abs(x):
-    if x < 0:
-        return -x
-    return x
+    return x < 0 ? -x : x
 
 def max(a, b):
-    if a > b:
-        return a
-    return b
+    return a > b ? a : b
 
 def min(a, b):
-    if a < b:
-        return a
-    return b
+    return a < b ? a : b
 
 def sum(iterable):
     res = 0

+ 4 - 2
src/compiler.h

@@ -883,8 +883,10 @@ __LISTCOMP:
             emitCode(OP_DELETE_REF);
             consumeEndStatement();
         } else if(match(TK("global"))){
-            consume(TK("@id"));
-            getCode()->co_global_names.push_back(parser->previous.str());
+            do {
+                consume(TK("@id"));
+                getCode()->co_global_names.push_back(parser->previous.str());
+            } while (match(TK(",")));
             consumeEndStatement();
         } else if(match(TK("pass"))){
             consumeEndStatement();

+ 1 - 7
src/hash_table8.hpp

@@ -25,11 +25,6 @@
 
 #pragma once
 
-// Modification:
-// 1. Add #define EMH_WYHASH_HASH 1
-// 2. Add static for wymix
-#define EMH_WYHASH_HASH 1
-
 #include <cstring>
 #include <string>
 #include <cstdlib>
@@ -1665,7 +1660,7 @@ one-way search strategy.
 
 #if EMH_WYHASH_HASH
     //#define WYHASH_CONDOM 1
-    inline static uint64_t wymix(uint64_t A, uint64_t B)
+    inline uint64_t wymix(uint64_t A, uint64_t B)
     {
 #if defined(__SIZEOF_INT128__)
         __uint128_t r = A; r *= B;
@@ -1791,4 +1786,3 @@ private:
     size_type _etail;
 };
 } // namespace emhash
-

+ 2 - 2
src/main.cpp

@@ -3,8 +3,8 @@
 
 #include "pocketpy.h"
 
-#define PK_DEBUG_TIME
-//#define PK_DEBUG_THREADED
+//#define PK_DEBUG_TIME
+#define PK_DEBUG_THREADED
 
 struct Timer{
     const char* title;

+ 10 - 0
src/pocketpy.h

@@ -85,6 +85,11 @@ void __initializeBuiltinFunctions(VM* _vm) {
         return vm->PyInt(vm->hash(args[0]));
     });
 
+    _vm->bindBuiltinFunc("len", [](VM* vm, const pkpy::ArgList& args) {
+        vm->__checkArgSize(args, 1);
+        return vm->call(args[0], __len__, pkpy::noArg());
+    });
+
     _vm->bindBuiltinFunc("chr", [](VM* vm, const pkpy::ArgList& args) {
         vm->__checkArgSize(args, 1);
         _Int i = vm->PyInt_AS_C(args[0]);
@@ -146,6 +151,11 @@ void __initializeBuiltinFunctions(VM* _vm) {
         return args[0]->_type;
     });
 
+    _vm->bindMethod("type", "__eq__", [](VM* vm, const pkpy::ArgList& args) {
+        vm->__checkArgSize(args, 2, true);
+        return vm->PyBool(args[0] == args[1]);
+    });
+
     _vm->bindMethod("range", "__new__", [](VM* vm, const pkpy::ArgList& args) {
         _Range r;
         switch (args.size()) {

+ 1 - 5
src/safestl.h

@@ -35,11 +35,7 @@ public:
     using std::vector<PyVar>::vector;
 };
 
-
-class PyVarDict: public emhash8::HashMap<_Str, PyVar> {
-    using emhash8::HashMap<_Str, PyVar>::HashMap;
-};
-
+typedef emhash8::HashMap<_Str, PyVar> PyVarDict;
 
 namespace pkpy {
     const uint8_t MAX_POOLING_N = 10;

+ 1 - 2
src/str.h

@@ -46,8 +46,7 @@ public:
 
     size_t hash() const{
         if(!hash_initialized){
-            //_hash = std::hash<std::string>()(*this);
-            _hash = emhash8::HashMap<int,int>::wyhashstr(data(), size());
+            _hash = std::hash<std::string>()(*this);
             hash_initialized = true;
         }
         return _hash;

+ 11 - 0
tests/_basic.py

@@ -106,3 +106,14 @@ assert [1, 2, 3] * 3 == [1, 2, 3, 1, 2, 3, 1, 2, 3]
 a = 5
 assert ((a > 3) ? 1 : 0) == 1
 assert ((a < 3) ? 1 : 0) == 0
+
+assert eq(round(3.1415926, 2), 3.14)
+assert eq(round(3.1415926, 3), 3.142)
+assert eq(round(3.1415926, 4), 3.1416)
+assert eq(round(-3.1415926, 2), -3.14)
+assert eq(round(-3.1415926, 3), -3.142)
+assert eq(round(-3.1415926, 4), -3.1416)
+assert round(23.2) == 23
+assert round(23.8) == 24
+assert round(-23.2) == -23
+assert round(-23.8) == -24

+ 11 - 0
tests/_functions.py

@@ -38,3 +38,14 @@ assert f(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, e=1) == 58
 assert f(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20) == 217
 assert f(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, d=1, e=2) == 213
 
+a = 1
+b = 2
+
+def f():
+    global a, b
+    a = 3
+    b = 4
+
+f()
+assert a == 3
+assert b == 4