blueloveTH 2 ani în urmă
părinte
comite
f59f0c373e
11 a modificat fișierele cu 122 adăugiri și 49 ștergeri
  1. 13 1
      docs/features/long.md
  2. 19 15
      python/_long.py
  3. 29 8
      src/ceval.h
  4. 4 0
      src/common.h
  5. 5 1
      src/compiler.h
  6. 14 0
      src/expr.h
  7. 21 16
      src/lexer.h
  8. 6 6
      src/pocketpy.h
  9. 3 0
      src/str.h
  10. 2 0
      src/vm.h
  11. 6 2
      tests/09_long.py

+ 13 - 1
docs/features/long.md

@@ -10,8 +10,20 @@ in 64 bit platforms, it is 62 bit.
 For arbitrary sized integers, we provide a builtin `long` type, just like python2's `long`.
 `long` is implemented via pure python in [_long.py](https://github.com/blueloveTH/pocketpy/blob/main/python/_long.py).
 
+### Create a long object
+
+You can use `L` suffix to create a `long` literal from a decimal literal.
+Also, you can use `long()` function to create a `long` object from a `int` object or a `str` object.
+
+```python
+a = 1000L
+b = long(1000)
+c = long('1000')
+assert a == b == c
+```
+
 ```python
-a = long(2)         # use long() to create a long explicitly
+a = 2L         # use `L` suffix to create a `long` object
 print(a ** 1000)
 # 10715086071862673209484250490600018105614048117055336074437503883703510511249361224931983788156958581275946729175531468251871452856923140435984577574698574803934567774824230985421074605062371141877954182153046474983581941267398767559165543946077062914571196477686542167660429831652624386837205668069376L
 ```

+ 19 - 15
python/_long.py

@@ -143,8 +143,9 @@ def ulong_repr(x: list) -> str:
     return s
 
 def ulong_fromstr(s: str):
-    res = [0]
-    base = [1]
+    if s[-1] == 'L':
+        s = s[:-1]
+    res, base = [0], [1]
     if s[0] == '-':
         sign = -1
         s = s[1:]
@@ -172,8 +173,8 @@ class long:
     def __add__(self, other):
         if type(other) is int:
             other = long(other)
-        else:
-            assert type(other) is long
+        elif type(other) is not long:
+            return NotImplemented
         if self.sign == other.sign:
             return long((ulong_add(self.digits, other.digits), self.sign))
         else:
@@ -191,21 +192,24 @@ class long:
     def __sub__(self, other):
         if type(other) is int:
             other = long(other)
-        else:
-            assert type(other) is long
+        elif type(other) is not long:
+            return NotImplemented
         if self.sign != other.sign:
             return long((ulong_add(self.digits, other.digits), self.sign))
+        cmp = ulong_cmp(self.digits, other.digits)
+        if cmp == 0:
+            return long(0)
+        if cmp > 0:
+            return long((ulong_sub(self.digits, other.digits), self.sign))
         else:
-            cmp = ulong_cmp(self.digits, other.digits)
-            if cmp == 0:
-                return long(0)
-            if cmp > 0:
-                return long((ulong_sub(self.digits, other.digits), self.sign))
-            else:
-                return long((ulong_sub(other.digits, self.digits), -other.sign))
+            return long((ulong_sub(other.digits, self.digits), -other.sign))
             
     def __rsub__(self, other):
-        return self.__sub__(other)
+        if type(other) is int:
+            other = long(other)
+        elif type(other) is not long:
+            return NotImplemented
+        return other.__sub__(self)
     
     def __mul__(self, other):
         if type(other) is int:
@@ -218,7 +222,7 @@ class long:
                 ulong_mul(self.digits, other.digits),
                 self.sign * other.sign
             ))
-        raise TypeError('unsupported operand type(s) for *')
+        return NotImplemented
     
     def __rmul__(self, other):
         return self.__mul__(other)

+ 29 - 8
src/ceval.h

