dataclasses.cpp 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. #include "pocketpy/dataclasses.h"
  2. namespace pkpy{
  3. static void patch__init__(VM* vm, Type cls){
  4. vm->bind(vm->_t(cls), "__init__(self, *args, **kwargs)", [](VM* vm, ArgsView _view){
  5. PyObject* self = _view[0];
  6. const Tuple& args = CAST(Tuple&, _view[1]);
  7. const Dict& kwargs_ = CAST(Dict&, _view[2]);
  8. NameDict kwargs;
  9. kwargs_.apply([&](PyObject* k, PyObject* v){
  10. kwargs.set(CAST(Str&, k), v);
  11. });
  12. Type cls = vm->_tp(self);
  13. const PyTypeInfo* cls_info = &vm->_all_types[cls];
  14. NameDict& cls_d = cls_info->obj->attr();
  15. const auto& fields = cls_info->annotated_fields;
  16. int i = 0; // index into args
  17. for(StrName field: fields){
  18. if(kwargs.contains(field)){
  19. self->attr().set(field, kwargs[field]);
  20. kwargs.del(field);
  21. }else{
  22. if(i < args.size()){
  23. self->attr().set(field, args[i]);
  24. ++i;
  25. }else if(cls_d.contains(field)){ // has default value
  26. self->attr().set(field, cls_d[field]);
  27. }else{
  28. vm->TypeError(_S(cls_info->name, " missing required argument ", field.escape()));
  29. }
  30. }
  31. }
  32. if(args.size() > i){
  33. vm->TypeError(_S(cls_info->name, " takes ", fields.size(), " positional arguments but ", args.size(), " were given"));
  34. }
  35. if(kwargs.size() > 0){
  36. StrName unexpected_key = kwargs.items()[0].first;
  37. vm->TypeError(_S(cls_info->name, " got an unexpected keyword argument ", unexpected_key.escape()));
  38. }
  39. return vm->None;
  40. });
  41. }
  42. static void patch__repr__(VM* vm, Type cls){
  43. vm->bind__repr__(cls, [](VM* vm, PyObject* _0){
  44. auto _lock = vm->heap.gc_scope_lock();
  45. const PyTypeInfo* cls_info = &vm->_all_types[vm->_tp(_0)];
  46. const auto& fields = cls_info->annotated_fields;
  47. const NameDict& obj_d = _0->attr();
  48. SStream ss;
  49. ss << cls_info->name << "(";
  50. bool first = true;
  51. for(StrName field: fields){
  52. if(first) first = false;
  53. else ss << ", ";
  54. ss << field << "=" << CAST(Str&, vm->py_repr(obj_d[field]));
  55. }
  56. ss << ")";
  57. return VAR(ss.str());
  58. });
  59. }
  60. static void patch__eq__(VM* vm, Type cls){
  61. vm->bind__eq__(cls, [](VM* vm, PyObject* _0, PyObject* _1){
  62. if(vm->_tp(_0) != vm->_tp(_1)) return vm->NotImplemented;
  63. const PyTypeInfo* cls_info = &vm->_all_types[vm->_tp(_0)];
  64. const auto& fields = cls_info->annotated_fields;
  65. for(StrName field: fields){
  66. PyObject* lhs = _0->attr(field);
  67. PyObject* rhs = _1->attr(field);
  68. if(vm->py_ne(lhs, rhs)) return vm->False;
  69. }
  70. return vm->True;
  71. });
  72. }
  73. void add_module_dataclasses(VM* vm){
  74. PyObject* mod = vm->new_module("dataclasses");
  75. vm->bind_func<1>(mod, "dataclass", [](VM* vm, ArgsView args){
  76. vm->check_non_tagged_type(args[0], VM::tp_type);
  77. Type cls = PK_OBJ_GET(Type, args[0]);
  78. NameDict& cls_d = args[0]->attr();
  79. if(!cls_d.contains("__init__")) patch__init__(vm, cls);
  80. if(!cls_d.contains("__repr__")) patch__repr__(vm, cls);
  81. if(!cls_d.contains("__eq__")) patch__eq__(vm, cls);
  82. const auto& fields = vm->_all_types[cls].annotated_fields;
  83. bool has_default = false;
  84. for(StrName field: fields){
  85. if(cls_d.contains(field)){
  86. has_default = true;
  87. }else{
  88. if(has_default){
  89. vm->TypeError(_S("non-default argument ", field.escape(), " follows default argument"));
  90. }
  91. }
  92. }
  93. return args[0];
  94. });
  95. vm->bind_func<1>(mod, "asdict", [](VM* vm, ArgsView args){
  96. const auto& fields = vm->_inst_type_info(args[0])->annotated_fields;
  97. const NameDict& obj_d = args[0]->attr();
  98. Dict d(vm);
  99. for(StrName field: fields){
  100. d.set(VAR(field.sv()), obj_d[field]);
  101. }
  102. return VAR(std::move(d));
  103. });
  104. }
  105. } // namespace pkpy