Ver Fonte

add `~` operator

BLUELOVETH há 2 anos atrás
pai
commit
d5fc2c8686

+ 2 - 0
include/linalg.pyi

@@ -96,6 +96,8 @@ class mat3x3:
     def transpose(self) -> mat3x3: ...
     def inverse(self) -> mat3x3: ...
 
+    def __invert__(self) -> mat3x3: ...
+
     @staticmethod
     def zeros() -> mat3x3: ...
     @staticmethod

+ 7 - 0
include/pocketpy/expr.h

@@ -78,6 +78,13 @@ struct NameExpr: Expr{
     bool emit_store(CodeEmitContext* ctx) override;
 };
 
+struct InvertExpr: Expr{
+    Expr_ child;
+    InvertExpr(Expr_&& child): child(std::move(child)) {}
+    std::string str() const override { return "Invert()"; }
+    void emit(CodeEmitContext* ctx) override;
+};
+
 struct StarredExpr: Expr{
     int level;
     Expr_ child;

+ 2 - 2
include/pocketpy/lexer.h

@@ -21,7 +21,7 @@ constexpr const char* kTokens[] = {
     /*****************************************/
     ".", ",", ":", ";", "#", "(", ")", "[", "]", "{", "}",
     "**", "=", ">", "<", "...", "->", "?", "@", "==", "!=", ">=", "<=",
-    "++", "--",
+    "++", "--", "~",
     /** SPEC_BEGIN **/
     "$goto", "$label",
     /** KW_BEGIN **/
@@ -92,7 +92,7 @@ enum Precedence {
   PREC_BITWISE_SHIFT, // << >>
   PREC_TERM,          // + -
   PREC_FACTOR,        // * / % // @
-  PREC_UNARY,         // - not
+  PREC_UNARY,         // - not ~
   PREC_EXPONENT,      // **
   PREC_CALL,          // ()
   PREC_SUBSCRIPT,     // []

+ 1 - 0
include/pocketpy/opcodes.h

@@ -99,6 +99,7 @@ OPCODE(SET_ADD)
 OPCODE(UNARY_NEGATIVE)
 OPCODE(UNARY_NOT)
 OPCODE(UNARY_STAR)
+OPCODE(UNARY_INVERT)
 /**************************/
 OPCODE(GET_ITER)
 OPCODE(FOR_ITER)

+ 1 - 0
include/pocketpy/str.h

@@ -281,6 +281,7 @@ const StrName __rshift__ = StrName::get("__rshift__");
 const StrName __and__ = StrName::get("__and__");
 const StrName __or__ = StrName::get("__or__");
 const StrName __xor__ = StrName::get("__xor__");
+const StrName __invert__ = StrName::get("__invert__");
 // indexer
 const StrName __getitem__ = StrName::get("__getitem__");
 const StrName __setitem__ = StrName::get("__setitem__");

+ 2 - 0
include/pocketpy/vm.h

@@ -65,6 +65,7 @@ struct PyTypeInfo{
     PyObject* (*m__json__)(VM* vm, PyObject*) = nullptr;
     PyObject* (*m__neg__)(VM* vm, PyObject*) = nullptr;
     PyObject* (*m__bool__)(VM* vm, PyObject*) = nullptr;
+    PyObject* (*m__invert__)(VM* vm, PyObject*) = nullptr;
 
     BinaryFuncC m__eq__ = nullptr;
     BinaryFuncC m__lt__ = nullptr;
@@ -221,6 +222,7 @@ public:
     BIND_UNARY_SPECIAL(__json__)
     BIND_UNARY_SPECIAL(__neg__)
     BIND_UNARY_SPECIAL(__bool__)
+    BIND_UNARY_SPECIAL(__invert__)
 
     void bind__hash__(Type type, i64 (*f)(VM* vm, PyObject*));
     void bind__len__(Type type, i64 (*f)(VM* vm, PyObject*));

+ 6 - 0
src/ceval.cpp

@@ -566,6 +566,12 @@ __NEXT_STEP:;
     TARGET(UNARY_STAR)
         TOP() = VAR(StarWrapper(byte.arg, TOP()));
         DISPATCH();
+    TARGET(UNARY_INVERT)
+        _ti = _inst_type_info(TOP());
+        if(_ti->m__invert__) _0 = _ti->m__invert__(this, TOP());
+        else _0 = call_method(TOP(), __invert__);
+        TOP() = _0;
+        DISPATCH();
     /*****************************************/
     TARGET(GET_ITER)
         TOP() = py_iter(TOP());

+ 4 - 0
src/compiler.cpp

@@ -52,6 +52,7 @@ namespace pkpy{
         rules[TK("+")] =        { nullptr,               METHOD(exprBinaryOp),       PREC_TERM };
         rules[TK("-")] =        { METHOD(exprUnaryOp),   METHOD(exprBinaryOp),       PREC_TERM };
         rules[TK("*")] =        { METHOD(exprUnaryOp),   METHOD(exprBinaryOp),       PREC_FACTOR };
+        rules[TK("~")] =        { METHOD(exprUnaryOp),   nullptr,                    PREC_UNARY };
         rules[TK("/")] =        { nullptr,               METHOD(exprBinaryOp),       PREC_FACTOR };
         rules[TK("//")] =       { nullptr,               METHOD(exprBinaryOp),       PREC_FACTOR };
         rules[TK("**")] =       { METHOD(exprUnaryOp),   METHOD(exprBinaryOp),       PREC_EXPONENT };
@@ -241,6 +242,9 @@ namespace pkpy{
             case TK("-"):
                 ctx()->s_expr.push(make_expr<NegatedExpr>(ctx()->s_expr.popx()));
                 break;
+            case TK("~"):
+                ctx()->s_expr.push(make_expr<InvertExpr>(ctx()->s_expr.popx()));
+                break;
             case TK("*"):
                 ctx()->s_expr.push(make_expr<StarredExpr>(1, ctx()->s_expr.popx()));
                 break;

+ 4 - 0
src/expr.cpp

@@ -139,6 +139,10 @@ namespace pkpy{
         return true;
     }
 
+    void InvertExpr::emit(CodeEmitContext* ctx) {
+        child->emit(ctx);
+        ctx->emit(OP_UNARY_INVERT, BC_NOARG, line);
+    }
 
     void StarredExpr::emit(CodeEmitContext* ctx) {
         child->emit(ctx);

+ 1 - 0
src/lexer.cpp

@@ -286,6 +286,7 @@ namespace pkpy{
             switch (c) {
                 case '\'': case '"': eat_string(c, NORMAL_STRING); return true;
                 case '#': skip_line_comment(); break;
+                case '~': add_token(TK("~")); return true;
                 case '{': add_token(TK("{")); return true;
                 case '}': add_token(TK("}")); return true;
                 case ',': add_token(TK(",")); return true;

+ 8 - 0
src/linalg.cpp

@@ -375,6 +375,14 @@ namespace pkpy{
             return VAR_T(PyMat3x3, ret);
         });
 
+        vm->bind__invert__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* obj){
+            PyMat3x3& self = _CAST(PyMat3x3&, obj);
+            Mat3x3 ret;
+            bool ok = self.inverse(ret);
+            if(!ok) vm->ValueError("matrix is not invertible");
+            return VAR_T(PyMat3x3, ret);
+        });
+
         vm->bind_func<0>(type, "zeros", [](VM* vm, ArgsView args){
             PK_UNUSED(args);
             return VAR_T(PyMat3x3, Mat3x3::zeros());

+ 2 - 0
src/pocketpy.cpp

@@ -343,6 +343,8 @@ void init_builtins(VM* _vm) {
 
     _vm->bind__hash__(_vm->tp_int, [](VM* vm, PyObject* obj) { return _CAST(i64, obj); });
 
+    _vm->bind__invert__(_vm->tp_int, [](VM* vm, PyObject* obj) { return VAR(~_CAST(i64, obj)); });
+
 #define INT_BITWISE_OP(name, op) \
     _vm->bind##name(_vm->tp_int, [](VM* vm, PyObject* lhs, PyObject* rhs) { \
         return VAR(_CAST(i64, lhs) op CAST(i64, rhs)); \

+ 5 - 1
tests/01_int.py

@@ -58,4 +58,8 @@ assert 7**21 == 558545864083284007
 assert 2**60 == 1152921504606846976
 assert -2**60 == -1152921504606846976
 assert 4**13 == 67108864
-assert (-4)**13 == -67108864
+assert (-4)**13 == -67108864
+
+assert ~3 == -4
+assert ~-3 == 2
+assert ~0 == -1