dataclasses.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. def _get_annotations(cls: type):
  2. inherits = []
  3. while cls is not object:
  4. inherits.append(cls)
  5. cls = cls.__base__
  6. inherits.reverse()
  7. res = {}
  8. for cls in inherits:
  9. res.update(cls.__annotations__)
  10. return res.keys()
  11. def _wrapped__init__(self, *args, **kwargs):
  12. cls = type(self)
  13. cls_d = cls.__dict__
  14. fields = _get_annotations(cls)
  15. i = 0 # index into args
  16. for field in fields:
  17. if field in kwargs:
  18. setattr(self, field, kwargs.pop(field))
  19. else:
  20. if i < len(args):
  21. setattr(self, field, args[i])
  22. i += 1
  23. elif field in cls_d: # has default value
  24. setattr(self, field, cls_d[field])
  25. else:
  26. raise TypeError(f"{cls.__name__} missing required argument {field!r}")
  27. if len(args) > i:
  28. raise TypeError(f"{cls.__name__} takes {len(fields)} positional arguments but {len(args)} were given")
  29. if len(kwargs) > 0:
  30. raise TypeError(f"{cls.__name__} got an unexpected keyword argument {next(iter(kwargs))!r}")
  31. def _wrapped__repr__(self):
  32. fields = _get_annotations(type(self))
  33. obj_d = self.__dict__
  34. args: list = [f"{field}={obj_d[field]!r}" for field in fields]
  35. return f"{type(self).__name__}({', '.join(args)})"
  36. def _wrapped__eq__(self, other):
  37. if type(self) is not type(other):
  38. return False
  39. fields = _get_annotations(type(self))
  40. for field in fields:
  41. if getattr(self, field) != getattr(other, field):
  42. return False
  43. return True
  44. def _wrapped__ne__(self, other):
  45. return not self.__eq__(other)
  46. def dataclass(cls: type):
  47. assert type(cls) is type
  48. cls_d = cls.__dict__
  49. if '__init__' not in cls_d:
  50. cls.__init__ = _wrapped__init__
  51. if '__repr__' not in cls_d:
  52. cls.__repr__ = _wrapped__repr__
  53. if '__eq__' not in cls_d:
  54. cls.__eq__ = _wrapped__eq__
  55. if '__ne__' not in cls_d:
  56. cls.__ne__ = _wrapped__ne__
  57. fields = _get_annotations(cls)
  58. has_default = False
  59. for field in fields:
  60. if field in cls_d:
  61. has_default = True
  62. else:
  63. if has_default:
  64. raise TypeError(f"non-default argument {field!r} follows default argument")
  65. return cls
  66. def asdict(obj) -> dict:
  67. fields = _get_annotations(type(obj))
  68. obj_d = obj.__dict__
  69. return {field: obj_d[field] for field in fields}