فهرست منبع

improve `array2d`

blueloveTH 1 سال پیش
والد
کامیت
9165ca228a
4فایلهای تغییر یافته به همراه199 افزوده شده و 72 حذف شده
  1. 2 2
      include/pocketpy/config.h
  2. 18 6
      include/typings/array2d.pyi
  3. 145 44
      src/modules/array2d.c
  4. 34 20
      tests/90_array2d.py

+ 2 - 2
include/pocketpy/config.h

@@ -1,10 +1,10 @@
 #pragma once
 // clang-format off
 
-#define PK_VERSION				"2.0.2"
+#define PK_VERSION				"2.0.3"
 #define PK_VERSION_MAJOR            2
 #define PK_VERSION_MINOR            0
-#define PK_VERSION_PATCH            2
+#define PK_VERSION_PATCH            3
 
 /*************** feature settings ***************/
 

+ 18 - 6
include/typings/array2d.pyi

@@ -19,6 +19,8 @@ class array2d(Generic[T]):
 
     def __new__(cls, n_cols: int, n_rows: int, default=None): ...
     def __len__(self) -> int: ...
+    def __eq__(self, other: object) -> array2d[bool]: ... # type: ignore
+    def __ne__(self, other: object) -> array2d[bool]: ... # type: ignore
     def __repr__(self) -> str: ...
     def __iter__(self) -> Iterator[tuple[int, int, T]]: ...
 
@@ -39,29 +41,39 @@ class array2d(Generic[T]):
     @overload
     def __getitem__(self, index: vec2i) -> T: ...
     @overload
-    def __getitem__(self, index: tuple[slice, slice]) -> 'array2d[T]': ...
+    def __getitem__(self, index: tuple[slice, slice]) -> array2d[T]: ...
+    @overload
+    def __getitem__(self, mask: array2d[bool]) -> list[T]: ...
     @overload
     def __setitem__(self, index: tuple[int, int], value: T): ...
     @overload
     def __setitem__(self, index: vec2i, value: T): ...
     @overload
     def __setitem__(self, index: tuple[slice, slice], value: int | float | str | bool | None | 'array2d[T]'): ...
+    @overload
+    def __setitem__(self, mask: array2d[bool], value: T): ...
 
-    def map(self, f: Callable[[T], Any]) -> 'array2d': ...
+    def map[R](self, f: Callable[[T], R]) -> array2d[R]: ...
     def copy(self) -> 'array2d[T]': ...
 
     def fill_(self, value: T) -> None: ...
     def apply_(self, f: Callable[[T], T]) -> None: ...
-    def copy_(self, other: 'array2d[T] | list[T]') -> None: ...
+    def copy_(self, other: array2d[T] | list[T]) -> None: ...
 
-    def tolist(self) -> list[list[T]]: ...
     def render(self) -> str: ...
 
+    def all(self: array2d[bool]) -> bool: ...
+    def any(self: array2d[bool]) -> bool: ...
+    
+    @staticmethod
+    def fromlist(data: list[list[T]]) -> array2d[T]: ...
+    def tolist(self) -> list[list[T]]: ...
+
     # algorithms
     def count(self, value: T) -> int:
         """Counts the number of cells with the given value."""
 
-    def count_neighbors(self, value: T, neighborhood: Neighborhood) -> 'array2d[int]':
+    def count_neighbors(self, value: T, neighborhood: Neighborhood) -> array2d[int]:
         """Counts the number of neighbors with the given value for each cell."""
 
     def find_bounding_rect(self, value: T) -> tuple[int, int, int, int]:
