blueloveTH 2 лет назад
Родитель
Сommit
da760c301c
6 измененных файлов с 110 добавлено и 51 удалено
  1. 48 12
      src/ceval.h
  2. 16 16
      src/expr.h
  3. 19 1
      src/opcodes.h
  4. 8 8
      src/pocketpy.h
  5. 19 8
      src/str.h
  6. 0 6
      src/vm.h

+ 48 - 12
src/ceval.h

@@ -180,18 +180,54 @@ __NEXT_STEP:;
         args[0] = frame->top();     // rhs
         frame->top() = fast_call(BINARY_SPECIAL_METHODS[byte.arg], std::move(args));
     } DISPATCH();
-    case OP_COMPARE_OP: {
-        Args args(2);
-        args[1] = frame->popx();    // lhs
-        args[0] = frame->top();     // rhs
-        frame->top() = fast_call(COMPARE_SPECIAL_METHODS[byte.arg], std::move(args));
-    } DISPATCH();
-    case OP_BITWISE_OP: {
-        Args args(2);
-        args[1] = frame->popx();    // lhs
-        args[0] = frame->top();     // rhs
-        frame->top() = fast_call(BITWISE_SPECIAL_METHODS[byte.arg], std::move(args));
-    } DISPATCH();
+
+#define INT_BINARY_OP(op, func) \
+        if(is_both_int(frame->top(), frame->top_1())){      \
+            i64 b = _CAST(i64, frame->top());               \
+            i64 a = _CAST(i64, frame->top_1());             \
+            frame->pop();                                   \
+            frame->top() = VAR(a op b);                     \
+        }else{                                              \
+            Args args(2);                                   \
+            args[1] = frame->popx();                        \
+            args[0] = frame->top();                         \
+            frame->top() = fast_call(func, std::move(args));\
+        }                                                   \
+        DISPATCH();
+
+    case OP_BINARY_ADD:
+        INT_BINARY_OP(+, __add__)
+    case OP_BINARY_SUB:
+        INT_BINARY_OP(-, __sub__)
+    case OP_BINARY_MUL:
+        INT_BINARY_OP(*, __mul__)
+    case OP_BINARY_FLOORDIV:
+        INT_BINARY_OP(/, __floordiv__)
+    case OP_BINARY_MOD:
+        INT_BINARY_OP(%, __mod__)
+    case OP_COMPARE_LT:
+        INT_BINARY_OP(<, __lt__)
+    case OP_COMPARE_LE:
+        INT_BINARY_OP(<=, __le__)
+    case OP_COMPARE_EQ:
+        INT_BINARY_OP(==, __eq__)
+    case OP_COMPARE_NE:
+        INT_BINARY_OP(!=, __ne__)
+    case OP_COMPARE_GT:
+        INT_BINARY_OP(>, __gt__)
+    case OP_COMPARE_GE:
+        INT_BINARY_OP(>=, __ge__)
+    case OP_BITWISE_LSHIFT:
+        INT_BINARY_OP(<<, __lshift__)
+    case OP_BITWISE_RSHIFT:
+        INT_BINARY_OP(>>, __rshift__)
+    case OP_BITWISE_AND:
+        INT_BINARY_OP(&, __and__)
+    case OP_BITWISE_OR:
+        INT_BINARY_OP(|, __or__)
+    case OP_BITWISE_XOR:
+        INT_BINARY_OP(^, __xor__)
+#undef INT_BINARY_OP
     case OP_IS_OP: {
         PyObject* rhs = frame->popx();
         PyObject* lhs = frame->top();

+ 16 - 16
src/expr.h

@@ -650,30 +650,30 @@ struct BinaryExpr: Expr{
         lhs->emit(ctx);
         rhs->emit(ctx);
         switch (op) {
-            case TK("+"):   ctx->emit(OP_BINARY_OP, 0, line);  break;
-            case TK("-"):   ctx->emit(OP_BINARY_OP, 1, line);  break;
-            case TK("*"):   ctx->emit(OP_BINARY_OP, 2, line);  break;
+            case TK("+"):   ctx->emit(OP_BINARY_ADD, BC_NOARG, line);  break;
+            case TK("-"):   ctx->emit(OP_BINARY_SUB, BC_NOARG, line);  break;
+            case TK("*"):   ctx->emit(OP_BINARY_MUL, BC_NOARG, line);  break;
             case TK("/"):   ctx->emit(OP_BINARY_OP, 3, line);  break;
-            case TK("//"):  ctx->emit(OP_BINARY_OP, 4, line);  break;
-            case TK("%"):   ctx->emit(OP_BINARY_OP, 5, line);  break;
+            case TK("//"):  ctx->emit(OP_BINARY_FLOORDIV, BC_NOARG, line);  break;
+            case TK("%"):   ctx->emit(OP_BINARY_MOD, BC_NOARG, line);  break;
             case TK("**"):  ctx->emit(OP_BINARY_OP, 6, line);  break;
 
-            case TK("<"):   ctx->emit(OP_COMPARE_OP, 0, line);    break;
-            case TK("<="):  ctx->emit(OP_COMPARE_OP, 1, line);    break;
-            case TK("=="):  ctx->emit(OP_COMPARE_OP, 2, line);    break;
-            case TK("!="):  ctx->emit(OP_COMPARE_OP, 3, line);    break;
-            case TK(">"):   ctx->emit(OP_COMPARE_OP, 4, line);    break;
-            case TK(">="):  ctx->emit(OP_COMPARE_OP, 5, line);    break;
+            case TK("<"):   ctx->emit(OP_COMPARE_LT, BC_NOARG, line);  break;
+            case TK("<="):  ctx->emit(OP_COMPARE_LE, BC_NOARG, line);  break;
+            case TK("=="):  ctx->emit(OP_COMPARE_EQ, BC_NOARG, line);  break;
+            case TK("!="):  ctx->emit(OP_COMPARE_NE, BC_NOARG, line);  break;
+            case TK(">"):   ctx->emit(OP_COMPARE_GT, BC_NOARG, line);  break;
+            case TK(">="):  ctx->emit(OP_COMPARE_GE, BC_NOARG, line);  break;
             case TK("in"):      ctx->emit(OP_CONTAINS_OP, 0, line);   break;
             case TK("not in"):  ctx->emit(OP_CONTAINS_OP, 1, line);   break;
             case TK("is"):      ctx->emit(OP_IS_OP, 0, line);         break;
             case TK("is not"):  ctx->emit(OP_IS_OP, 1, line);         break;
 
-            case TK("<<"):  ctx->emit(OP_BITWISE_OP, 0, line);    break;
-            case TK(">>"):  ctx->emit(OP_BITWISE_OP, 1, line);    break;
-            case TK("&"):   ctx->emit(OP_BITWISE_OP, 2, line);    break;
-            case TK("|"):   ctx->emit(OP_BITWISE_OP, 3, line);    break;
-            case TK("^"):   ctx->emit(OP_BITWISE_OP, 4, line);    break;
+            case TK("<<"):  ctx->emit(OP_BITWISE_LSHIFT, BC_NOARG, line);  break;
+            case TK(">>"):  ctx->emit(OP_BITWISE_RSHIFT, BC_NOARG, line);  break;
+            case TK("&"):   ctx->emit(OP_BITWISE_AND, BC_NOARG, line);  break;
+            case TK("|"):   ctx->emit(OP_BITWISE_OR, BC_NOARG, line);  break;
+            case TK("^"):   ctx->emit(OP_BITWISE_XOR, BC_NOARG, line);  break;
             default: UNREACHABLE();
         }
     }

+ 19 - 1
src/opcodes.h

@@ -41,8 +41,26 @@ OPCODE(BUILD_TUPLE)
 OPCODE(BUILD_STRING)
 /**************************/
 OPCODE(BINARY_OP)
+OPCODE(BINARY_ADD)
+OPCODE(BINARY_SUB)
+OPCODE(BINARY_MUL)
+OPCODE(BINARY_FLOORDIV)
+OPCODE(BINARY_MOD)
+
 OPCODE(COMPARE_OP)
-OPCODE(BITWISE_OP)
+OPCODE(COMPARE_LT)
+OPCODE(COMPARE_LE)
+OPCODE(COMPARE_EQ)
+OPCODE(COMPARE_NE)
+OPCODE(COMPARE_GT)
+OPCODE(COMPARE_GE)
+
+OPCODE(BITWISE_LSHIFT)
+OPCODE(BITWISE_RSHIFT)
+OPCODE(BITWISE_AND)
+OPCODE(BITWISE_OR)
+OPCODE(BITWISE_XOR)
+
 OPCODE(IS_OP)
 OPCODE(CONTAINS_OP)
 /**************************/

+ 8 - 8
src/pocketpy.h

@@ -25,23 +25,23 @@ inline CodeObject_ VM::compile(Str source, Str filename, CompileMode mode) {
 }
 
 #define BIND_NUM_ARITH_OPT(name, op)                                                                    \
-    _vm->_bind_methods<1>({"int","float"}, #name, [](VM* vm, Args& args){                         \
+    _vm->_bind_methods<1>({"int","float"}, #name, [](VM* vm, Args& args){                               \
         if(is_both_int(args[0], args[1])){                                                              \
-            return VAR(_CAST(i64, args[0]) op _CAST(i64, args[1]));                     \
+            return VAR(_CAST(i64, args[0]) op _CAST(i64, args[1]));                                     \
         }else{                                                                                          \
-            return VAR(vm->num_to_float(args[0]) op vm->num_to_float(args[1]));                 \
+            return VAR(vm->num_to_float(args[0]) op vm->num_to_float(args[1]));                         \
         }                                                                                               \
     });
 
 #define BIND_NUM_LOGICAL_OPT(name, op, is_eq)                                                           \
-    _vm->_bind_methods<1>({"int","float"}, #name, [](VM* vm, Args& args){                         \
+    _vm->_bind_methods<1>({"int","float"}, #name, [](VM* vm, Args& args){                               \
+        if(is_both_int(args[0], args[1]))                                                               \
+            return VAR(_CAST(i64, args[0]) op _CAST(i64, args[1]));                                     \
         if(!is_both_int_or_float(args[0], args[1])){                                                    \
-            if constexpr(is_eq) return VAR(args[0] op args[1]);                                  \
+            if constexpr(is_eq) return VAR(args[0] op args[1]);                                         \
             vm->TypeError("unsupported operand type(s) for " #op );                                     \
         }                                                                                               \
-        if(is_both_int(args[0], args[1]))                                                               \
-            return VAR(_CAST(i64, args[0]) op _CAST(i64, args[1]));                    \
-        return VAR(vm->num_to_float(args[0]) op vm->num_to_float(args[1]));                      \
+        return VAR(vm->num_to_float(args[0]) op vm->num_to_float(args[1]));                             \
     });
     
 

+ 19 - 8
src/str.h

@@ -404,10 +404,13 @@ const StrName m_add = StrName::get("add");
 const StrName __enter__ = StrName::get("__enter__");
 const StrName __exit__ = StrName::get("__exit__");
 
-const StrName COMPARE_SPECIAL_METHODS[] = {
-    StrName::get("__lt__"), StrName::get("__le__"), StrName::get("__eq__"),
-    StrName::get("__ne__"), StrName::get("__gt__"), StrName::get("__ge__")
-};
+const StrName __add__ = StrName::get("__add__");
+const StrName __sub__ = StrName::get("__sub__");
+const StrName __mul__ = StrName::get("__mul__");
+// const StrName __truediv__ = StrName::get("__truediv__");
+const StrName __floordiv__ = StrName::get("__floordiv__");
+const StrName __mod__ = StrName::get("__mod__");
+// const StrName __pow__ = StrName::get("__pow__");
 
 const StrName BINARY_SPECIAL_METHODS[] = {
     StrName::get("__add__"), StrName::get("__sub__"), StrName::get("__mul__"),
@@ -415,9 +418,17 @@ const StrName BINARY_SPECIAL_METHODS[] = {
     StrName::get("__mod__"), StrName::get("__pow__")
 };
 
-const StrName BITWISE_SPECIAL_METHODS[] = {
-    StrName::get("__lshift__"), StrName::get("__rshift__"),
-    StrName::get("__and__"), StrName::get("__or__"), StrName::get("__xor__")
-};
+const StrName __lt__ = StrName::get("__lt__");
+const StrName __le__ = StrName::get("__le__");
+const StrName __eq__ = StrName::get("__eq__");
+const StrName __ne__ = StrName::get("__ne__");
+const StrName __gt__ = StrName::get("__gt__");
+const StrName __ge__ = StrName::get("__ge__");
+
+const StrName __lshift__ = StrName::get("__lshift__");
+const StrName __rshift__ = StrName::get("__rshift__");
+const StrName __and__ = StrName::get("__and__");
+const StrName __or__ = StrName::get("__or__");
+const StrName __xor__ = StrName::get("__xor__");
 
 } // namespace pkpy

+ 0 - 6
src/vm.h

@@ -592,12 +592,6 @@ inline Str VM::disassemble(CodeObject_ co){
             case OP_BINARY_OP:
                 argStr += fmt(" (", BINARY_SPECIAL_METHODS[byte.arg], ")");
                 break;
-            case OP_COMPARE_OP:
-                argStr += fmt(" (", COMPARE_SPECIAL_METHODS[byte.arg], ")");
-                break;
-            case OP_BITWISE_OP:
-                argStr += fmt(" (", BITWISE_SPECIAL_METHODS[byte.arg], ")");
-                break;
         }
         ss << pad(argStr, 40);      // may overflow
         ss << co->blocks[byte.block].type;