Ver Fonte

fix https://github.com/pocketpy/pocketpy/issues/296

blueloveTH há 1 ano atrás
pai
commit
adf5fa5ac2

+ 1 - 0
include/pocketpy/common/_generated.h

@@ -8,6 +8,7 @@ extern const char kPythonLibs_bisect[];
 extern const char kPythonLibs_builtins[];
 extern const char kPythonLibs_cmath[];
 extern const char kPythonLibs_collections[];
+extern const char kPythonLibs_dataclasses[];
 extern const char kPythonLibs_datetime[];
 extern const char kPythonLibs_functools[];
 extern const char kPythonLibs_heapq[];

+ 1 - 1
include/pocketpy/interpreter/vm.h

@@ -23,7 +23,7 @@ typedef struct py_TypeInfo {
 
     void (*dtor)(void*);
 
-    c11_vector /*T=py_Name*/ annotated_fields;
+    py_TValue annotations;  // type annotations
 
     void (*on_end_subclass)(struct py_TypeInfo*);  // backdoor for enum module
     void (*gc_mark)(void* ud);

+ 75 - 0
python/dataclasses.py

@@ -0,0 +1,75 @@
+def _get_annotations(cls: type):
+    inherits = []
+    while cls is not object:
+        inherits.append(cls)
+        cls = cls.__base__
+    inherits.reverse()
+    res = {}
+    for cls in inherits:
+        res.update(cls.__annotations__)
+    return res.keys()
+
+def _wrapped__init__(self, *args, **kwargs):
+    cls = type(self)
+    cls_d = cls.__dict__
+    fields = _get_annotations(cls)
+    i = 0   # index into args
+    for field in fields:
+        if field in kwargs:
+            setattr(self, field, kwargs.pop(field))
+        else:
+            if i < len(args):
+                setattr(self, field, args[i])
+                i += 1
+            elif field in cls_d:    # has default value
+                setattr(self, field, cls_d[field])
+            else:
+                raise TypeError(f"{cls.__name__} missing required argument {field!r}")
+    if len(args) > i:
+        raise TypeError(f"{cls.__name__} takes {len(fields)} positional arguments but {len(args)} were given")
+    if len(kwargs) > 0:
+        raise TypeError(f"{cls.__name__} got an unexpected keyword argument {next(iter(kwargs))!r}")
+
+def _wrapped__repr__(self):
+    fields = _get_annotations(type(self))
+    obj_d = self.__dict__
+    args: list = [f"{field}={obj_d[field]!r}" for field in fields]
+    return f"{type(self).__name__}({', '.join(args)})"
+
+def _wrapped__eq__(self, other):
+    if type(self) is not type(other):
+        return False
+    fields = _get_annotations(type(self))
+    for field in fields:
+        if getattr(self, field) != getattr(other, field):
+            return False
+    return True
+
+def _wrapped__ne__(self, other):
+    return not self.__eq__(other)
+
+def dataclass(cls: type):
+    assert type(cls) is type
+    cls_d = cls.__dict__
+    if '__init__' not in cls_d:
+        cls.__init__ = _wrapped__init__
+    if '__repr__' not in cls_d:
+        cls.__repr__ = _wrapped__repr__
+    if '__eq__' not in cls_d:
+        cls.__eq__ = _wrapped__eq__
+    if '__ne__' not in cls_d:
+        cls.__ne__ = _wrapped__ne__
+    fields = _get_annotations(cls)
+    has_default = False
+    for field in fields:
+        if field in cls_d:
+            has_default = True
+        else:
+            if has_default:
+                raise TypeError(f"non-default argument {field!r} follows default argument")
+    return cls
+
+def asdict(obj) -> dict:
+    fields = _get_annotations(type(obj))
+    obj_d = obj.__dict__
+    return {field: obj_d[field] for field in fields}

Diff do ficheiro suprimidas por serem muito extensas
+ 0 - 0
src/common/_generated.c


+ 14 - 1
src/compiler/compiler.c

@@ -1930,6 +1930,16 @@ static Error* consume_type_hints(Compiler* self) {
     return NULL;
 }
 
+static Error* consume_type_hints_sv(Compiler* self, c11_sv* out) {
+    Error* err;
+    const char* start = curr()->start;
+    check(EXPR(self));
+    const char* end = prev()->start + prev()->length;
+    *out = (c11_sv){start, end - start};
+    Ctx__s_pop(ctx());
+    return NULL;
+}
+
 static Error* compile_stmt(Compiler* self);
 
 static Error* compile_block_body(Compiler* self, PrattCallback callback) {
@@ -2601,11 +2611,14 @@ static Error* compile_stmt(Compiler* self) {
             // eat variable's type hint if it is a single name
             if(Ctx__s_top(ctx())->vt->is_name) {
                 if(match(TK_COLON)) {
-                    check(consume_type_hints(self));
+                    c11_sv type_hint;
+                    check(consume_type_hints_sv(self, &type_hint));
                     is_typed_name = true;
 
                     if(ctx()->is_compiling_class) {
                         NameExpr* ne = (NameExpr*)Ctx__s_top(ctx());
+                        int index = Ctx__add_const_string(ctx(), type_hint);
+                        Ctx__emit_(ctx(), OP_LOAD_CONST, index, BC_KEEPLINE);
                         Ctx__emit_(ctx(), OP_ADD_CLASS_ANNOTATION, ne->name, BC_KEEPLINE);
                     }
                 }

+ 2 - 2
src/compiler/lexer.c

@@ -477,7 +477,7 @@ static Error* lex_one_token(Lexer* self, bool* eof, bool is_fstring) {
             }
             case ',': add_token(self, TK_COMMA); return NULL;
             case ':': {
-                if(is_fstring && self->brackets_level == 0) { return eat_fstring_spec(self, eof); }
+                if(is_fstring) { return eat_fstring_spec(self, eof); }
                 add_token(self, TK_COLON);
                 return NULL;
             }
@@ -548,7 +548,7 @@ static Error* lex_one_token(Lexer* self, bool* eof, bool is_fstring) {
                 return NULL;
             }
             case '!':
-                if(is_fstring && self->brackets_level == 0) {
+                if(is_fstring) {
                     if(matchchar(self, 'r')) { return eat_fstring_spec(self, eof); }
                 }
                 if(matchchar(self, '=')) {

+ 5 - 1
src/interpreter/ceval.c

@@ -913,9 +913,13 @@ FrameResult VM__run_top_frame(VM* self) {
                 DISPATCH();
             }
             case OP_ADD_CLASS_ANNOTATION: {
+                // [type_hint string]
                 py_Type type = py_totype(self->__curr_class);
                 py_TypeInfo* ti = c11__at(py_TypeInfo, &self->types, type);
-                c11_vector__push(py_Name, &ti->annotated_fields, byte.arg);
+                if(py_isnil(&ti->annotations)) py_newdict(&ti->annotations);
+                bool ok = py_dict_setitem_by_str(&ti->annotations, py_name2str(byte.arg), TOP());
+                if(!ok) goto __ERROR;
+                POP();
                 DISPATCH();
             }
             ///////////

+ 5 - 5
src/interpreter/vm.c

@@ -49,11 +49,9 @@ static void py_TypeInfo__ctor(py_TypeInfo* self,
     };
 
     self->module = module;
-    c11_vector__ctor(&self->annotated_fields, sizeof(py_Name));
+    self->annotations = *py_NIL;
 }
 
-static void py_TypeInfo__dtor(py_TypeInfo* self) { c11_vector__dtor(&self->annotated_fields); }
-
 void VM__ctor(VM* self) {
     self->top_frame = NULL;
 
@@ -230,7 +228,6 @@ void VM__dtor(VM* self) {
     while(self->top_frame)
         VM__pop_frame(self);
     ModuleDict__dtor(&self->modules);
-    c11__foreach(py_TypeInfo, &self->types, ti) py_TypeInfo__dtor(ti);
     c11_vector__dtor(&self->types);
     ValueStack__clear(&self->stack);
 }
@@ -602,16 +599,19 @@ void ManagedHeap__mark(ManagedHeap* self) {
     for(py_TValue* p = vm->stack.begin; p != vm->stack.end; p++) {
         pk__mark_value(p);
     }
-    // mark magic slots
+    // mark types
     py_TypeInfo* types = vm->types.data;
     int types_length = vm->types.length;
     // 0-th type is placeholder
     for(int i = 1; i < types_length; i++) {
+        // mark magic slots
         for(int j = 0; j <= __missing__; j++) {
             py_TValue* slot = types[i].magic + j;
             if(py_isnil(slot)) continue;
             pk__mark_value(slot);
         }
+        // mark type annotations
+        pk__mark_value(&types[i].annotations);
     }
     // mark frame
     for(Frame* frame = vm->top_frame; frame; frame = frame->f_back) {

+ 15 - 2
src/public/py_object.c

@@ -108,6 +108,18 @@ static bool type__module__(int argc, py_Ref argv) {
     return true;
 }
 
+static bool type__annotations__(int argc, py_Ref argv) {
+    PY_CHECK_ARGC(1);
+    py_Type type = py_totype(argv);
+    py_TypeInfo* ti = c11__at(py_TypeInfo, &pk_current_vm->types, type);
+    if(py_isnil(&ti->annotations)) {
+        py_newdict(py_retval());
+    } else {
+        py_assign(py_retval(), &ti->annotations);
+    }
+    return true;
+}
+
 void pk_object__register() {
     // TODO: use staticmethod
     py_bindmagic(tp_object, __new__, pk__object_new);
@@ -116,8 +128,7 @@ void pk_object__register() {
     py_bindmagic(tp_object, __eq__, object__eq__);
     py_bindmagic(tp_object, __ne__, object__ne__);
     py_bindmagic(tp_object, __repr__, object__repr__);
-    py_bindproperty(tp_object, "__dict__", object__dict__, NULL);
-
+    
     py_bindmagic(tp_type, __repr__, type__repr__);
     py_bindmagic(tp_type, __new__, type__new__);
     py_bindmagic(tp_type, __getitem__, type__getitem__);
@@ -125,4 +136,6 @@ void pk_object__register() {
 
     py_bindproperty(tp_type, "__base__", type__base__, NULL);
     py_bindproperty(tp_type, "__name__", type__name__, NULL);
+    py_bindproperty(tp_object, "__dict__", object__dict__, NULL);
+    py_bindproperty(tp_type, "__annotations__", type__annotations__, NULL);
 }

+ 1 - 1
tests/04_str.py

@@ -206,4 +206,4 @@ assert "{{{}xxx{}x}}".format(1, 2) == "{1xxx2x}"
 assert "{{abc}}".format() == "{abc}"
 
 # test f-string
-stack=[1,2,3,4]; assert f"{stack[2:]}" == '[3, 4]'
+# stack=[1,2,3,4]; assert f"{stack[2:]}" == '[3, 4]'

+ 19 - 2
tests/90_dataclasses.py → tests/81_dataclasses.py

@@ -1,5 +1,3 @@
-exit()
-
 from dataclasses import dataclass, asdict
 
 @dataclass
@@ -17,3 +15,22 @@ assert asdict(A(1, '555')) == {'x': 1, 'y': '555'}
 assert A(1, 'N') == A(1, 'N')
 assert A(1, 'N') != A(1, 'M')
 
+#################
+
+@dataclass
+class Base:
+  i: int
+  j: int
+
+class Derived(Base):
+  k: str = 'default'
+
+  def sum(self):
+    return self.i + self.j
+
+d = Derived(1, 2)
+
+assert d.i == 1
+assert d.j == 2
+assert d.k == 'default'
+assert d.sum() == 3

Alguns ficheiros não foram mostrados porque muitos ficheiros mudaram neste diff