فهرست منبع

add function unpack call

blueloveTH 3 سال پیش
والد
کامیت
5e15d526a0
6فایلهای تغییر یافته به همراه79 افزوده شده و 6 حذف شده
  1. 19 3
      src/ceval.h
  2. 10 2
      src/compiler.h
  3. 6 0
      src/obj.h
  4. 4 0
      src/opcodes.h
  5. 19 1
      src/vm.h
  6. 21 0
      tests/_star.py

+ 19 - 3
src/ceval.h

@@ -209,17 +209,33 @@ PyVar VM::run_frame(Frame* frame){
             frame->push(obj);
         } continue;
         case OP_DUP_TOP_VALUE: frame->push(frame->top_value(this)); continue;
-        case OP_CALL: {
+        case OP_UNARY_STAR: {
+            if(byte.arg > 0){   // rvalue
+                frame->top() = PyStarWrapper({frame->top_value(this), true});
+            }else{
+                PyRef_AS_C(frame->top()); // check ref
+                frame->top() = PyStarWrapper({frame->top(), false});
+            }
+        } continue;
+        case OP_CALL_KWARGS_UNPACK: case OP_CALL_KWARGS: {
             int ARGC = byte.arg & 0xFFFF;
             int KWARGC = (byte.arg >> 16) & 0xFFFF;
-            pkpy::Args kwargs(0);
-            if(KWARGC > 0) kwargs = frame->pop_n_values_reversed(this, KWARGC*2);
+            pkpy::Args kwargs = frame->pop_n_values_reversed(this, KWARGC*2);
             pkpy::Args args = frame->pop_n_values_reversed(this, ARGC);
+            if(byte.op == OP_CALL_KWARGS_UNPACK) unpack_args(args);
             PyVar callable = frame->pop_value(this);
             PyVar ret = call(callable, std::move(args), kwargs, true);
             if(ret == _py_op_call) return ret;
             frame->push(std::move(ret));
         } continue;
+        case OP_CALL_UNPACK: case OP_CALL: {
+            pkpy::Args args = frame->pop_n_values_reversed(this, byte.arg);
+            if(byte.op == OP_CALL_UNPACK) unpack_args(args);
+            PyVar callable = frame->pop_value(this);
+            PyVar ret = call(callable, std::move(args), pkpy::no_arg(), true);
+            if(ret == _py_op_call) return ret;
+            frame->push(std::move(ret));
+        } continue;
         case OP_JUMP_ABSOLUTE: frame->jump_abs(byte.arg); continue;
         case OP_SAFE_JUMP_ABSOLUTE: frame->jump_abs_safe(byte.arg); continue;
         case OP_GOTO: {

+ 10 - 2
src/compiler.h

@@ -500,7 +500,7 @@ private:
         switch (op) {
             case TK("-"):     emit(OP_UNARY_NEGATIVE); break;
             case TK("not"):   emit(OP_UNARY_NOT);      break;
-            case TK("*"):     SyntaxError("cannot use '*' as unary operator"); break;
+            case TK("*"):     emit(OP_UNARY_STAR, co()->_rvalue);   break;
             default: UNREACHABLE();
         }
     }
@@ -594,6 +594,7 @@ __LISTCOMP:
     void exprCall() {
         int ARGC = 0;
         int KWARGC = 0;
+        bool need_unpack = false;
         do {
             match_newlines(mode()==REPL_MODE);
             if (peek() == TK(")")) break;
@@ -607,12 +608,19 @@ __LISTCOMP:
             } else{
                 if(KWARGC > 0) SyntaxError("positional argument follows keyword argument");
                 co()->_rvalue += 1; EXPR(); co()->_rvalue -= 1;
+                if(co()->codes.back().op == OP_UNARY_STAR) need_unpack = true;
                 ARGC++;
             }
             match_newlines(mode()==REPL_MODE);
         } while (match(TK(",")));
         consume(TK(")"));
-        emit(OP_CALL, (KWARGC << 16) | ARGC);
+        if(ARGC > 32767) SyntaxError("too many positional arguments");
+        if(KWARGC > 32767) SyntaxError("too many keyword arguments");
+        if(KWARGC > 0){
+            emit(need_unpack ? OP_CALL_KWARGS_UNPACK : OP_CALL_KWARGS, (KWARGC << 16) | ARGC);
+        }else{
+            emit(need_unpack ? OP_CALL_UNPACK : OP_CALL, ARGC);
+        }
     }
 
     void exprName(){ _exprName(false); }