@@ -325,46 +325,59 @@ __NEXT_STEP:;
         if(_ti->m##func){                               \
             TOP() = VAR(_ti->m##func(this, _0, _1));    \
         }else{                                          \
-            TOP() = call_method(_0, func, _1);          \
+            PyObject* self;                                         \
+            _2 = get_unbound_method(_0, func, &self, false);        \
+            if(_2 != nullptr) TOP() = call_method(self, _2, _1);    \
+            else TOP() = NotImplemented;                            \
         }
 
-    TARGET(BINARY_TRUEDIV)
-        if(is_tagged(SECOND())){
-            f64 lhs = num_to_float(SECOND());
-            f64 rhs = num_to_float(TOP());
-            POP();
-            TOP() = VAR(lhs / rhs);
-            DISPATCH();
+#define BINARY_OP_RSPECIAL(op, func)                                \
+        if(TOP() == NotImplemented){                                \
+            PyObject* self;                                         \
+            _2 = get_unbound_method(_1, func, &self, false);        \
+            if(_2 != nullptr) TOP() = call_method(self, _2, _0);    \
+            else BinaryOptError(op);                                \
         }
+
+    TARGET(BINARY_TRUEDIV)
         BINARY_OP_SPECIAL(__truediv__);
+        if(TOP() == NotImplemented) BinaryOptError("/");
         DISPATCH();
     TARGET(BINARY_POW)
         BINARY_OP_SPECIAL(__pow__);
+        if(TOP() == NotImplemented) BinaryOptError("**");
         DISPATCH();
     TARGET(BINARY_ADD)
         PREDICT_INT_OP(+);
         BINARY_OP_SPECIAL(__add__);
+        BINARY_OP_RSPECIAL("+", __radd__);
         DISPATCH()
     TARGET(BINARY_SUB)
         PREDICT_INT_OP(-);
         BINARY_OP_SPECIAL(__sub__);
+        BINARY_OP_RSPECIAL("-", __rsub__);
         DISPATCH()
     TARGET(BINARY_MUL)
         BINARY_OP_SPECIAL(__mul__);
+        BINARY_OP_RSPECIAL("*", __rmul__);
         DISPATCH()
     TARGET(BINARY_FLOORDIV)
         PREDICT_INT_OP(/);
         BINARY_OP_SPECIAL(__floordiv__);
+        if(TOP() == NotImplemented) BinaryOptError("//");
         DISPATCH()
     TARGET(BINARY_MOD)
         PREDICT_INT_OP(%);
         BINARY_OP_SPECIAL(__mod__);
+        if(TOP() == NotImplemented) BinaryOptError("%");
         DISPATCH()
     TARGET(COMPARE_LT)
         BINARY_OP_SPECIAL(__lt__);
+        if(TOP() == NotImplemented) BinaryOptError("<");
         DISPATCH()
     TARGET(COMPARE_LE)
         BINARY_OP_SPECIAL(__le__);
+        if(TOP() == NotImplemented) BinaryOptError("<=");
         DISPATCH()
     TARGET(COMPARE_EQ)
         _1 = POPX();
@@ -378,32 +391,40 @@ __NEXT_STEP:;
         DISPATCH()
     TARGET(COMPARE_GT)
         BINARY_OP_SPECIAL(__gt__);
+        if(TOP() == NotImplemented) BinaryOptError(">");
         DISPATCH()
     TARGET(COMPARE_GE)
         BINARY_OP_SPECIAL(__ge__);
+        if(TOP() == NotImplemented) BinaryOptError(">=");
         DISPATCH()
     TARGET(BITWISE_LSHIFT)
         PREDICT_INT_OP(<<);
         BINARY_OP_SPECIAL(__lshift__);
+        if(TOP() == NotImplemented) BinaryOptError("<<");
         DISPATCH()
     TARGET(BITWISE_RSHIFT)
         PREDICT_INT_OP(>>);
         BINARY_OP_SPECIAL(__rshift__);
+        if(TOP() == NotImplemented) BinaryOptError(">>");
         DISPATCH()
     TARGET(BITWISE_AND)
         PREDICT_INT_OP(&);
         BINARY_OP_SPECIAL(__and__);
+        if(TOP() == NotImplemented) BinaryOptError("&");
         DISPATCH()
     TARGET(BITWISE_OR)
         PREDICT_INT_OP(|);
         BINARY_OP_SPECIAL(__or__);
+        if(TOP() == NotImplemented) BinaryOptError("|");
         DISPATCH()
     TARGET(BITWISE_XOR)
         PREDICT_INT_OP(^);
         BINARY_OP_SPECIAL(__xor__);
+        if(TOP() == NotImplemented) BinaryOptError("^");
         DISPATCH()
     TARGET(BINARY_MATMUL)
         BINARY_OP_SPECIAL(__matmul__);
+        if(TOP() == NotImplemented) BinaryOptError("@");
         DISPATCH();
 
 #undef BINARY_OP_SPECIAL

+ 4 - 0
src/common.h

@@ -133,6 +133,10 @@ inline bool is_both_int(PyObject* a, PyObject* b) noexcept {
     return is_int(a) && is_int(b);
 }
 
+inline bool is_both_float(PyObject* a, PyObject* b) noexcept {
+	return is_float(a) && is_float(b);
+}
+
 // special singals, is_tagged() for them is true
 inline PyObject* const PY_NULL = (PyObject*)0b000011;		// tagged null
 inline PyObject* const PY_OP_CALL = (PyObject*)0b100011;

+ 5 - 1
src/compiler.h

@@ -121,6 +121,7 @@ class Compiler {
         rules[TK("@num")] =     { METHOD(exprLiteral),   NO_INFIX };
         rules[TK("@str")] =     { METHOD(exprLiteral),   NO_INFIX };
         rules[TK("@fstr")] =    { METHOD(exprFString),   NO_INFIX };
+        rules[TK("@long")] =    { METHOD(exprLong),      NO_INFIX };
 #undef METHOD
 #undef NO_INFIX
     }
@@ -194,11 +195,14 @@ class Compiler {
         return expr;
     }
 
-    
     void exprLiteral(){
         ctx()->s_expr.push(make_expr<LiteralExpr>(prev().value));
     }
 
+    void exprLong(){
+        ctx()->s_expr.push(make_expr<LongExpr>(prev().str()));
+    }
+
     void exprFString(){
         ctx()->s_expr.push(make_expr<FStringExpr>(std::get<Str>(prev().value)));
     }

+ 14 - 0
src/expr.h

@@ -259,6 +259,20 @@ struct Literal0Expr: Expr{
     bool is_json_object() const override { return true; }
 };
 
+struct LongExpr: Expr{
+    Str s;
+    LongExpr(const Str& s): s(s) {}
+    std::string str() const override { return s.str(); }
+
+    void emit(CodeEmitContext* ctx) override {
+        VM* vm = ctx->vm;
+        PyObject* long_type = vm->builtins->attr().try_get("long");
+        PK_ASSERT(long_type != nullptr);
+        PyObject* obj = vm->call(long_type, VAR(s));
+        ctx->emit(OP_LOAD_CONST, ctx->add_const(obj), line);
+    }
+};
+
 // @num, @str which needs to invoke OP_LOAD_CONST
 struct LiteralExpr: Expr{
     TokenValue value;

+ 21 - 16
src/lexer.h

@@ -11,7 +11,7 @@ typedef uint8_t TokenIndex;
 constexpr const char* kTokens[] = {
     "is not", "not in", "yield from",
     "@eof", "@eol", "@sof",
-    "@id", "@num", "@str", "@fstr",
+    "@id", "@num", "@str", "@fstr", "@long",
     "@indent", "@dedent",
     /*****************************************/
     "+", "+=", "-", "-=",   // (INPLACE_OP - 1) can get '=' removed
@@ -342,29 +342,34 @@ struct Lexer {
     }
 
     void eat_number() {
-        static const std::regex pattern("^(0x)?[0-9a-fA-F]+(\\.[0-9]+)?");
+        static const std::regex pattern("^(0x)?[0-9a-fA-F]+(\\.[0-9]+)?(L)?");
         std::smatch m;
 
         const char* i = token_start;
         while(*i != '\n' && *i != '\0') i++;
         std::string s = std::string(token_start, i);
 
+        bool ok = std::regex_search(s, m, pattern);
+        PK_ASSERT(ok);
+        // here is m.length()-1, since the first char was eaten by lex_token()
+        for(int j=0; j<m.length()-1; j++) eatchar();
+
+        if(m[3].matched){
+            add_token(TK("@long"));
+            return;
+        }
+
         try{
-            if (std::regex_search(s, m, pattern)) {
-                // here is m.length()-1, since the first char was eaten by lex_token()
-                for(int j=0; j<m.length()-1; j++) eatchar();
-
-                int base = 10;
-                size_t size;
-                if (m[1].matched) base = 16;
-                if (m[2].matched) {
-                    if(base == 16) SyntaxError("hex literal should not contain a dot");
-                    add_token(TK("@num"), Number::stof(m[0], &size));
-                } else {
-                    add_token(TK("@num"), Number::stoi(m[0], &size, base));
-                }
-                if (size != m.length()) FATAL_ERROR();
+            int base = 10;
+            size_t size;
+            if (m[1].matched) base = 16;
+            if (m[2].matched) {
+                if(base == 16) SyntaxError("hex literal should not contain a dot");
+                add_token(TK("@num"), Number::stof(m[0], &size));
+            } else {
+                add_token(TK("@num"), Number::stoi(m[0], &size, base));
             }
+            PK_ASSERT(size == m.length());
         }catch(std::exception& _){
             SyntaxError("invalid number literal");
         } 

+ 6 - 6
src/pocketpy.h

@@ -36,14 +36,14 @@ inline CodeObject_ VM::compile(Str source, Str filename, CompileMode mode, bool
 inline void init_builtins(VM* _vm) {
 #define BIND_NUM_ARITH_OPT(name, op)                                                                    \
     _vm->bind##name(_vm->tp_int, [](VM* vm, PyObject* lhs, PyObject* rhs) {                             \
-        if(is_int(rhs)){                                                                                \
-            return VAR(_CAST(i64, lhs) op _CAST(i64, rhs));                                             \
-        }else{                                                                                          \
-            return VAR(_CAST(i64, lhs) op vm->num_to_float(rhs));                                       \
-        }                                                                                               \
+        if(is_int(rhs)) return VAR(_CAST(i64, lhs) op _CAST(i64, rhs));                                 \
+        if(is_float(rhs)) return VAR(_CAST(i64, lhs) op _CAST(f64, rhs));                               \
+        return vm->NotImplemented;                                                                      \
     });                                                                                                 \
     _vm->bind##name(_vm->tp_float, [](VM* vm, PyObject* lhs, PyObject* rhs) {                           \
-        return VAR(_CAST(f64, lhs) op vm->num_to_float(rhs));                                           \
+        if(is_float(rhs)) return VAR(_CAST(f64, lhs) op _CAST(f64, rhs));                               \
+        if(is_int(rhs)) return VAR(_CAST(f64, lhs) op _CAST(i64, rhs));                                 \
+        return vm->NotImplemented;                                                                      \
     });
 
     BIND_NUM_ARITH_OPT(__add__, +)

+ 3 - 0
src/str.h

@@ -426,8 +426,11 @@ const StrName __ge__ = StrName::get("__ge__");
 const StrName __contains__ = StrName::get("__contains__");
 // binary operators
 const StrName __add__ = StrName::get("__add__");
+const StrName __radd__ = StrName::get("__radd__");
 const StrName __sub__ = StrName::get("__sub__");
+const StrName __rsub__ = StrName::get("__rsub__");
 const StrName __mul__ = StrName::get("__mul__");
+const StrName __rmul__ = StrName::get("__rmul__");
 const StrName __truediv__ = StrName::get("__truediv__");
 const StrName __floordiv__ = StrName::get("__floordiv__");
 const StrName __mod__ = StrName::get("__mod__");

+ 2 - 0
src/vm.h

@@ -512,6 +512,7 @@ public:
     void ValueError(const Str& msg){ _error("ValueError", msg); }
     void NameError(StrName name){ _error("NameError", fmt("name ", name.escape() + " is not defined")); }
     void KeyError(PyObject* obj){ _error("KeyError", OBJ_GET(Str, py_repr(obj))); }
+    void BinaryOptError(const char* op) { TypeError(fmt("unsupported operand type(s) for ", op)); }
 
     void AttributeError(PyObject* obj, StrName name){
         // OBJ_NAME calls getattr, which may lead to a infinite recursion
@@ -1148,6 +1149,7 @@ inline void VM::init_builtin_types(){
     builtins->attr().set("dict", _t(tp_dict));
     builtins->attr().set("property", _t(tp_property));
     builtins->attr().set("StopIteration", StopIteration);
+    builtins->attr().set("NotImplemented", NotImplemented);
     builtins->attr().set("slice", _t(tp_slice));
 
     post_init();

+ 6 - 2
tests/09_long.py

@@ -1,4 +1,4 @@
-assert long(123) == long('123') == 123
+assert long(123) == long('123') == 123L == 123
 
 a = long(2)
 assert a ** 0 == 1
@@ -9,4 +9,8 @@ assert a - 1 == 1
 assert a * 2 == 4
 assert a // 2 == 1
 
-assert -a == -2
+assert -a == -2
+
+assert 1 + a == 3L
+assert 1 - a == -1L
+assert 2 * a == 4L