blueloveTH 3 лет назад
Родитель
Сommit
3bb3b0d584
6 измененных файлов с 49 добавлено и 42 удалено
  1. 8 8
      src/ceval.h
  2. 18 17
      src/compiler.h
  3. 0 2
      src/obj.h
  4. 1 1
      src/pocketpy.h
  5. 13 13
      src/vm.h
  6. 9 1
      tests/_closure.py

+ 8 - 8
src/ceval.h

@@ -13,14 +13,14 @@ PyVar VM::run_frame(Frame* frame){
         case OP_NO_OP: continue;
         case OP_LOAD_CONST: frame->push(frame->co->consts[byte.arg]); continue;
         case OP_LOAD_FUNCTION: {
-            PyVar obj = frame->co->consts[byte.arg];
-            auto& f = PyFunction_AS_C(obj);
-            f->_module = frame->_module;
-            frame->push(obj);
+            const PyVar obj = frame->co->consts[byte.arg];
+            pkpy::Function f = PyFunction_AS_C(obj);  // copy
+            f._module = frame->_module;
+            frame->push(PyFunction(f));
         } continue;
         case OP_SETUP_CLOSURE: {
-            auto& f = PyFunction_AS_C(frame->top());
-            f->_closure = frame->_locals;
+            pkpy::Function& f = PyFunction_AS_C(frame->top());    // reference
+            f._closure = frame->_locals;
         } continue;
         case OP_LOAD_NAME_REF: {
             frame->push(PyRef(NameRef(frame->co->names[byte.arg])));
@@ -98,8 +98,8 @@ PyVar VM::run_frame(Frame* frame){
             while(true){
                 PyVar fn = frame->pop_value(this);
                 if(fn == None) break;
-                const pkpy::Function_& f = PyFunction_AS_C(fn);
-                setattr(cls, f->name, fn);
+                const pkpy::Function& f = PyFunction_AS_C(fn);
+                setattr(cls, f.name, fn);
             }
         } continue;
         case OP_RETURN_VALUE: return frame->pop_value(this);

+ 18 - 17
src/compiler.h

@@ -384,19 +384,20 @@ private:
     }
 
     void exprLambda() {
-        pkpy::Function_ func = pkpy::make_shared<pkpy::Function>();
-        func->name = "<lambda>";
+        pkpy::Function func;
+        func.name = "<lambda>";
         if(!match(TK(":"))){
             _compile_f_args(func, false);
             consume(TK(":"));
         }
-        func->code = pkpy::make_shared<CodeObject>(parser->src, func->name);
-        this->codes.push(func->code);
+        func.code = pkpy::make_shared<CodeObject>(parser->src, func.name);
+        this->codes.push(func.code);
         EXPR_TUPLE();
         emit(OP_RETURN_VALUE);
-        func->code->optimize(vm);
+        func.code->optimize(vm);
         this->codes.pop();
         emit(OP_LOAD_FUNCTION, co()->add_const(vm->PyFunction(func)));
+        if(name_scope() == NAME_LOCAL) emit(OP_SETUP_CLOSURE);
     }
 
     void exprAssign() {
@@ -961,7 +962,7 @@ __LISTCOMP:
         emit(OP_BUILD_CLASS, cls_name_idx);
     }
 
-    void _compile_f_args(pkpy::Function_ func, bool enable_type_hints){
+    void _compile_f_args(pkpy::Function& func, bool enable_type_hints){
         int state = 0;      // 0 for args, 1 for *args, 2 for k=v, 3 for **kwargs
         do {
             if(state == 3) SyntaxError("**kwargs should be the last argument");
@@ -976,7 +977,7 @@ __LISTCOMP:
 
             consume(TK("@id"));
             const Str& name = parser->prev.str();
-            if(func->has_name(name)) SyntaxError("duplicate argument name");
+            if(func.has_name(name)) SyntaxError("duplicate argument name");
 
             // eat type hints
             if(enable_type_hints && match(TK(":"))) consume(TK("@id"));
@@ -985,16 +986,16 @@ __LISTCOMP:
 
             switch (state)
             {
-                case 0: func->args.push_back(name); break;
-                case 1: func->starred_arg = name; state+=1; break;
+                case 0: func.args.push_back(name); break;
+                case 1: func.starred_arg = name; state+=1; break;
                 case 2: {
                     consume(TK("="));
                     PyVarOrNull value = read_literal();
                     if(value == nullptr){
                         SyntaxError(Str("expect a literal, not ") + TK_STR(parser->curr.type));
                     }
-                    func->kwargs[name] = value;
-                    func->kwargs_order.push_back(name);
+                    func.kwargs[name] = value;
+                    func.kwargs_order.push_back(name);
                 } break;
                 case 3: SyntaxError("**kwargs is not supported yet"); break;
             }
@@ -1006,23 +1007,23 @@ __LISTCOMP:
             if(match(TK("pass"))) return;
             consume(TK("def"));
         }
-        pkpy::Function_ func = pkpy::make_shared<pkpy::Function>();
+        pkpy::Function func;
         consume(TK("@id"));
-        func->name = parser->prev.str();
+        func.name = parser->prev.str();
         consume(TK("("));
         if (!match(TK(")"))) {
             _compile_f_args(func, true);
             consume(TK(")"));
         }
         if(match(TK("->"))) consume(TK("@id")); // eat type hints
-        func->code = pkpy::make_shared<CodeObject>(parser->src, func->name);
-        this->codes.push(func->code);
+        func.code = pkpy::make_shared<CodeObject>(parser->src, func.name);
+        this->codes.push(func.code);
         compile_block_body();
-        func->code->optimize(vm);
+        func.code->optimize(vm);
         this->codes.pop();
         emit(OP_LOAD_FUNCTION, co()->add_const(vm->PyFunction(func)));
         if(name_scope() == NAME_LOCAL) emit(OP_SETUP_CLOSURE);
-        if(!is_compiling_class) emit(OP_STORE_NAME, co()->add_name(func->name, name_scope()));
+        if(!is_compiling_class) emit(OP_STORE_NAME, co()->add_name(func.name, name_scope()));
     }
 
     PyVarOrNull read_literal(){

+ 0 - 2
src/obj.h

@@ -63,8 +63,6 @@ struct Slice {
         if(stop < start) stop = start;
     }
 };
-
-typedef shared_ptr<Function> Function_;
 }
 
 class BaseIter {

+ 1 - 1
src/pocketpy.h

@@ -591,7 +591,7 @@ void add_module_math(VM* vm){
 void add_module_dis(VM* vm){
     PyVar mod = vm->new_module("dis");
     vm->bind_func<1>(mod, "dis", [](VM* vm, pkpy::Args& args) {
-        CodeObject_ code = vm->PyFunction_AS_C(args[0])->code;
+        CodeObject_ code = vm->PyFunction_AS_C(args[0]).code;
         (*vm->_stdout) << vm->disassemble(code);
         return vm->None;
     });

+ 13 - 13
src/vm.h

@@ -150,12 +150,12 @@ public:
             if(kwargs.size() != 0) TypeError("native_function does not accept keyword arguments");
             return f(this, args);
         } else if((*callable)->is_type(tp_function)){
-            const pkpy::Function_& fn = PyFunction_AS_C((*callable));
+            const pkpy::Function& fn = PyFunction_AS_C(*callable);
             pkpy::shared_ptr<pkpy::NameDict> _locals = pkpy::make_shared<pkpy::NameDict>();
             pkpy::NameDict& locals = *_locals;
 
             int i = 0;
-            for(const auto& name : fn->args){
+            for(const auto& name : fn.args){
                 if(i < args.size()){
                     locals.emplace(name, args[i++]);
                     continue;
@@ -163,15 +163,15 @@ public:
                 TypeError("missing positional argument '" + name + "'");
             }
 
-            locals.insert(fn->kwargs.begin(), fn->kwargs.end());
+            locals.insert(fn.kwargs.begin(), fn.kwargs.end());
 
             std::vector<Str> positional_overrided_keys;
-            if(!fn->starred_arg.empty()){
+            if(!fn.starred_arg.empty()){
                 pkpy::List vargs;        // handle *args
                 while(i < args.size()) vargs.push_back(args[i++]);
-                locals.emplace(fn->starred_arg, PyTuple(std::move(vargs)));
+                locals.emplace(fn.starred_arg, PyTuple(std::move(vargs)));
             }else{
-                for(const auto& key : fn->kwargs_order){
+                for(const auto& key : fn.kwargs_order){
                     if(i < args.size()){
                         locals[key] = args[i++];
                         positional_overrided_keys.push_back(key);
@@ -184,8 +184,8 @@ public:
             
             for(int i=0; i<kwargs.size(); i+=2){
                 const Str& key = PyStr_AS_C(kwargs[i]);
-                if(!fn->kwargs.contains(key)){
-                    TypeError(key.escape(true) + " is an invalid keyword argument for " + fn->name + "()");
+                if(!fn.kwargs.contains(key)){
+                    TypeError(key.escape(true) + " is an invalid keyword argument for " + fn.name + "()");
                 }
                 const PyVar& val = kwargs[i+1];
                 if(!positional_overrided_keys.empty()){
@@ -196,9 +196,9 @@ public:
                 }
                 locals[key] = val;
             }
-            PyVar _module = fn->_module != nullptr ? fn->_module : top_frame()->_module;
-            auto _frame = _new_frame(fn->code, _module, _locals, fn->_closure);
-            if(fn->code->is_generator){
+            PyVar _module = fn._module != nullptr ? fn._module : top_frame()->_module;
+            auto _frame = _new_frame(fn.code, _module, _locals, fn._closure);
+            if(fn.code->is_generator){
                 return PyIter(pkpy::make_shared<BaseIter, Generator>(
                     this, std::move(_frame)));
             }
@@ -512,7 +512,7 @@ public:
             PyVar obj = co->consts[i];
             if(obj->is_type(tp_function)){
                 const auto& f = PyFunction_AS_C(obj);
-                ss << disassemble(f->code);
+                ss << disassemble(f.code);
             }
         }
         return Str(ss.str());
@@ -554,7 +554,7 @@ public:
     DEF_NATIVE(Float, f64, tp_float)
     DEF_NATIVE(List, pkpy::List, tp_list)
     DEF_NATIVE(Tuple, pkpy::Tuple, tp_tuple)
-    DEF_NATIVE(Function, pkpy::Function_, tp_function)
+    DEF_NATIVE(Function, pkpy::Function, tp_function)
     DEF_NATIVE(NativeFunc, pkpy::NativeFunc, tp_native_function)
     DEF_NATIVE(Iter, pkpy::shared_ptr<BaseIter>, tp_native_iterator)
     DEF_NATIVE(BoundMethod, pkpy::BoundMethod, tp_bound_method)

+ 9 - 1
tests/_closure.py

@@ -6,7 +6,9 @@ def f0(a, b):
     return f1
 
 a = f0(1, 2)
+b = f0(3, 4)
 assert a() == 3
+assert b() == 7
 
 
 def f0(a, b):
@@ -16,4 +18,10 @@ def f0(a, b):
     return f1
 
 a = f0(1, 2)
-assert a() == 7
+assert a() == 7
+
+def f3(x, y):
+    return lambda z: x + y + z
+
+a = f3(1, 2)
+assert a(3) == 6