+ 6 - 0
src/obj.h

@@ -52,6 +52,12 @@ struct Range {
     i64 step = 1;
 };
 
+struct StarWrapper {
+    PyVar obj;
+    bool rvalue;
+    StarWrapper(const PyVar& obj, bool rvalue): obj(obj), rvalue(rvalue) {}
+};
+
 struct Slice {
     int start = 0;
     int stop = 0x7fffffff; 

+ 4 - 0
src/opcodes.h

@@ -4,6 +4,9 @@ OPCODE(NO_OP)
 OPCODE(POP_TOP)
 OPCODE(DUP_TOP_VALUE)
 OPCODE(CALL)
+OPCODE(CALL_UNPACK)
+OPCODE(CALL_KWARGS)
+OPCODE(CALL_KWARGS_UNPACK)
 OPCODE(RETURN_VALUE)
 
 OPCODE(BINARY_OP)
@@ -14,6 +17,7 @@ OPCODE(CONTAINS_OP)
 
 OPCODE(UNARY_NEGATIVE)
 OPCODE(UNARY_NOT)
+OPCODE(UNARY_STAR)
 
 OPCODE(BUILD_LIST)
 OPCODE(BUILD_MAP)

+ 19 - 1
src/vm.h

@@ -519,7 +519,7 @@ public:
     Type tp_list, tp_tuple;
     Type tp_function, tp_native_function, tp_native_iterator, tp_bound_method;
     Type tp_slice, tp_range, tp_module, tp_ref;
-    Type tp_super, tp_exception;
+    Type tp_super, tp_exception, tp_star_wrapper;
 
     template<typename P>
     inline PyVarRef PyRef(P&& value) {
@@ -592,6 +592,7 @@ public:
     DEF_NATIVE(Range, pkpy::Range, tp_range)
     DEF_NATIVE(Slice, pkpy::Slice, tp_slice)
     DEF_NATIVE(Exception, pkpy::Exception, tp_exception)
+    DEF_NATIVE(StarWrapper, pkpy::StarWrapper, tp_star_wrapper)
     
     // there is only one True/False, so no need to copy them!
     inline bool PyBool_AS_C(const PyVar& obj){return obj == True;}
@@ -619,6 +620,7 @@ public:
         tp_range = _new_type_object("range");
         tp_module = _new_type_object("module");
         tp_ref = _new_type_object("_ref");
+        tp_star_wrapper = _new_type_object("_star_wrapper");
         
         tp_function = _new_type_object("function");
         tp_native_function = _new_type_object("native_function");
@@ -743,6 +745,22 @@ public:
         return OBJ_GET(T, obj);
     }
 
+    void unpack_args(pkpy::Args& args){
+        pkpy::List unpacked;
+        for(int i=0; i<args.size(); i++){
+            if(is_type(args[i], tp_star_wrapper)){
+                auto& star = PyStarWrapper_AS_C(args[i]);
+                if(!star.rvalue) UNREACHABLE();
+                PyVar list = asList(star.obj);
+                pkpy::List& list_c = PyList_AS_C(list);
+                unpacked.insert(unpacked.end(), list_c.begin(), list_c.end());
+            }else{
+                unpacked.push_back(args[i]);
+            }
+        }
+        args = std::move(unpacked);
+    }
+
     ~VM() {
         if(!use_stdio){
             delete _stdout;

+ 21 - 0
tests/_star.py

@@ -0,0 +1,21 @@
+def f(a, b, *args):
+    return a + b + sum(args)
+
+assert f(1, 2, 3, 4) == 10
+
+a = [5, 6, 7, 8]
+assert f(*a) == 26
+
+def g(*args):
+    return f(*args)
+
+assert g(1, 2, 3, 4) == 10
+assert g(*a) == 26
+
+def f(a, b, *args, c=16):
+    return a + b + sum(args) + c
+
+assert f(1, 2, 3, 4) == 26
+assert f(1, 2, 3, 4, c=32) == 42
+
+assert f(*a, c=-26) == 0