Parcourir la source

add closure impl

blueloveTH il y a 3 ans
Parent
commit
2285300ed3
10 fichiers modifiés avec 65 ajouts et 29 suppressions
  1. 6 1
      src/ceval.h
  2. 1 1
      src/common.h
  3. 5 4
      src/compiler.h
  4. 9 2
      src/frame.h
  5. 2 2
      src/memory.h
  6. 11 8
      src/obj.h
  7. 2 0
      src/opcodes.h
  8. 0 1
      src/str.h
  9. 10 10
      src/vm.h
  10. 19 0
      tests/_closure.py

+ 6 - 1
src/ceval.h

@@ -14,9 +14,14 @@ PyVar VM::run_frame(Frame* frame){
         case OP_LOAD_CONST: frame->push(frame->co->consts[byte.arg]); continue;
         case OP_LOAD_FUNCTION: {
             PyVar obj = frame->co->consts[byte.arg];
-            setattr(obj, __module__, frame->_module);
+            auto& f = PyFunction_AS_C(obj);
+            f->_module = frame->_module;
             frame->push(obj);
         } continue;
+        case OP_SETUP_CLOSURE: {
+            auto& f = PyFunction_AS_C(frame->top());
+            f->_closure = frame->_locals;
+        } continue;
         case OP_LOAD_NAME_REF: {
             frame->push(PyRef(NameRef(frame->co->names[byte.arg])));
         } continue;

+ 1 - 1
src/common.h

@@ -34,7 +34,7 @@
 #define UNREACHABLE() throw std::runtime_error( __FILE__ + std::string(":") + std::to_string(__LINE__) + " UNREACHABLE()!");
 #endif
 
-#define PK_VERSION "0.8.7"
+#define PK_VERSION "0.8.8"
 
 typedef int64_t i64;
 typedef double f64;

+ 5 - 4
src/compiler.h

@@ -976,7 +976,7 @@ __LISTCOMP:
 
             consume(TK("@id"));
             const Str& name = parser->prev.str();
-            if(func->hasName(name)) SyntaxError("duplicate argument name");
+            if(func->has_name(name)) SyntaxError("duplicate argument name");
 
             // eat type hints
             if(enable_type_hints && match(TK(":"))) consume(TK("@id"));
@@ -986,15 +986,15 @@ __LISTCOMP:
             switch (state)
             {
                 case 0: func->args.push_back(name); break;
-                case 1: func->starredArg = name; state+=1; break;
+                case 1: func->starred_arg = name; state+=1; break;
                 case 2: {
                     consume(TK("="));
                     PyVarOrNull value = read_literal();
                     if(value == nullptr){
                         SyntaxError(Str("expect a literal, not ") + TK_STR(parser->curr.type));
                     }
-                    func->kwArgs[name] = value;
-                    func->kwArgsOrder.push_back(name);
+                    func->kwargs[name] = value;
+                    func->kwargs_order.push_back(name);
                 } break;
                 case 3: SyntaxError("**kwargs is not supported yet"); break;
             }
@@ -1021,6 +1021,7 @@ __LISTCOMP:
         func->code->optimize(vm);
         this->codes.pop();
         emit(OP_LOAD_FUNCTION, co()->add_const(vm->PyFunction(func)));
+        if(name_scope() == NAME_LOCAL) emit(OP_SETUP_CLOSURE);
         if(!is_compiling_class) emit(OP_STORE_NAME, co()->add_name(func->name, name_scope()));
     }
 

+ 9 - 2
src/frame.h

