blueloveTH vor 2 Jahren
Ursprung
Commit
651a51dc49
2 geänderte Dateien mit 127 neuen und 62 gelöschten Zeilen
  1. 94 40
      python/pickle.py
  2. 33 22
      tests/81_pickle.py

+ 94 - 40
python/pickle.py

@@ -1,6 +1,8 @@
 import json
 import builtins
 
+_BASIC_TYPES = [int, float, str, bool, type(None)]
+
 def _find_class(path: str):
     if "." not in path:
         g = globals()
@@ -16,47 +18,92 @@ def _find__new__(cls):
         if "__new__" in d:
             return d["__new__"]
         cls = cls.__base__
-    raise PickleError(f"cannot find __new__ for {cls.__name__}")
+    assert False
 
-def _wrap(o):
-    if type(o) in (int, float, str, bool, type(None)):
-        return o
-    if type(o) is list:
-        return ["list", [_wrap(i) for i in o]]
-    if type(o) is tuple:
-        return ["tuple", [_wrap(i) for i in o]]
-    if type(o) is dict:
-        return ["dict", [[_wrap(k), _wrap(v)] for k,v in o.items()]]
-    if type(o) is bytes:
-        return ["bytes", [o[j] for j in range(len(o))]]
-    
-    _0 = o.__class__.__name__
-    if hasattr(o, "__getnewargs__"):
-        _1 = o.__getnewargs__()     # an iterable
-        _1 = [_wrap(i) for i in _1]
-    else:
-        _1 = None
-    if hasattr(o, "__getstate__"):
-        _2 = o.__getstate__()
-    else:
-        if o.__dict__ is None:
-            _2 = None
+class _Pickler:
+    def __init__(self) -> None:
+        self.raw_memo = {}  # id -> int
+        self.memo = []      # int -> object
+
+    def wrap(self, o):
+        if type(o) in _BASIC_TYPES:
+            return o
+        
+        index = self.raw_memo.get(id(o), None)
+        if index is not None:
+            return ["$", index]
+        
+        ret = []
+        index = len(self.memo)
+        self.memo.append(ret)
+        self.raw_memo[id(o)] = index
+
+        if type(o) is list:
+            ret.append("list")
+            ret.append([self.wrap(i) for i in o])
+            return ["$", index]
+
+        if type(o) is tuple:
+            ret.append("tuple")
+            ret.append([self.wrap(i) for i in o])
+            return ["$", index]
+        
+        if type(o) is dict:
+            ret.append("dict")
+            ret.append([[self.wrap(k), self.wrap(v)] for k,v in o.items()])
+            return ["$", index]
+        
+        if type(o) is bytes:
+            ret.append("bytes")
+            ret.append([o[j] for j in range(len(o))])
+            return ["$", index]
+        
+        _0 = o.__class__.__name__
+        if hasattr(o, "__getnewargs__"):
+            _1 = o.__getnewargs__()     # an iterable
+            _1 = [self.wrap(i) for i in _1]
+        else:
+            _1 = None
+        if hasattr(o, "__getstate__"):
+            _2 = o.__getstate__()
         else:
-            _2 = {}
-            for k,v in o.__dict__.items():
-                _2[k] = _wrap(v)
-    return [_0, _1, _2]
+            if o.__dict__ is None:
+                _2 = None
+            else:
+                _2 = {}
+                for k,v in o.__dict__.items():
+                    _2[k] = self.wrap(v)
+        ret.append(_0)
+        ret.append(_1)
+        ret.append(_2)
+        return ["$", index]
 
-def _unwrap(o):
-    if type(o) in (int, float, str, bool, type(None)):
-        return o
-    if isinstance(o, list):
+class _Unpickler:
+    def __init__(self, memo: list) -> None:
+        self.memo = memo
+        self.unwrapped = [None] * len(memo)
+
+    def unwrap_ref(self, i: int):
+        if self.unwrapped[i] is None:
+            o = self.memo[i]
+            assert type(o) is list
+            assert o[0] != '$'
+            self.unwrapped[i] = self.unwrap(o)
+        return self.unwrapped[i]
+
+    def unwrap(self, o):
+        if type(o) in _BASIC_TYPES:
+            return o
+        assert type(o) is list
+        if o[0] == '$':
+            index = o[1]
+            return self.unwrap_ref(index)
         if o[0] == "list":
