Prechádzať zdrojové kódy

refactor defaultdict

blueloveTH 1 rok pred
rodič
commit
c77fef35a2

+ 1 - 2
docs/modules/collections.md

@@ -14,5 +14,4 @@ A double-ended queue.
 
 
 ### `collections.defaultdict`
 ### `collections.defaultdict`
 
 
-A `dict` wrapper that calls a factory function to supply missing values.
-It is not a subclass of `dict`.
+A dictionary that returns a default value when a key is not found.

Rozdielové dáta súboru neboli zobrazené, pretože súbor je príliš veľký
+ 0 - 0
include/pocketpy/_generated.h


+ 1 - 0
include/pocketpy/str.h

@@ -224,6 +224,7 @@ const StrName __all__ = StrName::get("__all__");
 const StrName __package__ = StrName::get("__package__");
 const StrName __package__ = StrName::get("__package__");
 const StrName __path__ = StrName::get("__path__");
 const StrName __path__ = StrName::get("__path__");
 const StrName __class__ = StrName::get("__class__");
 const StrName __class__ = StrName::get("__class__");
+const StrName __missing__ = StrName::get("__missing__");
 
 
 const StrName pk_id_add = StrName::get("add");
 const StrName pk_id_add = StrName::get("add");
 const StrName pk_id_set = StrName::get("set");
 const StrName pk_id_set = StrName::get("set");

+ 0 - 1
include/pocketpy/vm.h

@@ -82,7 +82,6 @@ struct PyTypeInfo{
     void (*m__setattr__)(VM* vm, PyObject*, StrName, PyObject*) = nullptr;
     void (*m__setattr__)(VM* vm, PyObject*, StrName, PyObject*) = nullptr;
     PyObject* (*m__getattr__)(VM* vm, PyObject*, StrName) = nullptr;
     PyObject* (*m__getattr__)(VM* vm, PyObject*, StrName) = nullptr;
     bool (*m__delattr__)(VM* vm, PyObject*, StrName) = nullptr;
     bool (*m__delattr__)(VM* vm, PyObject*, StrName) = nullptr;
-
 };
 };
 
 
 typedef void(*PrintFunc)(const char*, int);
 typedef void(*PrintFunc)(const char*, int);

+ 10 - 54
python/collections.py

@@ -7,63 +7,19 @@ def Counter(iterable):
             a[x] = 1
             a[x] = 1
     return a
     return a
 
 
-class defaultdict:
-    def __init__(self, default_factory) -> None:
+class defaultdict(dict):
+    def __init__(self, default_factory, *args):
+        super().__init__(*args)
+        self._enable_instance_dict()
         self.default_factory = default_factory
         self.default_factory = default_factory
-        self._a = {}
 
 
-    def __getitem__(self, key):
-        if key not in self._a:
-            self._a[key] = self.default_factory()
-        return self._a[key]
-        
-    def __setitem__(self, key, value):
-        self._a[key] = value
-
-    def __delitem__(self, key):
-        del self._a[key]
+    def __missing__(self, key):
+        self[key] = self.default_factory()
+        return self[key]
 
 
     def __repr__(self) -> str:
     def __repr__(self) -> str:
-        return f"defaultdict({self.default_factory}, {self._a})"
-    
-    def __eq__(self, __o: object) -> bool:
-        if not isinstance(__o, defaultdict):
-            return False
-        if self.default_factory != __o.default_factory:
-            return False
-        return self._a == __o._a
-    
-    def __iter__(self):
-        return iter(self._a)
-
-    def __contains__(self, key):
-        return key in self._a
-    
-    def __len__(self):
-        return len(self._a)
-
-    def keys(self):
-        return self._a.keys()
-    
-    def values(self):
-        return self._a.values()
-    
-    def items(self):
-        return self._a.items()
-
-    def pop(self, *args):
-        return self._a.pop(*args)
-
-    def clear(self):
-        self._a.clear()
+        return f"defaultdict({self.default_factory}, {super().__repr__()})"
 
 
     def copy(self):
     def copy(self):
-        new_dd = defaultdict(self.default_factory)
-        new_dd._a = self._a.copy()
-        return new_dd
-    
-    def get(self, key, default):
-        return self._a.get(key, default)
-    
-    def update(self, other):
-        self._a.update(other)
+        return defaultdict(self.default_factory, self)
+

+ 30 - 14
src/pocketpy.cpp