@@ -12,14 +12,21 @@ struct Frame {
     const CodeObject_ co;
     PyVar _module;
     pkpy::shared_ptr<pkpy::NameDict> _locals;
+    pkpy::shared_ptr<pkpy::NameDict> _closure;
     const i64 id;
     std::stack<std::pair<int, std::vector<PyVar>>> s_try_block;
 
     inline pkpy::NameDict& f_locals() noexcept { return *_locals; }
     inline pkpy::NameDict& f_globals() noexcept { return _module->attr(); }
 
-    Frame(const CodeObject_ co, PyVar _module, pkpy::shared_ptr<pkpy::NameDict> _locals)
-        : co(co), _module(_module), _locals(_locals), id(kFrameGlobalId++) { }
+    inline PyVar* f_closure_try_get(const Str& name) noexcept {
+        if(_closure == nullptr) return nullptr;
+        return _closure->try_get(name);
+    }
+
+    Frame(const CodeObject_ co, PyVar _module,
+        pkpy::shared_ptr<pkpy::NameDict> _locals, pkpy::shared_ptr<pkpy::NameDict> _closure=nullptr)
+        : co(co), _module(_module), _locals(_locals), _closure(_closure), id(kFrameGlobalId++) { }
 
     inline const Bytecode& next_bytecode() {
         _ip = _next_ip++;

+ 2 - 2
src/memory.h

@@ -18,14 +18,14 @@ namespace pkpy{
 
     template <typename T>
     class shared_ptr {
-        int* counter = nullptr;
+        int* counter;
 
 #define _t() ((T*)(counter + 1))
 #define _inc_counter() if(counter) ++(*counter)
 #define _dec_counter() if(counter && --(*counter) == 0){ SpAllocator<T>::dealloc(counter); }
 
     public:
-        shared_ptr() {}
+        shared_ptr() : counter(nullptr) {}
         shared_ptr(int* counter) : counter(counter) {}
         shared_ptr(const shared_ptr& other) : counter(other.counter) {
             _inc_counter();

+ 11 - 8
src/obj.h

@@ -24,14 +24,18 @@ struct Function {
     Str name;
     CodeObject_ code;
     std::vector<Str> args;
-    Str starredArg;                // empty if no *arg
-    pkpy::NameDict kwArgs;          // empty if no k=v
-    std::vector<Str> kwArgsOrder;
+    Str starred_arg;                // empty if no *arg
+    pkpy::NameDict kwargs;          // empty if no k=v
+    std::vector<Str> kwargs_order;
 
-    bool hasName(const Str& val) const {
+    // runtime settings
+    PyVar _module;
+    pkpy::shared_ptr<pkpy::NameDict> _closure;
+
+    bool has_name(const Str& val) const {
         bool _0 = std::find(args.begin(), args.end(), val) != args.end();
-        bool _1 = starredArg == val;
-        bool _2 = kwArgs.find(val) != kwArgs.end();
+        bool _1 = starred_arg == val;
+        bool _2 = kwargs.find(val) != kwargs.end();
         return _0 || _1 || _2;
     }
 };
@@ -99,8 +103,7 @@ struct Py_ : PyObject {
     Py_(Type type, T&& val): PyObject(type, sizeof(Py_<T>)), _value(std::move(val)) { _init(); }
 
     inline void _init() noexcept {
-        if constexpr (std::is_same_v<T, Dummy> || std::is_same_v<T, Type>
-        || std::is_same_v<T, pkpy::Function_> || std::is_same_v<T, pkpy::NativeFunc>) {
+        if constexpr (std::is_same_v<T, Dummy> || std::is_same_v<T, Type>) {
             _attr = new pkpy::NameDict();
         }else{
             _attr = nullptr;

+ 2 - 0
src/opcodes.h

@@ -76,4 +76,6 @@ OPCODE(FAST_INDEX_REF)       // a[x]
 OPCODE(INPLACE_BINARY_OP)
 OPCODE(INPLACE_BITWISE_OP)
 
+OPCODE(SETUP_CLOSURE)
+
 #endif

+ 0 - 1
src/str.h

@@ -153,7 +153,6 @@ const Str __new__ = Str("__new__");
 const Str __iter__ = Str("__iter__");
 const Str __str__ = Str("__str__");
 const Str __repr__ = Str("__repr__");
-const Str __module__ = Str("__module__");
 const Str __getitem__ = Str("__getitem__");
 const Str __setitem__ = Str("__setitem__");
 const Str __delitem__ = Str("__delitem__");

+ 10 - 10
src/vm.h

@@ -163,15 +163,15 @@ public:
                 TypeError("missing positional argument '" + name + "'");
             }
 
-            locals.insert(fn->kwArgs.begin(), fn->kwArgs.end());
+            locals.insert(fn->kwargs.begin(), fn->kwargs.end());
 
             std::vector<Str> positional_overrided_keys;
-            if(!fn->starredArg.empty()){
+            if(!fn->starred_arg.empty()){
                 pkpy::List vargs;        // handle *args
                 while(i < args.size()) vargs.push_back(args[i++]);
-                locals.emplace(fn->starredArg, PyTuple(std::move(vargs)));
+                locals.emplace(fn->starred_arg, PyTuple(std::move(vargs)));
             }else{
-                for(const auto& key : fn->kwArgsOrder){
+                for(const auto& key : fn->kwargs_order){
                     if(i < args.size()){
                         locals[key] = args[i++];
                         positional_overrided_keys.push_back(key);
@@ -184,7 +184,7 @@ public:
             
             for(int i=0; i<kwargs.size(); i+=2){
                 const Str& key = PyStr_AS_C(kwargs[i]);
-                if(!fn->kwArgs.contains(key)){
+                if(!fn->kwargs.contains(key)){
                     TypeError(key.escape(true) + " is an invalid keyword argument for " + fn->name + "()");
                 }
                 const PyVar& val = kwargs[i+1];
@@ -196,10 +196,8 @@ public:
                 }
                 locals[key] = val;
             }
-
-            PyVar* _m = (*callable)->attr().try_get(__module__);
-            PyVar _module = _m != nullptr ? *_m : top_frame()->_module;
-            auto _frame = _new_frame(fn->code, _module, _locals);
+            PyVar _module = fn->_module != nullptr ? fn->_module : top_frame()->_module;
+            auto _frame = _new_frame(fn->code, _module, _locals, fn->_closure);
             if(fn->code->is_generator){
                 return PyIter(pkpy::make_shared<BaseIter, Generator>(
                     this, std::move(_frame)));
@@ -208,7 +206,7 @@ public:
             if(opCall) return _py_op_call;
             return _exec();
         }
-        TypeError("'" + OBJ_NAME(_t(*callable)) + "' object is not callable");
+        TypeError(OBJ_NAME(_t(*callable)).escape(true) + " object is not callable");
         return None;
     }
 
@@ -716,6 +714,8 @@ PyVar NameRef::get(VM* vm, Frame* frame) const{
     PyVar* val;
     val = frame->f_locals().try_get(name());
     if(val) return *val;
+    val = frame->f_closure_try_get(name());
+    if(val) return *val;
     val = frame->f_globals().try_get(name());
     if(val) return *val;
     val = vm->builtins->attr().try_get(name());

+ 19 - 0
tests/_closure.py

@@ -0,0 +1,19 @@
+# only one level nested closure is implemented
+
+def f0(a, b):
+    def f1():
+        return a + b
+    return f1
+
+a = f0(1, 2)
+assert a() == 3
+
+
+def f0(a, b):
+    def f1():
+        a = 5   # use this first
+        return a + b
+    return f1
+
+a = f0(1, 2)
+assert a() == 7