@@ -70,5 +82,5 @@ class array2d(Generic[T]):
         Returns a tuple `(x, y, width, height)` or raise `ValueError` if the value is not found.
         """
 
-    def convolve(self: array2d[int], kernel: 'array2d[int]', padding: int) -> 'array2d[int]':
+    def convolve(self: array2d[int], kernel: array2d[int], padding: int) -> array2d[int]:
         """Convolves the array with the given kernel."""

+ 145 - 44
src/modules/array2d.c

@@ -160,22 +160,35 @@ static bool array2d__len__(int argc, py_Ref argv) {
     return true;
 }
 
-static bool array2d__eq__(int argc, py_Ref argv) {
-    PY_CHECK_ARGC(2);
-    c11_array2d* self = py_touserdata(argv);
-    if(!py_istype(py_arg(1), tp_array2d)) {
-        py_newnotimplemented(py_retval());
-        return true;
+static bool _array2d_check_all_type(c11_array2d* self, py_Type type) {
+    for(int i = 0; i < self->numel; i++) {
+        py_Type item_type = self->data[i].type;
+        if(item_type != type) {
+            const char* fmt = "expected array2d[%t], got %t";
+            return TypeError(fmt, type, item_type);
+        }
     }
-    c11_array2d* other = py_touserdata(py_arg(1));
-    if(self->n_cols != other->n_cols || self->n_rows != other->n_rows) {
-        py_newbool(py_retval(), false);
-        return true;
+    return true;
+}
+
+static bool _check_same_shape(int colA, int rowA, int colB, int rowB) {
+    if(colA != colB || rowA != rowB) {
+        const char* fmt = "expected the same shape: (%d, %d) != (%d, %d)";
+        return ValueError(fmt, colA, rowA, colB, rowB);
     }
+    return true;
+}
+
+static bool _array2d_check_same_shape(c11_array2d* self, c11_array2d* other) {
+    return _check_same_shape(self->n_cols, self->n_rows, other->n_cols, other->n_rows);
+}
+
+static bool array2d_all(int argc, py_Ref argv) {
+    PY_CHECK_ARGC(1);
+    c11_array2d* self = py_touserdata(argv);
+    if(!_array2d_check_all_type(self, tp_bool)) return false;
     for(int i = 0; i < self->numel; i++) {
-        int res = py_equal(self->data + i, other->data + i);
-        if(res == -1) return false;
-        if(res == 0) {
+        if(!py_tobool(self->data + i)) {
             py_newbool(py_retval(), false);
             return true;
         }
@@ -184,10 +197,53 @@ static bool array2d__eq__(int argc, py_Ref argv) {
     return true;
 }
 
+static bool array2d_any(int argc, py_Ref argv) {
+    PY_CHECK_ARGC(1);
+    c11_array2d* self = py_touserdata(argv);
+    if(!_array2d_check_all_type(self, tp_bool)) return false;
+    for(int i = 0; i < self->numel; i++) {
+        if(py_tobool(self->data + i)) {
+            py_newbool(py_retval(), true);
+            return true;
+        }
+    }
+    py_newbool(py_retval(), false);
+    return true;
+}
+
+static bool array2d__eq__(int argc, py_Ref argv) {
+    PY_CHECK_ARGC(2);
+    c11_array2d* self = py_touserdata(argv);
+    c11_array2d* res = py_array2d(py_pushtmp(), self->n_cols, self->n_rows);
+    if(py_istype(py_arg(1), tp_array2d)) {
+        c11_array2d* other = py_touserdata(py_arg(1));
+        if(!_array2d_check_same_shape(self, other)) return false;
+        for(int i = 0; i < self->numel; i++) {
+            int code = py_equal(self->data + i, other->data + i);
+            if(code == -1) return false;
+            py_newbool(res->data + i, (bool)code);
+        }
+    } else {
+        // broadcast
+        for(int i = 0; i < self->numel; i++) {
+            int code = py_equal(self->data + i, py_arg(1));
+            if(code == -1) return false;
+            py_newbool(res->data + i, (bool)code);
+        }
+    }
+    py_assign(py_retval(), py_peek(-1));
+    py_pop();
+    return true;
+}
+
 static bool array2d__ne__(int argc, py_Ref argv) {
     bool ok = array2d__eq__(argc, argv);
     if(!ok) return false;
-    if(py_isbool(py_retval())) { py_newbool(py_retval(), !py_tobool(py_retval())); }
+    c11_array2d* res = py_touserdata(py_retval());
+    py_TValue* data = res->data;
+    for(int i = 0; i < res->numel; i++) {
+        py_newbool(&data[i], !py_tobool(&data[i]));
+    }
     return true;
 }
 
@@ -211,8 +267,8 @@ static bool array2d__iter__(int argc, py_Ref argv) {
     return true;
 }
 
+// __iter__(self) -> Iterator[tuple[int, int, T]]
 static bool array2d_iterator__next__(int argc, py_Ref argv) {
-    // def __iter__(self) -> Iterator[tuple[int, int, T]]: ...
     PY_CHECK_ARGC(1);
     c11_array2d_iterator* self = py_touserdata(argv);
     if(self->index < self->array->numel) {
@@ -285,13 +341,7 @@ static bool array2d_copy_(int argc, py_Ref argv) {
     py_Type src_type = py_typeof(py_arg(1));
     if(src_type == tp_array2d) {
         c11_array2d* src = py_touserdata(py_arg(1));
-        if(self->n_cols != src->n_cols || self->n_rows != src->n_rows) {
-            return ValueError("copy_() expected the same shape: (%d, %d) != (%d, %d)",
-                              self->n_cols,
-                              self->n_rows,
-                              src->n_cols,
-                              src->n_rows);
-        }
+        if(!_array2d_check_same_shape(self, src)) return false;
         memcpy(self->data, src->data, self->numel * sizeof(py_TValue));
     } else {
         py_TValue* data;
@@ -309,8 +359,36 @@ static bool array2d_copy_(int argc, py_Ref argv) {
     return true;
 }
 
+// fromlist(data: list[list[T]]) -> array2d[T]
+static bool array2d_fromlist_STATIC(int argc, py_Ref argv) {
+    PY_CHECK_ARGC(1);
+    if(!py_checktype(argv, tp_list)) return false;
+    int n_rows = py_list_len(argv);
+    if(n_rows == 0) return ValueError("fromlist() expected a non-empty list");
+    int n_cols = -1;
+    for(int j = 0; j < n_rows; j++) {
+        py_Ref row_j = py_list_getitem(argv, j);
+        if(!py_checktype(row_j, tp_list)) return false;
+        int n_cols_j = py_list_len(row_j);
+        if(n_cols == -1) {
+            if(n_cols_j == 0) return ValueError("fromlist() expected a non-empty list");
+            n_cols = n_cols_j;
+        } else if(n_cols != n_cols_j) {
+            return ValueError("fromlist() expected a list of lists with the same length");
+        }
+    }
+    c11_array2d* res = py_array2d(py_retval(), n_cols, n_rows);
+    for(int j = 0; j < n_rows; j++) {
+        py_Ref row_j = py_list_getitem(argv, j);
+        for(int i = 0; i < n_cols; i++) {
+            py_array2d__set(res, i, j, py_list_getitem(row_j, i));
+        }
+    }
+    return true;
+}
+
+// tolist(self) -> list[list[T]]
 static bool array2d_tolist(int argc, py_Ref argv) {
-    // def tolist(self) -> list[list[T]]: ...
     PY_CHECK_ARGC(1);
     c11_array2d* self = py_touserdata(argv);
     py_newlistn(py_retval(), self->n_rows);
@@ -341,8 +419,8 @@ static bool array2d_render(int argc, py_Ref argv) {
     return true;
 }
 
+// count(self, value: T) -> int
 static bool array2d_count(int argc, py_Ref argv) {
-    // def count(self, value: T) -> int: ...
     PY_CHECK_ARGC(2);
     c11_array2d* self = py_touserdata(argv);
     int count = 0;
@@ -355,7 +433,9 @@ static bool array2d_count(int argc, py_Ref argv) {
     return true;
 }
 
+// find_bounding_rect(self, value: T) -> tuple[int, int, int, int]
 static bool array2d_find_bounding_rect(int argc, py_Ref argv) {
+    PY_CHECK_ARGC(2);
     c11_array2d* self = py_touserdata(argv);
     py_Ref value = py_arg(1);
     int left = self->n_cols;
@@ -389,8 +469,8 @@ static bool array2d_find_bounding_rect(int argc, py_Ref argv) {
     return true;
 }
 
+// count_neighbors(self, value: T, neighborhood: Neighborhood) -> array2d[int]
 static bool array2d_count_neighbors(int argc, py_Ref argv) {
-    // def count_neighbors(self, value: T, neighborhood: Neighborhood) -> 'array2d[int]': ...
     PY_CHECK_ARGC(3);
     c11_array2d* self = py_touserdata(argv);
     c11_array2d* res = py_array2d(py_pushtmp(), self->n_cols, self->n_rows);
@@ -471,6 +551,18 @@ static bool array2d__getitem__(int argc, py_Ref argv) {
         }
         return _array2d_IndexError(self, pos.x, pos.y);
     }
+
+    if(argv[1].type == tp_array2d) {
+        c11_array2d* mask = py_touserdata(&argv[1]);
+        if(!_array2d_check_same_shape(self, mask)) return false;
+        if(!_array2d_check_all_type(mask, tp_bool)) return false;
+        py_newlist(py_retval());
+        for(int i = 0; i < self->numel; i++) {
+            if(py_tobool(mask->data + i)) py_list_append(py_retval(), self->data + i);
+        }
+        return true;
+    }
+
     PY_CHECK_ARG_TYPE(1, tp_tuple);
     if(py_tuple_len(py_arg(1)) != 2) return TypeError("expected a tuple of 2 elements");
     py_Ref x = py_tuple_getitem(py_arg(1), 0);
@@ -511,6 +603,18 @@ static bool array2d__setitem__(int argc, py_Ref argv) {
         }
         return _array2d_IndexError(self, pos.x, pos.y);
     }
+
+    if(argv[1].type == tp_array2d) {
+        c11_array2d* mask = py_touserdata(&argv[1]);
+        if(!_array2d_check_same_shape(self, mask)) return false;
+        if(!_array2d_check_all_type(mask, tp_bool)) return false;
+        for(int i = 0; i < self->numel; i++) {
+            if(py_tobool(mask->data + i)) self->data[i] = *value;
+        }
+        py_newnone(py_retval());
+        return true;
+    }
+
     PY_CHECK_ARG_TYPE(1, tp_tuple);
     if(py_tuple_len(py_arg(1)) != 2) return TypeError("expected a tuple of 2 elements");
     py_Ref x = py_tuple_getitem(py_arg(1), 0);
@@ -548,13 +652,8 @@ static bool array2d__setitem__(int argc, py_Ref argv) {
             }
         } else {
             c11_array2d* src = py_touserdata(value);
-            if(slice_width != src->n_cols || slice_height != src->n_rows) {
-                return ValueError("expected the same shape: (%d, %d) != (%d, %d)",
-                                  slice_width,
-                                  slice_height,
-                                  src->n_cols,
-                                  src->n_rows);
-            }
+            if(!_check_same_shape(slice_width, slice_height, src->n_cols, src->n_rows))
+                return false;
             for(int j = 0; j < slice_height; j++) {
                 for(int i = 0; i < slice_width; i++) {
                     py_array2d__set(self, i + start_col, j + start_row, py_array2d__get(src, i, j));
@@ -568,14 +667,7 @@ static bool array2d__setitem__(int argc, py_Ref argv) {
     }
 }
 
-static bool _array2d_is_all_ints(c11_array2d* self) {
-    for(int i = 0; i < self->numel; i++) {
-        if(!py_isint(self->data + i)) return false;
-    }
-    return true;
-}
-
-// convolve(self: array2d[int], kernel: 'array2d[int]', padding: int = 0) -> 'array2d[int]'
+// convolve(self: array2d[int], kernel: array2d[int], padding: int) -> array2d[int]
 static bool array2d_convolve(int argc, py_Ref argv) {
     PY_CHECK_ARGC(3);
     PY_CHECK_ARG_TYPE(1, tp_array2d);
@@ -585,10 +677,10 @@ static bool array2d_convolve(int argc, py_Ref argv) {
     int padding = py_toint(py_arg(2));
     if(kernel->n_cols != kernel->n_rows) { return ValueError("kernel must be square"); }
     int ksize = kernel->n_cols;
-    if(ksize % 2 == 0) { return ValueError("kernel size must be odd"); }
+    if(ksize % 2 == 0) return ValueError("kernel size must be odd");
     int ksize_half = ksize / 2;
-    if(!_array2d_is_all_ints(self)) { return TypeError("self must be `array2d[int]`"); }
-    if(!_array2d_is_all_ints(kernel)) { return TypeError("kernel must be `array2d[int]`"); }
+    if(!_array2d_check_all_type(self, tp_int)) return false;
+    if(!_array2d_check_all_type(kernel, tp_int)) return false;
     c11_array2d* res = py_array2d(py_pushtmp(), self->n_cols, self->n_rows);
     for(int j = 0; j < self->n_rows; j++) {
         for(int i = 0; i < self->n_cols; i++) {
@@ -624,6 +716,9 @@ void pk__add_module_array2d() {
 
     py_setdict(mod, py_name("array2d"), py_tpobject(array2d));
 
+    // array2d is unhashable
+    py_setdict(py_tpobject(array2d), __hash__, py_None());
+
     py_bindmagic(array2d_iterator, __iter__, pk_wrapper__self);
     py_bindmagic(array2d_iterator, __next__, array2d_iterator__next__);
     py_bind(py_tpobject(array2d),
@@ -657,8 +752,14 @@ void pk__add_module_array2d() {
     py_bindmethod(array2d, "apply_", array2d_apply_);
     py_bindmethod(array2d, "copy_", array2d_copy_);
 
-    py_bindmethod(array2d, "tolist", array2d_tolist);
     py_bindmethod(array2d, "render", array2d_render);
+
+    py_bindmethod(array2d, "all", array2d_all);
+    py_bindmethod(array2d, "any", array2d_any);
+
+    py_bindstaticmethod(array2d, "fromlist", array2d_fromlist_STATIC);
+    py_bindmethod(array2d, "tolist", array2d_tolist);
+
     py_bindmethod(array2d, "count", array2d_count);
     py_bindmethod(array2d, "find_bounding_rect", array2d_find_bounding_rect);
     py_bindmethod(array2d, "count_neighbors", array2d_count_neighbors);

+ 34 - 20
tests/90_array2d.py

@@ -9,7 +9,7 @@ except ValueError:
     pass
 
 # test callable constructor
-a = array2d(2, 4, lambda: 0)
+a = array2d[int](2, 4, lambda: 0)
 
 assert a.width == a.n_cols == 2
 assert a.height == a.n_rows == 4
@@ -49,7 +49,7 @@ try:
 except IndexError:
     pass
 
-# test __iter__
+# test tolist
 a_list = [[5, 0], [0, 0], [0, 0], [0, 6]]
 assert a_list == a.tolist()
 
@@ -59,10 +59,10 @@ assert len(a) == 4*2
 # test __eq__
 x = array2d(2, 4, default=0)
 b = array2d(2, 4, default=0)
-assert x == b
+assert (x == b).all()
 
 b[0, 0] = 1
-assert x != b
+assert (x != b).any()
 
 # test __repr__
 assert repr(a) == f'array2d(2, 4)'
@@ -77,22 +77,22 @@ assert c.numel == 8
 
 # test copy
 d = c.copy()
-assert d == c and d is not c
+assert (d == c).all() and d is not c
 
 # test fill_
 d.fill_(-3)
-assert d == array2d(2, 4, default=-3)
+assert (d == array2d(2, 4, default=-3)).all()
 
 # test apply_
 d.apply_(lambda x: x + 3)
-assert d == array2d(2, 4, default=0)
+assert (d == array2d(2, 4, default=0)).all()
 
 # test copy_
 a.copy_(d)
-assert a == d and a is not d
+assert (a == d).all() and a is not d
 x = array2d(2, 4, default=0)
 x.copy_(d)
-assert x == d and x is not d
+assert (x == d).all() and x is not d
 x.copy_([1, 2, 3, 4, 5, 6, 7, 8])
 assert x.tolist() == [[1, 2], [3, 4], [5, 6], [7, 8]]
 
@@ -115,14 +115,12 @@ assert _0 == moore_result
 _1 = a.count_neighbors(1, 'von Neumann')
 assert _1 == von_neumann_result
 
-MOORE_KERNEL = array2d[int](3, 3, default=1)
-MOORE_KERNEL[1, 1] = 0
-VON_NEUMANN_KERNEL = array2d[int](3, 3, default=0)
-VON_NEUMANN_KERNEL[0, 1] = VON_NEUMANN_KERNEL[1, 0] = VON_NEUMANN_KERNEL[1, 2] = VON_NEUMANN_KERNEL[2, 1] = 1
+MOORE_KERNEL = array2d[int].fromlist([[1, 1, 1], [1, 0, 1], [1, 1, 1]])
+VON_NEUMANN_KERNEL = array2d.fromlist([[0, 1, 0], [1, 0, 1], [0, 1, 0]])
 moore_conv_result = a.convolve(MOORE_KERNEL, 0)
-assert moore_conv_result == moore_result
+assert (moore_conv_result == moore_result).all()
 von_neumann_conv_result = a.convolve(VON_NEUMANN_KERNEL, 0)
-assert von_neumann_conv_result == von_neumann_result
+assert (von_neumann_conv_result == von_neumann_result).all()
 
 # test slice get
 a = array2d(5, 5, default=0)
@@ -130,9 +128,9 @@ b = array2d(3, 2, default=1)
 
 assert a[1:4, 1:4] == array2d(3, 3, default=0)
 assert a[1:4, 1:3] == array2d(3, 2, default=0)
-assert a[1:4, 1:3] != b
+assert (a[1:4, 1:3] != b).any()
 a[1:4, 1:3] = b
-assert a[1:4, 1:3] == b
+assert (a[1:4, 1:3] == b).all()
 """
 0 0 0 0 0
 0 1 1 1 0
