blueloveTH 1 year ago
parent
commit
52b210b016

+ 5 - 0
include/pocketpy/common/strname.h

@@ -39,6 +39,7 @@ extern uint16_t __next__;
 extern uint16_t __neg__;
 // logical operators
 extern uint16_t __eq__;
+extern uint16_t __ne__;
 extern uint16_t __lt__;
 extern uint16_t __le__;
 extern uint16_t __gt__;
@@ -52,9 +53,13 @@ extern uint16_t __rsub__;
 extern uint16_t __mul__;
 extern uint16_t __rmul__;
 extern uint16_t __truediv__;
+extern uint16_t __rtruediv__;
 extern uint16_t __floordiv__;
+extern uint16_t __rfloordiv__;
 extern uint16_t __mod__;
+extern uint16_t __rmod__;
 extern uint16_t __pow__;
+extern uint16_t __rpow__;
 extern uint16_t __matmul__;
 extern uint16_t __lshift__;
 extern uint16_t __rshift__;

+ 4 - 0
include/pocketpy/pocketpy.h

@@ -186,10 +186,14 @@ py_Error* py_lasterror();
 void py_Error__print(py_Error*);
 
 /************* Operators *************/
+bool py_bool(const py_Ref);
 int py_eq(const py_Ref, const py_Ref);
 int py_le(const py_Ref, const py_Ref);
 bool py_hash(const py_Ref, int64_t* out);
 
+/// Compare two objects without using magic methods.
+bool py_isidentical(const py_Ref, const py_Ref);
+
 /// A stack operation that calls a function.
 /// It assumes `argc + kwargc` arguments are already pushed to the stack.
 /// The result will be set to `vm->last_retval`.

+ 1 - 11
include/pocketpy/xmacros/opcodes.h

@@ -53,18 +53,8 @@ OPCODE(BUILD_SLICE)
 OPCODE(BUILD_STRING)
 /**************************/
 OPCODE(BINARY_OP)
-
-OPCODE(COMPARE_LT)
-OPCODE(COMPARE_LE)
-OPCODE(COMPARE_EQ)
-OPCODE(COMPARE_NE)
-OPCODE(COMPARE_GT)
-OPCODE(COMPARE_GE)
-
 OPCODE(IS_OP)
-OPCODE(IS_NOT_OP)
-OPCODE(IN_OP)
-OPCODE(NOT_IN_OP)
+OPCODE(CONTAINS_OP)
 /**************************/
 OPCODE(JUMP_FORWARD)
 OPCODE(POP_JUMP_IF_FALSE)

+ 15 - 5
src/common/strname.c