@@ -1228,37 +1228,53 @@ void init_builtins(VM* _vm) {
 
 
     // tp_dict
     // tp_dict
     _vm->bind_constructor<-1>(_vm->_t(VM::tp_dict), [](VM* vm, ArgsView args){
     _vm->bind_constructor<-1>(_vm->_t(VM::tp_dict), [](VM* vm, ArgsView args){
-        return VAR(Dict(vm));
+        Type cls_t = PK_OBJ_GET(Type, args[0]);
+        return vm->heap.gcnew<Dict>(cls_t, vm);
     });
     });
 
 
     _vm->bind_method<-1>(VM::tp_dict, "__init__", [](VM* vm, ArgsView args){
     _vm->bind_method<-1>(VM::tp_dict, "__init__", [](VM* vm, ArgsView args){
         if(args.size() == 1+0) return vm->None;
         if(args.size() == 1+0) return vm->None;
         if(args.size() == 1+1){
         if(args.size() == 1+1){
             auto _lock = vm->heap.gc_scope_lock();
             auto _lock = vm->heap.gc_scope_lock();
-            Dict& self = _CAST(Dict&, args[0]);
-            List& list = CAST(List&, args[1]);
-            for(PyObject* item : list){
-                Tuple& t = CAST(Tuple&, item);
-                if(t.size() != 2){
-                    vm->ValueError("dict() takes an iterable of tuples (key, value)");
-                    return vm->None;
+            Dict& self = PK_OBJ_GET(Dict, args[0]);
+            if(is_non_tagged_type(args[1], vm->tp_dict)){
+                Dict& other = CAST(Dict&, args[1]);
+                self.update(other);
+                return vm->None;
+            }
+            if(is_non_tagged_type(args[1], vm->tp_list)){
+                List& list = PK_OBJ_GET(List, args[1]);
+                for(PyObject* item : list){
+                    Tuple& t = CAST(Tuple&, item);
+                    if(t.size() != 2){
+                        vm->ValueError("dict() takes an iterable of tuples (key, value)");
+                        return vm->None;
+                    }
+                    self.set(t[0], t[1]);
                 }
                 }
-                self.set(t[0], t[1]);
             }
             }
             return vm->None;
             return vm->None;
         }
         }
         vm->TypeError("dict() takes at most 1 argument");
         vm->TypeError("dict() takes at most 1 argument");
-        return vm->None;
+        PK_UNREACHABLE()
     });
     });
 
 
     _vm->bind__len__(VM::tp_dict, [](VM* vm, PyObject* _0) {
     _vm->bind__len__(VM::tp_dict, [](VM* vm, PyObject* _0) {
-        return (i64)_CAST(Dict&, _0).size();
+        return (i64)PK_OBJ_GET(Dict, _0).size();
     });
     });
 
 
     _vm->bind__getitem__(VM::tp_dict, [](VM* vm, PyObject* _0, PyObject* _1) {
     _vm->bind__getitem__(VM::tp_dict, [](VM* vm, PyObject* _0, PyObject* _1) {
-        Dict& self = _CAST(Dict&, _0);
+        Dict& self = PK_OBJ_GET(Dict, _0);
         PyObject* ret = self.try_get(_1);
         PyObject* ret = self.try_get(_1);
-        if(ret == nullptr) vm->KeyError(_1);
+        if(ret == nullptr){
+            // try __missing__
+            PyObject* self;
+            PyObject* f_missing = vm->get_unbound_method(_0, __missing__, &self, false);
+            if(f_missing != nullptr){
+                return vm->call_method(self, f_missing, _1);
+            }
+            vm->KeyError(_1);
+        }
         return ret;
         return ret;
     });
     });
 
 
@@ -1372,7 +1388,7 @@ void init_builtins(VM* _vm) {
 
 
     _vm->bind__eq__(VM::tp_dict, [](VM* vm, PyObject* _0, PyObject* _1) {
     _vm->bind__eq__(VM::tp_dict, [](VM* vm, PyObject* _0, PyObject* _1) {
         Dict& self = _CAST(Dict&, _0);
         Dict& self = _CAST(Dict&, _0);
-        if(!is_non_tagged_type(_1, vm->tp_dict)) return vm->NotImplemented;
+        if(!vm->isinstance(_1, vm->tp_dict)) return vm->NotImplemented;
         Dict& other = _CAST(Dict&, _1);
         Dict& other = _CAST(Dict&, _1);
         if(self.size() != other.size()) return vm->False;
         if(self.size() != other.size()) return vm->False;
         for(int i=0; i<self._capacity; i++){
         for(int i=0; i<self._capacity; i++){

+ 3 - 3
src/vm.cpp

@@ -714,7 +714,7 @@ void VM::init_builtin_types(){
     if(tp_exception != _new_type_object("Exception", 0, true)) exit(-3);
     if(tp_exception != _new_type_object("Exception", 0, true)) exit(-3);
     if(tp_bytes != _new_type_object("bytes")) exit(-3);
     if(tp_bytes != _new_type_object("bytes")) exit(-3);
     if(tp_mappingproxy != _new_type_object("mappingproxy")) exit(-3);
     if(tp_mappingproxy != _new_type_object("mappingproxy")) exit(-3);
-    if(tp_dict != _new_type_object("dict")) exit(-3);
+    if(tp_dict != _new_type_object("dict", 0, true)) exit(-3);  // dict can be subclassed
     if(tp_property != _new_type_object("property")) exit(-3);
     if(tp_property != _new_type_object("property")) exit(-3);
     if(tp_star_wrapper != _new_type_object("_star_wrapper")) exit(-3);
     if(tp_star_wrapper != _new_type_object("_star_wrapper")) exit(-3);
 
 
@@ -1301,7 +1301,7 @@ void VM::bind__hash__(Type type, i64 (*f)(VM*, PyObject*)){
     PyObject* obj = _t(type);
     PyObject* obj = _t(type);
     _all_types[type].m__hash__ = f;
     _all_types[type].m__hash__ = f;
     PyObject* nf = bind_method<0>(obj, "__hash__", [](VM* vm, ArgsView args){
     PyObject* nf = bind_method<0>(obj, "__hash__", [](VM* vm, ArgsView args){
-        i64 ret = lambda_get_userdata<i64(*)(VM*, PyObject*)>(args.begin())(vm, args[0]);
+        i64 ret = lambda_get_userdata<decltype(f)>(args.begin())(vm, args[0]);
         return VAR(ret);
         return VAR(ret);
     });
     });
     PK_OBJ_GET(NativeFunc, nf).set_userdata(f);
     PK_OBJ_GET(NativeFunc, nf).set_userdata(f);
@@ -1311,7 +1311,7 @@ void VM::bind__len__(Type type, i64 (*f)(VM*, PyObject*)){
     PyObject* obj = _t(type);
     PyObject* obj = _t(type);
     _all_types[type].m__len__ = f;
     _all_types[type].m__len__ = f;
     PyObject* nf = bind_method<0>(obj, "__len__", [](VM* vm, ArgsView args){
     PyObject* nf = bind_method<0>(obj, "__len__", [](VM* vm, ArgsView args){
-        i64 ret = lambda_get_userdata<i64(*)(VM*, PyObject*)>(args.begin())(vm, args[0]);
+        i64 ret = lambda_get_userdata<decltype(f)>(args.begin())(vm, args[0]);
         return VAR(ret);
         return VAR(ret);
     });
     });
     PK_OBJ_GET(NativeFunc, nf).set_userdata(f);
     PK_OBJ_GET(NativeFunc, nf).set_userdata(f);

+ 9 - 9
tests/70_collections.py

@@ -2,15 +2,15 @@ from collections import Counter, deque, defaultdict
 import random
 import random
 import pickle
 import pickle
 import gc
 import gc
-import builtins
-
-dd_dict_keys = sorted(defaultdict.__dict__.keys())
-d_dict_keys = sorted(dict.__dict__.keys())
-d_dict_keys.remove('__new__')
-if dd_dict_keys != d_dict_keys:
-    print("dd_dict_keys:", dd_dict_keys)
-    print("d_dict_keys:", d_dict_keys)
-    raise Exception("dd_dict_keys != d_dict_keys")
+
+# test defaultdict
+assert issubclass(defaultdict, dict)
+a = defaultdict(int)
+a['1'] += 1
+assert a == {'1': 1}
+a = defaultdict(list)
+a['1'].append(1)
+assert a == {'1': [1]}
 
 
 q = deque()
 q = deque()
 q.append(1)
 q.append(1)

Niektoré súbory nie sú zobrazené, pretože je v týchto rozdielových dátach zmenené mnoho súborov