@@ -174,6 +172,7 @@ try:
 except TypeError:
     pass
 
+# test __iter__
 a = array2d(3, 4, default=1)
 for i, j, x in a:
     assert a[i, j] == x
@@ -187,13 +186,11 @@ a.unsafe_set(0, 0, 2)
 assert a.unsafe_get(0, 0) == 2
 
 # test convolve
-a = array2d[int](5, 2, default=0)
+a = array2d[int].fromlist([[1, 0, 2, 4, 0], [3, 1, 0, 5, 1]])
 """
 1 0 2 4 0
 3 1 0 5 1
 """
-a[0, 0] = 1; a[1, 0] = 0; a[2, 0] = 2; a[3, 0] = 4; a[4, 0] = 0
-a[0, 1] = 3; a[1, 1] = 1; a[2, 1] = 0; a[3, 1] = 5; a[4, 1] = 1
 assert a.tolist() == [[1, 0, 2, 4, 0], [3, 1, 0, 5, 1]]
 
 kernel = array2d[int](3, 3, default=1)
@@ -204,6 +201,23 @@ res = a.convolve(kernel, -1)
 """
 assert res.tolist() == [[0, 4, 9, 9, 5], [0, 4, 9, 9, 5]]
 
+mask = res == 9
+assert mask.tolist() == [
+    [False, False, True, True, False],
+    [False, False, True, True, False]
+    ]
+assert res[mask] == [9, 9, 9, 9]
+
+mask = res != 9
+assert mask.tolist() == [
+    [True, True, False, False, True],
+    [True, True, False, False, True]
+    ]
+assert res[mask] == [0, 4, 5, 0, 4, 5]
+res[mask] = -1
+assert res.tolist() == [[-1, -1, 9, 9, -1], [-1, -1, 9, 9, -1]]
+
+
 # stackoverflow bug due to recursive mark-and-sweep
 # class Cell:
 #     neighbors: list['Cell']