pickle.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import json
  2. import builtins
  3. _BASIC_TYPES = [int, float, str, bool, type(None)]
  4. def _find_class(path: str):
  5. if "." not in path:
  6. g = globals()
  7. if path in g:
  8. return g[path]
  9. return builtins.__dict__[path]
  10. modname, name = path.split(".")
  11. return __import__(modname).__dict__[name]
  12. def _find__new__(cls):
  13. while cls is not None:
  14. d = cls.__dict__
  15. if "__new__" in d:
  16. return d["__new__"]
  17. cls = cls.__base__
  18. assert False
  19. class _Pickler:
  20. def __init__(self, obj) -> None:
  21. self.obj = obj
  22. self.raw_memo = {} # id -> int
  23. self.memo = [] # int -> object
  24. def wrap(self, o):
  25. if type(o) in _BASIC_TYPES:
  26. return o
  27. if type(o) is type:
  28. return ["type", o.__name__]
  29. index = self.raw_memo.get(id(o), None)
  30. if index is not None:
  31. return [index]
  32. ret = []
  33. index = len(self.memo)
  34. self.memo.append(ret)
  35. self.raw_memo[id(o)] = index
  36. if type(o) is tuple:
  37. ret.append("tuple")
  38. ret.append([self.wrap(i) for i in o])
  39. return [index]
  40. if type(o) is bytes:
  41. ret.append("bytes")
  42. ret.append([o[j] for j in range(len(o))])
  43. return [index]
  44. if type(o) is list:
  45. ret.append("list")
  46. ret.append([self.wrap(i) for i in o])
  47. return [index]
  48. if type(o) is dict:
  49. ret.append("dict")
  50. ret.append([[self.wrap(k), self.wrap(v)] for k,v in o.items()])
  51. return [index]
  52. _0 = o.__class__.__name__
  53. if hasattr(o, "__getnewargs__"):
  54. _1 = o.__getnewargs__() # an iterable
  55. _1 = [self.wrap(i) for i in _1]
  56. else:
  57. _1 = None
  58. if o.__dict__ is None:
  59. _2 = None
  60. else:
  61. _2 = {}
  62. for k,v in o.__dict__.items():
  63. _2[k] = self.wrap(v)
  64. ret.append(_0)
  65. ret.append(_1)
  66. ret.append(_2)
  67. return [index]
  68. def run_pipe(self):
  69. o = self.wrap(self.obj)
  70. return [o, self.memo]
  71. class _Unpickler:
  72. def __init__(self, obj, memo: list) -> None:
  73. self.obj = obj
  74. self.memo = memo
  75. self._unwrapped = [None] * len(memo)
  76. def tag(self, index, o):
  77. assert self._unwrapped[index] is None
  78. self._unwrapped[index] = o
  79. def unwrap(self, o, index=None):
  80. if type(o) in _BASIC_TYPES:
  81. return o
  82. assert type(o) is list
  83. if o[0] == "type":
  84. return _find_class(o[1])
  85. # reference
  86. if type(o[0]) is int:
  87. assert index is None # index should be None
  88. index = o[0]
  89. if self._unwrapped[index] is None:
  90. o = self.memo[index]
  91. assert type(o) is list
  92. assert type(o[0]) is str
  93. self.unwrap(o, index)
  94. assert self._unwrapped[index] is not None
  95. return self._unwrapped[index]
  96. # concrete reference type
  97. if o[0] == "tuple":
  98. ret = tuple([self.unwrap(i) for i in o[1]])
  99. self.tag(index, ret)
  100. return ret
  101. if o[0] == "bytes":
  102. ret = bytes(o[1])
  103. self.tag(index, ret)
  104. return ret
  105. if o[0] == "list":
  106. ret = []
  107. self.tag(index, ret)
  108. for i in o[1]:
  109. ret.append(self.unwrap(i))
  110. return ret
  111. if o[0] == "dict":
  112. ret = {}
  113. self.tag(index, ret)
  114. for k,v in o[1]:
  115. ret[self.unwrap(k)] = self.unwrap(v)
  116. return ret
  117. # generic object
  118. cls, newargs, state = o
  119. cls = _find_class(o[0])
  120. # create uninitialized instance
  121. new_f = _find__new__(cls)
  122. if newargs is not None:
  123. newargs = [self.unwrap(i) for i in newargs]
  124. inst = new_f(cls, *newargs)
  125. else:
  126. inst = new_f(cls)
  127. self.tag(index, inst)
  128. # restore state
  129. if state is not None:
  130. for k,v in state.items():
  131. setattr(inst, k, self.unwrap(v))
  132. return inst
  133. def run_pipe(self):
  134. return self.unwrap(self.obj)
  135. def _wrap(o):
  136. return _Pickler(o).run_pipe()
  137. def _unwrap(packed: list):
  138. return _Unpickler(*packed).run_pipe()
  139. def dumps(o) -> bytes:
  140. o = _wrap(o)
  141. return json.dumps(o).encode()
  142. def loads(b) -> object:
  143. assert type(b) is bytes
  144. o = json.loads(b.decode())
  145. return _unwrap(o)