dataclasses.cpp 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. #include "pocketpy/dataclasses.h"
  2. namespace pkpy{
  3. static void patch__init__(VM* vm, Type cls){
  4. vm->bind(vm->_t(cls), "__init__(self, *args, **kwargs)", [](VM* vm, ArgsView _view){
  5. PyVar self = _view[0];
  6. const Tuple& args = CAST(Tuple&, _view[1]);
  7. const Dict& kwargs_ = CAST(Dict&, _view[2]);
  8. NameDict kwargs;
  9. kwargs_.apply([&](PyVar k, PyVar v){
  10. kwargs.set(CAST(Str&, k), v);
  11. });
  12. Type cls = vm->_tp(self);
  13. const PyTypeInfo* cls_info = &vm->_all_types[cls];
  14. NameDict& cls_d = cls_info->obj->attr();
  15. const auto& fields = cls_info->annotated_fields;
  16. int i = 0; // index into args
  17. for(StrName field: fields){
  18. if(kwargs.contains(field)){
  19. self->attr().set(field, kwargs[field]);
  20. kwargs.del(field);
  21. }else{
  22. if(i < args.size()){
  23. self->attr().set(field, args[i]);
  24. ++i;
  25. }else if(cls_d.contains(field)){ // has default value
  26. self->attr().set(field, cls_d[field]);
  27. }else{
  28. vm->TypeError(_S(cls_info->name, " missing required argument ", field.escape()));
  29. }
  30. }
  31. }
  32. if(args.size() > i){
  33. vm->TypeError(_S(cls_info->name, " takes ", fields.size(), " positional arguments but ", args.size(), " were given"));
  34. }
  35. if(kwargs.size() > 0){
  36. StrName unexpected_key = kwargs.items()[0].first;
  37. vm->TypeError(_S(cls_info->name, " got an unexpected keyword argument ", unexpected_key.escape()));
  38. }
  39. return vm->None;
  40. });
  41. }
  42. static void patch__repr__(VM* vm, Type cls){
  43. vm->bind__repr__(cls, [](VM* vm, PyVar _0) -> Str{
  44. const PyTypeInfo* cls_info = &vm->_all_types[vm->_tp(_0)];
  45. const auto& fields = cls_info->annotated_fields;
  46. const NameDict& obj_d = _0->attr();
  47. SStream ss;
  48. ss << cls_info->name << "(";
  49. bool first = true;
  50. for(StrName field: fields){
  51. if(first) first = false;
  52. else ss << ", ";
  53. ss << field << "=" << vm->py_repr(obj_d[field]);
  54. }
  55. ss << ")";
  56. return ss.str();
  57. });
  58. }
  59. static void patch__eq__(VM* vm, Type cls){
  60. vm->bind__eq__(cls, [](VM* vm, PyVar _0, PyVar _1){
  61. if(vm->_tp(_0) != vm->_tp(_1)) return vm->NotImplemented;
  62. const PyTypeInfo* cls_info = &vm->_all_types[vm->_tp(_0)];
  63. const auto& fields = cls_info->annotated_fields;
  64. for(StrName field: fields){
  65. PyVar lhs = _0->attr(field);
  66. PyVar rhs = _1->attr(field);
  67. if(vm->py_ne(lhs, rhs)) return vm->False;
  68. }
  69. return vm->True;
  70. });
  71. }
  72. void add_module_dataclasses(VM* vm){
  73. PyObject* mod = vm->new_module("dataclasses");
  74. vm->bind_func(mod, "dataclass", 1, [](VM* vm, ArgsView args){
  75. vm->check_type(args[0], VM::tp_type);
  76. Type cls = PK_OBJ_GET(Type, args[0]);
  77. NameDict& cls_d = args[0]->attr();
  78. if(!cls_d.contains(__init__)) patch__init__(vm, cls);
  79. if(!cls_d.contains(__repr__)) patch__repr__(vm, cls);
  80. if(!cls_d.contains(__eq__)) patch__eq__(vm, cls);
  81. const auto& fields = vm->_all_types[cls].annotated_fields;
  82. bool has_default = false;
  83. for(StrName field: fields){
  84. if(cls_d.contains(field)){
  85. has_default = true;
  86. }else{
  87. if(has_default){
  88. vm->TypeError(_S("non-default argument ", field.escape(), " follows default argument"));
  89. }
  90. }
  91. }
  92. return args[0];
  93. });
  94. vm->bind_func(mod, "asdict", 1, [](VM* vm, ArgsView args){
  95. const auto& fields = vm->_tp_info(args[0])->annotated_fields;
  96. const NameDict& obj_d = args[0]->attr();
  97. Dict d;
  98. for(StrName field: fields){
  99. d.set(vm, VAR(field.sv()), obj_d[field]);
  100. }
  101. return VAR(std::move(d));
  102. });
  103. }
  104. } // namespace pkpy