-            return [_unwrap(i) for i in o[1]]
+            return [self.unwrap(i) for i in o[1]]
         if o[0] == "tuple":
-            return tuple([_unwrap(i) for i in o[1]])
+            return tuple([self.unwrap(i) for i in o[1]])
         if o[0] == "dict":
-            return {_unwrap(k): _unwrap(v) for k,v in o[1]}
+            return {self.unwrap(k): self.unwrap(v) for k,v in o[1]}
         if o[0] == "bytes":
             return bytes(o[1])
         # generic object
@@ -65,7 +112,7 @@ def _unwrap(o):
         # create uninitialized instance
         new_f = _find__new__(cls)
         if newargs is not None:
-            newargs = [_unwrap(i) for i in newargs]
+            newargs = [self.unwrap(i) for i in newargs]
             inst = new_f(cls, *newargs)
         else:
             inst = new_f(cls)
@@ -75,14 +122,21 @@ def _unwrap(o):
         else:
             if state is not None:
                 for k,v in state.items():
-                    setattr(inst, k, _unwrap(v))
+                    setattr(inst, k, self.unwrap(v))
         return inst
-    raise PickleError(f"cannot unpickle {type(o).__name__} object")
 
+def _wrap(o):
+    p = _Pickler()
+    o = p.wrap(o)
+    return [o, p.memo]
 
-def dumps(o) -> bytes:
-    return json.dumps(_wrap(o)).encode()
+def _unwrap(packed: list):
+    o, memo = packed
+    return _Unpickler(memo).unwrap(o)
 
+def dumps(o) -> bytes:
+    o = _wrap(o)
+    return json.dumps(o).encode()
 
 def loads(b) -> object:
     assert type(b) is bytes

+ 33 - 22
tests/81_pickle.py

@@ -1,22 +1,24 @@
 from pickle import dumps, loads, _wrap, _unwrap
 
-def test(x, y):
-    _0 = _wrap(x)
-    _1 = _unwrap(y)
-    assert _0 == y, f"{_0} != {y}"
-    assert _1 == x, f"{_1} != {x}"
-    assert x == loads(dumps(x))
-
-test(1, 1)
-test(1.0, 1.0)
-test("hello", "hello")
-test(True, True)
-test(False, False)
-test(None, None)
-
-test([1, 2, 3], ["list", [1, 2, 3]])
-test((1, 2, 3), ["tuple", [1, 2, 3]])
-test({1: 2, 3: 4}, ["dict", [[1, 2], [3, 4]]])
+def test(x):
+    ok = x == loads(dumps(x))
+    if not ok:
+        _0 = _wrap(x)
+        _1 = _unwrap(0)
+        print(_0)
+        print(_1)
+        assert False
+
+test(1)
+test(1.0)
+test("hello")
+test(True)
+test(False)
+test(None)
+
+test([1, 2, 3])
+test((1, 2, 3))
+test({1: 2, 3: 4})
 
 class Foo:
     def __init__(self, x, y):
@@ -31,15 +33,24 @@ class Foo:
     def __repr__(self) -> str:
         return f"Foo({self.x}, {self.y})"
     
-foo = Foo(1, 2)
-test(foo, ["__main__.Foo", None, {"x": 1, "y": 2}])
+test(Foo(1, 2))
+
+a = [1,2]
+test(Foo([1, 2], a))
 
 from linalg import vec2
 
-test(vec2(1, 2), ["linalg.vec2", [1, 2], None])
+test(vec2(1, 2))
 
 a = {1, 2, 3, 4}
-test(a, ['set', None, {'_a': ['dict', [[1, None], [2, None], [3, None], [4, None]]]}])
+test(a)
 
 a = bytes([1, 2, 3, 4])
-assert loads(dumps(a)) == a
+test(a)
+
+a = [1, 2]
+d = {'k': a, 'j': a}
+c = loads(dumps(d))
+
+assert c['k'] is c['j']
+assert c == d