dataclasses.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. def _wrapped__init__(self, *args, **kwargs):
  2. cls = type(self)
  3. cls_d = cls.__dict__
  4. fields: tuple[str] = cls.__annotations__
  5. i = 0 # index into args
  6. for field in fields:
  7. if field in kwargs:
  8. setattr(self, field, kwargs.pop(field))
  9. else:
  10. if i < len(args):
  11. setattr(self, field, args[i])
  12. ++i
  13. elif field in cls_d: # has default value
  14. setattr(self, field, cls_d[field])
  15. else:
  16. raise TypeError(f"{cls.__name__} missing required argument {field!r}")
  17. if len(args) > i:
  18. raise TypeError(f"{cls.__name__} takes {len(field)} positional arguments but {len(args)} were given")
  19. if len(kwargs) > 0:
  20. raise TypeError(f"{cls.__name__} got an unexpected keyword argument {next(iter(kwargs))!r}")
  21. def _wrapped__repr__(self):
  22. fields: tuple[str] = type(self).__annotations__
  23. obj_d = self.__dict__
  24. args: list = [f"{field}={obj_d[field]!r}" for field in fields]
  25. return f"{type(self).__name__}({', '.join(args)})"
  26. def _wrapped__eq__(self, other):
  27. if type(self) is not type(other):
  28. return False
  29. fields: tuple[str] = type(self).__annotations__
  30. for field in fields:
  31. if getattr(self, field) != getattr(other, field):
  32. return False
  33. return True
  34. def dataclass(cls: type):
  35. assert type(cls) is type
  36. cls_d = cls.__dict__
  37. if '__init__' not in cls_d:
  38. cls.__init__ = _wrapped__init__
  39. if '__repr__' not in cls_d:
  40. cls.__repr__ = _wrapped__repr__
  41. if '__eq__' not in cls_d:
  42. cls.__eq__ = _wrapped__eq__
  43. fields: tuple[str] = cls.__annotations__
  44. has_default = False
  45. for field in fields:
  46. if field in cls_d:
  47. has_default = True
  48. else:
  49. if has_default:
  50. raise TypeError(f"non-default argument {field!r} follows default argument")
  51. return cls
  52. def asdict(obj) -> dict:
  53. fields: tuple[str] = type(obj).__annotations__
  54. obj_d = obj.__dict__
  55. return {field: obj_d[field] for field in fields}