@@ -29,6 +29,7 @@ void pk_StrName__initialize() {
     __neg__ = pk_StrName__map("__neg__");
     // logical operators
     __eq__ = pk_StrName__map("__eq__");
+    __ne__ = pk_StrName__map("__ne__");
     __lt__ = pk_StrName__map("__lt__");
     __le__ = pk_StrName__map("__le__");
     __gt__ = pk_StrName__map("__gt__");
@@ -42,9 +43,13 @@ void pk_StrName__initialize() {
     __mul__ = pk_StrName__map("__mul__");
     __rmul__ = pk_StrName__map("__rmul__");
     __truediv__ = pk_StrName__map("__truediv__");
+    __rtruediv__ = pk_StrName__map("__rtruediv__");
     __floordiv__ = pk_StrName__map("__floordiv__");
+    __rfloordiv__ = pk_StrName__map("__rfloordiv__");
     __mod__ = pk_StrName__map("__mod__");
+    __rmod__ = pk_StrName__map("__rmod__");
     __pow__ = pk_StrName__map("__pow__");
+    __rpow__ = pk_StrName__map("__rpow__");
     __matmul__ = pk_StrName__map("__matmul__");
     __lshift__ = pk_StrName__map("__lshift__");
     __rshift__ = pk_StrName__map("__rshift__");
@@ -71,15 +76,15 @@ void pk_StrName__initialize() {
     __class__ = pk_StrName__map("__class__");
     __missing__ = pk_StrName__map("__missing__");
 
+    // print all names
+    for(int i = 0; i < _interned.count; i++) {
+        printf("%d: %s\n", i+1, c11__getitem(char*, &_r_interned, i));
+    }
+
     pk_id_add = pk_StrName__map("add");
     pk_id_set = pk_StrName__map("set");
     pk_id_long = pk_StrName__map("long");
     pk_id_complex = pk_StrName__map("complex");
-
-    // // print all names
-    // for(int i = 0; i < _interned.count; i++) {
-    //     printf("%d: %s\n", i+1, c11__getitem(char*, &_r_interned, i));
-    // }
 }
 
 void pk_StrName__finalize() {
@@ -138,6 +143,7 @@ uint16_t __next__;
 uint16_t __neg__;
 // logical operators
 uint16_t __eq__;
+uint16_t __ne__;
 uint16_t __lt__;
 uint16_t __le__;
 uint16_t __gt__;
@@ -151,9 +157,13 @@ uint16_t __rsub__;
 uint16_t __mul__;
 uint16_t __rmul__;
 uint16_t __truediv__;
+uint16_t __rtruediv__;
 uint16_t __floordiv__;
+uint16_t __rfloordiv__;
 uint16_t __mod__;
+uint16_t __rmod__;
 uint16_t __pow__;
+uint16_t __rpow__;
 uint16_t __matmul__;
 uint16_t __lshift__;
 uint16_t __rshift__;

+ 46 - 30
src/compiler/compiler.c

@@ -849,19 +849,19 @@ static void BinaryExpr__dtor(Expr* self_) {
     vtdelete(self->rhs);
 }
 
-static Opcode cmp_token2op(TokenIndex token) {
+static Opcode cmp_token2name(TokenIndex token) {
     switch(token) {
-        case TK_LT: return OP_COMPARE_LT; break;
-        case TK_LE: return OP_COMPARE_LE; break;
-        case TK_EQ: return OP_COMPARE_EQ; break;
-        case TK_NE: return OP_COMPARE_NE; break;
-        case TK_GT: return OP_COMPARE_GT; break;
-        case TK_GE: return OP_COMPARE_GE; break;
-        default: return OP_NO_OP;  // 0
+        case TK_LT: return __lt__;
+        case TK_LE: return __le__;
+        case TK_EQ: return __eq__;
+        case TK_NE: return __ne__;
+        case TK_GT: return __gt__;
+        case TK_GE: return __ge__;
+        default: return 0;
     }
 }
 
-#define is_compare_expr(e) ((e)->vt->is_binary && cmp_token2op(((BinaryExpr*)(e))->op))
+#define is_compare_expr(e) ((e)->vt->is_binary && cmp_token2name(((BinaryExpr*)(e))->op))
 
 static void _emit_compare(BinaryExpr* self, Ctx* ctx, c11_vector* jmps) {
     if(is_compare_expr(self->lhs)) {
@@ -872,8 +872,7 @@ static void _emit_compare(BinaryExpr* self, Ctx* ctx, c11_vector* jmps) {
     vtemit_(self->rhs, ctx);                              // [a, b]
     Ctx__emit_(ctx, OP_DUP_TOP, BC_NOARG, self->line);    // [a, b, b]
     Ctx__emit_(ctx, OP_ROT_THREE, BC_NOARG, self->line);  // [b, a, b]
-    Opcode opcode = cmp_token2op(self->op);
-    Ctx__emit_(ctx, opcode, BC_NOARG, self->line);
+    Ctx__emit_(ctx, OP_BINARY_OP, cmp_token2name(self->op), self->line);
     // [b, RES]
     int index = Ctx__emit_(ctx, OP_JUMP_IF_FALSE_OR_POP, BC_NOARG, self->line);
     c11_vector__push(int, jmps, index);
@@ -883,7 +882,7 @@ static void BinaryExpr__emit_(Expr* self_, Ctx* ctx) {
     BinaryExpr* self = (BinaryExpr*)self_;
     c11_vector /*T=int*/ jmps;
     c11_vector__ctor(&jmps, sizeof(int));
-    if(cmp_token2op(self->op) && is_compare_expr(self->lhs)) {
+    if(cmp_token2name(self->op) && is_compare_expr(self->lhs)) {
         // (a < b) < c
         BinaryExpr* e = (BinaryExpr*)self->lhs;
         _emit_compare(e, ctx, &jmps);
@@ -906,22 +905,34 @@ static void BinaryExpr__emit_(Expr* self_, Ctx* ctx) {
         case TK_ADD: arg = __add__ | (__radd__ << 8); break;
         case TK_SUB: arg = __sub__ | (__rsub__ << 8); break;
         case TK_MUL: arg = __mul__ | (__rmul__ << 8); break;
-        case TK_DIV: arg = __truediv__; break;
-        case TK_FLOORDIV: arg = __floordiv__; break;
-        case TK_MOD: arg = __mod__; break;
-        case TK_POW: arg = __pow__; break;
-
-        case TK_LT: opcode = OP_COMPARE_LT; break;
-        case TK_LE: opcode = OP_COMPARE_LE; break;
-        case TK_EQ: opcode = OP_COMPARE_EQ; break;
-        case TK_NE: opcode = OP_COMPARE_NE; break;
-        case TK_GT: opcode = OP_COMPARE_GT; break;
-        case TK_GE: opcode = OP_COMPARE_GE; break;
-
-        case TK_IN: opcode = OP_IN_OP; break;
-        case TK_NOT_IN: opcode = OP_NOT_IN_OP; break;
-        case TK_IS: opcode = OP_IS_OP; break;
-        case TK_IS_NOT: opcode = OP_IS_NOT_OP; break;
+        case TK_DIV: arg = __truediv__ | (__rtruediv__ << 8); break;
+        case TK_FLOORDIV: arg = __floordiv__ | (__rfloordiv__ << 8); break;
+        case TK_MOD: arg = __mod__ | (__rmod__ << 8); break;
+        case TK_POW: arg = __pow__ | (__rpow__ << 8); break;
+
+        case TK_LT: arg = __lt__ | (__gt__ << 8); break;
+        case TK_LE: arg = __le__ | (__ge__ << 8); break;
+        case TK_EQ: arg = __eq__ | (__eq__ << 8); break;
+        case TK_NE: arg = __ne__ | (__ne__ << 8); break;
+        case TK_GT: arg = __gt__ | (__lt__ << 8); break;
+        case TK_GE: arg = __ge__ | (__le__ << 8); break;
+
+        case TK_IN:
+            opcode = OP_CONTAINS_OP;
+            arg = 0;
+            break;
+        case TK_NOT_IN:
+            opcode = OP_CONTAINS_OP;
+            arg = 1;
+            break;
+        case TK_IS:
+            opcode = OP_IS_OP;
+            arg = 0;
+            break;
+        case TK_IS_NOT:
+            opcode = OP_IS_OP;
+            arg = 1;
+            break;
 
         case TK_LSHIFT: arg = __lshift__; break;
         case TK_RSHIFT: arg = __rshift__; break;
@@ -1734,8 +1745,13 @@ static Error* exprBinaryOp(Compiler* self) {
     TokenIndex op = prev()->type;
     check(parse_expression(self, rules[op].precedence + 1, false));
     BinaryExpr* e = BinaryExpr__new(line, op, false);
-    e->rhs = Ctx__s_popx(ctx());
-    e->lhs = Ctx__s_popx(ctx());
+    if(op == TK_IN || op == TK_NOT_IN) {
+        e->lhs = Ctx__s_popx(ctx());
+        e->rhs = Ctx__s_popx(ctx());
+    } else {
+        e->rhs = Ctx__s_popx(ctx());
+        e->lhs = Ctx__s_popx(ctx());
+    }
     Ctx__s_push(ctx(), (Expr*)e);
     return NULL;
 }

+ 92 - 1
src/interpreter/ceval.c

@@ -545,10 +545,101 @@ pk_FrameResult pk_VM__run_top_frame(pk_VM* self) {
                         }
                     }
                 }
+                // eq/ne op never fails
+                if(op == __eq__ || op == __ne__) {
+                    POP();
+                    *TOP() = (op == __eq__) ? self->False : self->True;
+                    DISPATCH();
+                }
                 BinaryOptError(byte.arg);
                 goto __ERROR;
             }
-
+            case OP_IS_OP: {
+                bool res = py_isidentical(SECOND(), TOP());
+                POP();
+                if(byte.arg) res = !res;
+                *TOP() = res ? self->True : self->False;
+                DISPATCH();
+            }
+            case OP_CONTAINS_OP: {
+                // [b, a] -> b __contains__ a (a in b)
+                py_Ref magic = py_tpfindmagic(SECOND()->type, __contains__);
+                if(magic) {
+                    if(magic->type == tp_nativefunc) {
+                        bool ok = magic->_cfunc(2, SECOND(), SECOND());
+                        if(!ok) goto __ERROR;
+                        POP();
+                        *TOP() = self->last_retval;
+                    } else {
+                        INSERT_THIRD();     // [?, b, a]
+                        *THIRD() = *magic;  // [__contains__, a, b]
+                        vectorcall_opcall(2);
+                    }
+                    bool res = py_tobool(TOP());
+                    if(byte.arg) py_newbool(TOP(), !res);
+                    DISPATCH();
+                }
+                TypeError();
+                goto __ERROR;
+            }
+                /*****************************************/
+            case OP_JUMP_FORWARD: DISPATCH_JUMP((int16_t)byte.arg);
+            case OP_POP_JUMP_IF_FALSE: {
+                bool res = py_bool(TOP());
+                POP();
+                if(!res) DISPATCH_JUMP((int16_t)byte.arg);
+                DISPATCH();
+            }
+            case OP_POP_JUMP_IF_TRUE: {
+                bool res = py_bool(TOP());
+                POP();
+                if(res) DISPATCH_JUMP((int16_t)byte.arg);
+                DISPATCH();
+            }
+            case OP_JUMP_IF_TRUE_OR_POP:
+                if(py_bool(TOP())) {
+                    DISPATCH_JUMP((int16_t)byte.arg);
+                } else {
+                    POP();
+                    DISPATCH();
+                }
+            case OP_JUMP_IF_FALSE_OR_POP:
+                if(!py_bool(TOP())) {
+                    DISPATCH_JUMP((int16_t)byte.arg);
+                } else {
+                    POP();
+                    DISPATCH();
+                }
+            case OP_SHORTCUT_IF_FALSE_OR_POP:
+                if(!py_bool(TOP())) {    // [b, False]
+                    STACK_SHRINK(2);     // []
+                    PUSH(&self->False);  // [False]
+                    DISPATCH_JUMP((int16_t)byte.arg);
+                } else {
+                    POP();  // [b]
+                    DISPATCH();
+                }
+            case OP_LOOP_CONTINUE:
+                // just an alias of OP_JUMP_FORWARD
+                DISPATCH_JUMP((int16_t)byte.arg);
+            case OP_LOOP_BREAK: {
+                int target = Frame__ip(frame) + byte.arg;
+                Frame__prepare_jump_break(frame, &self->stack, target);
+                DISPATCH_JUMP((int16_t)byte.arg);
+            }
+            case OP_JUMP_ABSOLUTE_TOP: {
+                int target = py_toint(TOP());
+                POP();
+                DISPATCH_JUMP_ABSOLUTE(target);
+            }
+                // case OP_GOTO: {
+                //     StrName _name(byte.arg);
+                //     int target = c11_smallmap_n2i__get(&frame->co->labels, byte.arg, -1);
+                //     if(target < 0) RuntimeError(_S("label ", _name.escape(), " not found"));
+                //     frame->prepare_jump_break(&s_data, target);
+                //     DISPATCH_JUMP_ABSOLUTE(target)
+                // }
+                /*****************************************/
             case OP_RETURN_VALUE: {
                 self->last_retval = byte.arg == BC_NOARG ? POPX() : self->None;
                 pk_VM__pop_frame(self);

+ 15 - 12
src/interpreter/py_number.c

@@ -39,17 +39,17 @@
 //     return 0;
 // }
 
-#define DEF_NUM_BINARY_OP(name, op)                                                                \
+#define DEF_NUM_BINARY_OP(name, op, rint, rfloat)                                                  \
     static bool _py_int##name(int argc, py_Ref argv, py_Ref out) {                                 \
         py_checkargc(2);                                                                           \
         if(py_isint(&argv[1])) {                                                                   \
             int64_t lhs = py_toint(&argv[0]);                                                      \
             int64_t rhs = py_toint(&argv[1]);                                                      \
-            py_newint(out, lhs op rhs);                                                            \
+            rint(out, lhs op rhs);                                                                 \
         } else if(py_isfloat(&argv[1])) {                                                          \
             int64_t lhs = py_toint(&argv[0]);                                                      \
             double rhs = py_tofloat(&argv[1]);                                                     \
-            py_newfloat(out, lhs op rhs);                                                          \
+            rfloat(out, lhs op rhs);                                                               \
         } else {                                                                                   \
             py_newnotimplemented(out);                                                             \
         }                                                                                          \
@@ -60,22 +60,23 @@
         double lhs = py_tofloat(&argv[0]);                                                         \
         double rhs;                                                                                \
         if(py_castfloat(&argv[1], &rhs)) {                                                         \
-            py_newfloat(out, lhs op rhs);                                                          \
+            rfloat(out, lhs op rhs);                                                               \
         } else {                                                                                   \
             py_newnotimplemented(out);                                                             \
         }                                                                                          \
         return true;                                                                               \
     }
 
-DEF_NUM_BINARY_OP(__add__, +)
-DEF_NUM_BINARY_OP(__sub__, -)
-DEF_NUM_BINARY_OP(__mul__, *)
+DEF_NUM_BINARY_OP(__add__, +, py_newint, py_newfloat)
+DEF_NUM_BINARY_OP(__sub__, -, py_newint, py_newfloat)
+DEF_NUM_BINARY_OP(__mul__, *, py_newint, py_newfloat)
 
-DEF_NUM_BINARY_OP(__eq__, ==)
-DEF_NUM_BINARY_OP(__lt__, <)
-DEF_NUM_BINARY_OP(__le__, <=)
-DEF_NUM_BINARY_OP(__gt__, >)
-DEF_NUM_BINARY_OP(__ge__, >=)
+DEF_NUM_BINARY_OP(__eq__, ==, py_newbool, py_newbool)
+DEF_NUM_BINARY_OP(__ne__, ==, py_newbool, py_newbool)
+DEF_NUM_BINARY_OP(__lt__, <, py_newbool, py_newbool)
+DEF_NUM_BINARY_OP(__le__, <=, py_newbool, py_newbool)
+DEF_NUM_BINARY_OP(__gt__, >, py_newbool, py_newbool)
+DEF_NUM_BINARY_OP(__ge__, >=, py_newbool, py_newbool)
 
 #undef DEF_NUM_BINARY_OP
 
@@ -229,6 +230,8 @@ void pk_VM__init_builtins(pk_VM* self) {
 
     py_bindmagic(tp_int, __eq__, _py_int__eq__);
     py_bindmagic(tp_float, __eq__, _py_float__eq__);
+    py_bindmagic(tp_int, __ne__, _py_int__ne__);
+    py_bindmagic(tp_float, __ne__, _py_float__ne__);
     py_bindmagic(tp_int, __lt__, _py_int__lt__);
     py_bindmagic(tp_float, __lt__, _py_float__lt__);
     py_bindmagic(tp_int, __le__, _py_int__le__);

+ 9 - 0
src/public/py_ops.c

@@ -5,6 +5,15 @@ int py_eq(const py_Ref lhs, const py_Ref rhs) { return 0; }
 
 int py_le(const py_Ref lhs, const py_Ref rhs) { return 0; }
 
+bool py_isidentical(const py_Ref lhs, const py_Ref rhs){
+    if(lhs->is_ptr && rhs->is_ptr){
+        return lhs->_obj == rhs->_obj;
+    }
+    return false;
+}
+
+bool py_bool(const py_Ref val) { return 0; }
+
 bool py_hash(const py_Ref val, int64_t* out) { return 0; }
 
 bool py_getattr(const py_Ref self, py_Name name, py_Ref out) { return true; }

+ 3 - 6
src2/main.c

@@ -25,7 +25,7 @@ int main(int argc, char** argv) {
 #endif
 
     py_initialize();
-    const char* source = "[1/2, 'a']";
+    const char* source = "1 < 2";
 
     py_Ref r0 = py_reg(0);
     if(py_eval(source, r0)){
@@ -33,11 +33,8 @@ int main(int argc, char** argv) {
         py_Error__print(err);
     }else{
         // handle the result
-        py_Ref _0 = py_list__getitem(r0, 0);
-        py_Ref _1 = py_list__getitem(r0, 1);
-        float _L0 = py_tofloat(_0);
-        const char* _L1 = py_tostr(_1);
-        printf("%f, %s\n", _L0, _L1);
+        bool _L0 = py_tobool(r0);
+        printf("%d\n", _L0);
     }
 
     py_finalize();