Просмотр исходного кода

impl *args and **kwargs (partially)

blueloveTH 3 лет назад
Родитель
Сommit
553d02592f
3 измененных файлов с 79 добавлено и 20 удалено
  1. 41 13
      src/compiler.h
  2. 12 1
      src/obj.h
  3. 26 6
      src/vm.h

+ 41 - 13
src/compiler.h

@@ -298,7 +298,7 @@ public:
     }
 
     void exprLambda() {
-
+        throw SyntaxError(path, parser->previous, "lambda is not implemented yet");
     }
 
     void exprAssign() {
@@ -690,34 +690,62 @@ public:
             if(match(TK("pass"))) return;
             consume(TK("def"));
         }
+        _Func func;
         consume(TK("@id"));
-        const _Str& name = parser->previous.str();
+        func.name = parser->previous.str();
 
-        std::vector<_Str> argNames;
         if (match(TK("(")) && !match(TK(")"))) {
+            int state = 0;      // 0 for args, 1 for *args, 2 for k=v, 3 for **kwargs
             do {
+                if(state == 3){
+                    throw SyntaxError(path, parser->previous, "**kwargs should be the last argument");
+                }
+
                 matchNewLines();
+                if(match(TK("*"))){
+                    if(state < 1) state = 1;
+                    else throw SyntaxError(path, parser->previous, "*args should be placed before **kwargs");
+                }
+                else if(match(TK("**"))){
+                    state = 3;
+                }
+
                 consume(TK("@id"));
-                const _Str& argName = parser->previous.str();
-                if (std::find(argNames.begin(), argNames.end(), argName) != argNames.end()) {
-                    throw SyntaxError(path, parser->previous, "duplicate argument in function definition");
+                const _Str& name = parser->previous.str();
+                if(func.hasName(name)) throw SyntaxError(path, parser->previous, "duplicate argument name");
+
+                switch (state)
+                {
+                    case 0: func.args.push_back(name); break;
+                    case 1: func.starredArg = name; state+=1; break;
+                    case 2: consume(TK("=")); func.kwArgs[name] = consumeLiteral(); break;
+                    case 3: func.doubleStarredArg = name; break;
                 }
-                argNames.push_back(argName);
             } while (match(TK(",")));
             consume(TK(")"));
         }
 
-        _Code fnCode = std::make_shared<CodeObject>();
-        fnCode->co_name = name;
-        fnCode->co_filename = path;
-        this->codes.push(fnCode);
+        func.code = std::make_shared<CodeObject>();
+        func.code->co_name = func.name;
+        func.code->co_filename = path;
+        this->codes.push(func.code);
         compileBlockBody();
         this->codes.pop();
-        PyVar fn = vm->PyFunction(_Func{name, fnCode, argNames});
-        emitCode(OP_LOAD_CONST, getCode()->addConst(fn));
+        emitCode(OP_LOAD_CONST, getCode()->addConst(vm->PyFunction(func)));
         if(!isCompilingClass) emitCode(OP_STORE_FUNCTION);
     }
 
+    PyVar consumeLiteral(){
+        if(match(TK("@num"))) goto __LITERAL_EXIT;
+        if(match(TK("@str"))) goto __LITERAL_EXIT;
+        if(match(TK("True"))) goto __LITERAL_EXIT;
+        if(match(TK("False"))) goto __LITERAL_EXIT;
+        if(match(TK("None"))) goto __LITERAL_EXIT;
+        throw SyntaxError(path, parser->previous, "expect a literal");
+__LITERAL_EXIT:
+        return parser->previous.value;
+    }
+
     void compileTopLevelStatement() {
         if (match(TK("class"))) {
             compileClass();

+ 12 - 1
src/obj.h

@@ -27,7 +27,18 @@ typedef std::shared_ptr<CodeObject> _Code;
 struct _Func {
     _Str name;
     _Code code;
-    std::vector<_Str> argNames;
+    std::vector<_Str> args;
+    _Str starredArg;        // empty if no *arg
+    StlDict kwArgs;         // empty if no k=v
+    _Str doubleStarredArg;  // empty if no **kwargs
+
+    bool hasName(const _Str& val) const {
+        bool _0 = std::find(args.begin(), args.end(), val) != args.end();
+        bool _1 = starredArg == val;
+        bool _2 = kwArgs.find(val) != kwArgs.end();
+        bool _3 = doubleStarredArg == val;
+        return _0 || _1 || _2 || _3;
+    }
 };
 
 struct BoundedMethod {

+ 26 - 6
src/vm.h

@@ -111,13 +111,33 @@ public:
             return f(this, args);
         } else if(callable->isType(_tp_function)){
             _Func fn = PyFunction_AS_C(callable);
-            if(args.size() != fn.argNames.size()){
-                _error("TypeError", "expected " + std::to_string(fn.argNames.size()) + " arguments, but got " + std::to_string(args.size()));
-            }
             StlDict locals;
-            for(int i=0; i<fn.argNames.size(); i++){
-                locals[fn.argNames[i]] = args[i];
+            int i = 0;
+            for(const auto& name : fn.args){
+                if(i < args.size()) {
+                    locals[name] = args[i++];
+                }else{
+                    _error("TypeError", "missing positional argument '" + name + "'");
+                }
             }
+            // handle *args
+            if(!fn.starredArg.empty()){
+                PyVarList vargs;
+                while(i < args.size()) vargs.push_back(args[i++]);
+                locals[fn.starredArg] = PyTuple(vargs);
+            }
+            // handle keyword arguments
+            for(const auto& [name, value] : fn.kwArgs){
+                if(i < args.size()) {
+                    locals[name] = args[i++];
+                }else{
+                    locals[name] = value;
+                }
+            }
+
+            if(i < args.size()) _error("TypeError", "too many arguments");
+
+            // TODO: handle **kwargs
             return exec(fn.code, locals);
         }
         _error("TypeError", "'" + callable->getTypeName() + "' object is not callable");
@@ -132,7 +152,7 @@ public:
         callstack.push(frame);
         while(!frame->isEnd()){
             const ByteCode& byte = frame->readCode();
-            printf("%s (%d) stack_size: %d\n", OP_NAMES[byte.op], byte.arg, frame->stackSize());
+            //printf("%s (%d) stack_size: %d\n", OP_NAMES[byte.op], byte.arg, frame->stackSize());
 
             switch (byte.op)
             {