Parcourir la source

add `array2d.__iter__`

blueloveTH il y a 1 an
Parent
commit
665b95a162
3 fichiers modifiés avec 37 ajouts et 3 suppressions
  1. 2 1
      include/typings/array2d.pyi
  2. 29 1
      src/array2d.cpp
  3. 6 1
      tests/83_array2d.py

+ 2 - 1
include/typings/array2d.pyi

@@ -1,4 +1,4 @@
-from typing import Callable, Any, Generic, TypeVar, Literal, overload
+from typing import Callable, Any, Generic, TypeVar, Literal, overload, Iterator
 
 
 T = TypeVar('T')
 T = TypeVar('T')
 
 
@@ -27,6 +27,7 @@ class array2d(Generic[T]):
     def __setitem__(self, index: tuple[slice, slice], value: int | float | str | bool | None | 'array2d[T]'): ...
     def __setitem__(self, index: tuple[slice, slice], value: int | float | str | bool | None | 'array2d[T]'): ...
 
 
     def __len__(self) -> int: ...
     def __len__(self) -> int: ...
+    def __iter__(self) -> Iterator[T]: ...
     def __eq__(self, other: 'array2d') -> bool: ...
     def __eq__(self, other: 'array2d') -> bool: ...
     def __ne__(self, other: 'array2d') -> bool: ...
     def __ne__(self, other: 'array2d') -> bool: ...
     def __repr__(self): ...
     def __repr__(self): ...

+ 29 - 1
src/array2d.cpp

@@ -181,7 +181,7 @@ struct Array2d{
 
 
         vm->bind__len__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* _0){
         vm->bind__len__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* _0){
             Array2d& self = PK_OBJ_GET(Array2d, _0);
             Array2d& self = PK_OBJ_GET(Array2d, _0);
-            return (i64)self.n_rows;
+            return (i64)self.numel;
         });
         });
 
 
         vm->bind__repr__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* _0){
         vm->bind__repr__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* _0){
@@ -354,10 +354,38 @@ struct Array2d{
     }
     }
 };
 };
 
 
+
+struct Array2dIter{
+    PY_CLASS(Array2dIter, array2d, _array2d_iterator)
+    PyObject* ref;
+    int i;
+    Array2dIter(PyObject* ref) : ref(ref), i(0) {}
+
+    void _gc_mark() const{ PK_OBJ_MARK(ref); }
+
+    static void _register(VM* vm, PyObject* mod, PyObject* type){
+        vm->_all_types[PK_OBJ_GET(Type, type)].subclass_enabled = false;
+        vm->bind_notimplemented_constructor<Array2dIter>(type);
+        vm->bind__iter__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* obj){ return obj; });
+        vm->bind__next__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* obj){
+            Array2dIter& self = _CAST(Array2dIter&, obj);
+            Array2d& a = PK_OBJ_GET(Array2d, self.ref);
+            if(self.i == a.numel) return vm->StopIteration;
+            std::div_t res = std::div(self.i, a.n_cols);
+            return VAR(Tuple(VAR(res.rem), VAR(res.quot), a.data[self.i++]));
+        });
+    }
+};
+
 void add_module_array2d(VM* vm){
 void add_module_array2d(VM* vm){
     PyObject* mod = vm->new_module("array2d");
     PyObject* mod = vm->new_module("array2d");
 
 
     Array2d::register_class(vm, mod);
     Array2d::register_class(vm, mod);
+    Array2dIter::register_class(vm, mod);
+
+    vm->bind__iter__(Array2d::_type(vm), [](VM* vm, PyObject* obj){
+        return VAR_T(Array2dIter, obj);
+    });
 }
 }
 
 
 
 

+ 6 - 1
tests/83_array2d.py

@@ -53,7 +53,7 @@ a_list = [[5, 0], [0, 0], [0, 0], [0, 6]]
 assert a_list == a.tolist()
 assert a_list == a.tolist()
 
 
 # test __len__
 # test __len__
-assert len(a) == 4
+assert len(a) == 4*2
 
 
 # test __eq__
 # test __eq__
 x = array2d(2, 4, default=0)
 x = array2d(2, 4, default=0)
@@ -172,3 +172,8 @@ a.indexed_apply_(lambda x, y, val: x+y)
 assert a[0, 0] == 0
 assert a[0, 0] == 0
 assert a[1, 2] == 3
 assert a[1, 2] == 3
 assert a[2, 0] == 2
 assert a[2, 0] == 2
+
+for i, j, x in a:
+    assert a[i, j] == x
+
+assert len(a) == a.numel