blueloveTH 3 年 前
コミット
74ed211f7b
5 ファイル変更43 行追加42 行削除
  1. 3 3
      src/frame.h
  2. 1 1
      src/memory.h
  3. 23 16
      src/namedict.h
  4. 1 1
      src/ref.h
  5. 15 21
      src/vm.h

+ 3 - 3
src/frame.h

@@ -9,7 +9,7 @@ struct Frame {
     int _ip = -1;
     int _next_ip = 0;
 
-    const CodeObject_ co;
+    const CodeObject* co;
     PyVar _module;
     pkpy::shared_ptr<pkpy::NameDict> _locals;
     pkpy::shared_ptr<pkpy::NameDict> _closure;
@@ -24,9 +24,9 @@ struct Frame {
         return _closure->try_get(name);
     }
 
-    Frame(const CodeObject_ co, PyVar _module,
+    Frame(const CodeObject_& co, const PyVar& _module,
         pkpy::shared_ptr<pkpy::NameDict> _locals=nullptr, pkpy::shared_ptr<pkpy::NameDict> _closure=nullptr)
-        : co(co), _module(_module), _locals(_locals), _closure(_closure), id(kFrameGlobalId++) { }
+        : co(co.get()), _module(_module), _locals(_locals), _closure(_closure), id(kFrameGlobalId++) { }
 
     inline const Bytecode& next_bytecode() {
         _ip = _next_ip++;

+ 1 - 1
src/memory.h

@@ -80,7 +80,7 @@ namespace pkpy{
             return reinterpret_cast<__VAL>(counter);
         }
 
-        inline bool is_tagged() const {
+        inline constexpr bool is_tagged() const {
             if constexpr(!std::is_same_v<T, PyObject>) return false;
             return (reinterpret_cast<i64>(counter) & 0b11) != 0b00;
         }

+ 23 - 16
src/namedict.h

@@ -36,22 +36,22 @@ namespace pkpy{
 
         int size() const { return _size; }
 
-    //https://github.com/python/cpython/blob/main/Objects/dictobject.c#L175
-    #define HASH_PROBE(key, ok, i) \
-        int i = (key).index % _capacity; \
-        bool ok = false; \
-        while(!_a[i].empty()) { \
-            if(_a[i].first == (key)) { ok = true; break; } \
-            i = (5*i + 1) % _capacity; \
-        }
-
-    #define HASH_PROBE_OVERRIDE(key, ok, i) \
-        i = (key).index % _capacity; \
-        ok = false; \
-        while(!_a[i].empty()) { \
-            if(_a[i].first == (key)) { ok = true; break; } \
-            i = (5*i + 1) % _capacity; \
-        }
+//https://github.com/python/cpython/blob/main/Objects/dictobject.c#L175
+#define HASH_PROBE(key, ok, i) \
+    int i = (key).index % _capacity; \
+    bool ok = false; \
+    while(!_a[i].empty()) { \
+        if(_a[i].first == (key)) { ok = true; break; } \
+        i = (5*i + 1) % _capacity; \
+    }
+
+#define HASH_PROBE_OVERRIDE(key, ok, i) \
+    i = (key).index % _capacity; \
+    ok = false; \
+    while(!_a[i].empty()) { \
+        if(_a[i].first == (key)) { ok = true; break; } \
+        i = (5*i + 1) % _capacity; \
+    }
 
         const PyVar& operator[](StrName key) const {
             HASH_PROBE(key, ok, i);
@@ -95,6 +95,13 @@ namespace pkpy{
             return &_a[i].second;
         }
 
+        inline bool try_set(StrName key, PyVar&& value){
+            HASH_PROBE(key, ok, i);
+            if(!ok) return false;
+            _a[i].second = std::move(value);
+            return true;
+        }
+
         inline bool contains(StrName key) const {
             HASH_PROBE(key, ok, i);
             return ok;

+ 1 - 1
src/ref.h

@@ -20,7 +20,7 @@ struct NameRef : BaseRef {
     const std::pair<StrName, NameScope> pair;
     inline StrName name() const { return pair.first; }
     inline NameScope scope() const { return pair.second; }
-    NameRef(std::pair<StrName, NameScope>& pair) : pair(pair) {}
+    NameRef(const std::pair<StrName, NameScope>& pair) : pair(pair) {}
 
     PyVar get(VM* vm, Frame* frame) const;
     void set(VM* vm, Frame* frame, PyVar val) const;

+ 15 - 21
src/vm.h

@@ -151,30 +151,29 @@ public:
             return f(this, args);
         } else if(is_type(*callable, tp_function)){
             const pkpy::Function& fn = PyFunction_AS_C(*callable);
-            auto _locals = pkpy::make_shared<pkpy::NameDict>(
+            auto locals = pkpy::make_shared<pkpy::NameDict>(
                 fn.code->ideal_locals_capacity, kLocalsLoadFactor
             );
-            pkpy::NameDict& locals = *_locals;
 
             int i = 0;
             for(StrName name : fn.args){
                 if(i < args.size()){
-                    locals.emplace(name, args[i++]);
+                    locals->emplace(name, args[i++]);
                     continue;
                 }
                 TypeError("missing positional argument " + name.str().escape(true));
             }
 
-            locals.insert(fn.kwargs.begin(), fn.kwargs.end());
+            locals->insert(fn.kwargs.begin(), fn.kwargs.end());
 
             if(!fn.starred_arg.empty()){
                 pkpy::List vargs;        // handle *args
                 while(i < args.size()) vargs.push_back(args[i++]);
-                locals.emplace(fn.starred_arg, PyTuple(std::move(vargs)));
+                locals->emplace(fn.starred_arg, PyTuple(std::move(vargs)));
             }else{
                 for(StrName key : fn.kwargs_order){
                     if(i < args.size()){
-                        locals.emplace(key, args[i++]);
+                        locals->emplace(key, args[i++]);
                     }else{
                         break;
                     }
@@ -187,10 +186,10 @@ public:
                 if(!fn.kwargs.contains(key)){
                     TypeError(key.escape(true) + " is an invalid keyword argument for " + fn.name + "()");
                 }
-                locals.emplace(key, kwargs[i+1]);
+                locals->emplace(key, kwargs[i+1]);
             }
-            PyVar _module = fn._module != nullptr ? fn._module : top_frame()->_module;
-            auto _frame = _new_frame(fn.code, _module, _locals, fn._closure);
+            const PyVar& _module = fn._module != nullptr ? fn._module : top_frame()->_module;
+            auto _frame = _new_frame(fn.code, _module, locals, fn._closure);
             if(fn.code->is_generator){
                 return PyIter(pkpy::make_shared<BaseIter, Generator>(
                     this, std::move(_frame)));
@@ -742,13 +741,13 @@ public:
 PyVar NameRef::get(VM* vm, Frame* frame) const{
     PyVar* val;
     val = frame->f_locals().try_get(name());
-    if(val) return *val;
+    if(val != nullptr) return *val;
     val = frame->f_closure_try_get(name());
-    if(val) return *val;
+    if(val != nullptr) return *val;
     val = frame->f_globals().try_get(name());
-    if(val) return *val;
+    if(val != nullptr) return *val;
     val = vm->builtins->attr().try_get(name());
-    if(val) return *val;
+    if(val != nullptr) return *val;
     vm->NameError(name());
     return nullptr;
 }
@@ -757,14 +756,9 @@ void NameRef::set(VM* vm, Frame* frame, PyVar val) const{
     switch(scope()) {
         case NAME_LOCAL: frame->f_locals()[name()] = std::move(val); break;
         case NAME_GLOBAL:
-        {
-            PyVar* existing = frame->f_locals().try_get(name());
-            if(existing != nullptr){
-                *existing = std::move(val);
-            }else{
-                frame->f_globals()[name()] = std::move(val);
-            }
-        } break;
+            if(frame->f_locals().try_set(name(), std::move(val))) return;
+            frame->f_globals()[name()] = std::move(val);
+            break;
         default: UNREACHABLE();
     }
 }