blueloveTH 1 год назад
Родитель
Сommit
cfac4cbbca
3 измененных файлов с 119 добавлено и 2 удалено
  1. 9 0
      include/typings/array2d.pyi
  2. 71 0
      src/modules/array2d.c
  3. 39 2
      tests/90_array2d.py

+ 9 - 0
include/typings/array2d.pyi

@@ -69,3 +69,12 @@ class array2d(Generic[T]):
         
         Returns a tuple `(x, y, width, height)` or raise `ValueError` if the value is not found.
         """
+
+    def find_one(self, condition: Callable[[T], bool]) -> vec2i:
+        """Finds the position of the first cell that satisfies the condition.
+        
+        Returns a `vec2i` or raise `ValueError` if no cell satisfies the condition.
+        """
+
+    def convolve(self: array2d[int], kernel: 'array2d[int]', padding: int) -> 'array2d[int]':
+        """Convolves the array with the given kernel."""

+ 71 - 0
src/modules/array2d.c

@@ -568,6 +568,75 @@ static bool array2d__setitem__(int argc, py_Ref argv) {
     }
 }
 
+// find_one(self, condition: Callable[[T], bool]) -> vec2i
+static bool array2d_find_one(int argc, py_Ref argv) {
+    PY_CHECK_ARGC(2);
+    c11_array2d* self = py_touserdata(argv);
+    py_Ref condition = py_arg(1);
+    for(int j = 0; j < self->n_rows; j++) {
+        for(int i = 0; i < self->n_cols; i++) {
+            bool ok = py_call(condition, 1, py_array2d__get(self, i, j));
+            if(!ok) return false;
+            if(!py_isbool(py_retval())) return TypeError("condition must return a bool");
+            if(py_tobool(py_retval())) {
+                py_newvec2i(py_retval(),
+                            (c11_vec2i){
+                                {i, j}
+                });
+                return true;
+            }
+        }
+    }
+    return ValueError("condition not met");
+}
+
+static bool _array2d_is_all_ints(c11_array2d* self) {
+    for(int i = 0; i < self->numel; i++) {
+        if(!py_isint(self->data + i)) return false;
+    }
+    return true;
+}
+
+// convolve(self: array2d[int], kernel: 'array2d[int]', padding: int = 0) -> 'array2d[int]'
+static bool array2d_convolve(int argc, py_Ref argv) {
+    PY_CHECK_ARGC(3);
+    PY_CHECK_ARG_TYPE(1, tp_array2d);
+    PY_CHECK_ARG_TYPE(2, tp_int);
+    c11_array2d* self = py_touserdata(argv);
+    c11_array2d* kernel = py_touserdata(py_arg(1));
+    int padding = py_toint(py_arg(2));
+    if(kernel->n_cols != kernel->n_rows) { return ValueError("kernel must be square"); }
+    int ksize = kernel->n_cols;
+    if(ksize % 2 == 0) { return ValueError("kernel size must be odd"); }
+    int ksize_half = ksize / 2;
+    if(!_array2d_is_all_ints(self)) { return TypeError("self must be `array2d[int]`"); }
+    if(!_array2d_is_all_ints(kernel)) { return TypeError("kernel must be `array2d[int]`"); }
+    c11_array2d* res = py_array2d(py_pushtmp(), self->n_cols, self->n_rows);
+    for(int j = 0; j < self->n_rows; j++) {
+        for(int i = 0; i < self->n_cols; i++) {
+            py_i64 sum = 0;
+            for(int jj = 0; jj < ksize; jj++) {
+                for(int ii = 0; ii < ksize; ii++) {
+                    int x = i + ii - ksize_half;
+                    int y = j + jj - ksize_half;
+                    py_i64 _0, _1;
+                    if(x < 0 || x >= self->n_cols || y < 0 || y >= self->n_rows) {
+                        _0 = padding;
+                    } else {
+                        _0 = py_toint(py_array2d__get(self, x, y));
+                    }
+                    _1 = py_toint(py_array2d__get(kernel, ii, jj));
+                    sum += _0 * _1;
+                }
+            }
+            py_newint(py_array2d__get(res, i, j), sum);
+        }
+    }
+    py_assign(py_retval(), py_peek(-1));
+    py_pop();
+    return true;
+}
+
 void pk__add_module_array2d() {
     py_GlobalRef mod = py_newmodule("array2d");
     py_Type array2d = pk_newtype("array2d", tp_object, mod, NULL, false, true);
@@ -615,6 +684,8 @@ void pk__add_module_array2d() {
     py_bindmethod(array2d, "count", array2d_count);
     py_bindmethod(array2d, "find_bounding_rect", array2d_find_bounding_rect);
     py_bindmethod(array2d, "count_neighbors", array2d_count_neighbors);
+    py_bindmethod(array2d, "find_one", array2d_find_one);
+    py_bindmethod(array2d, "convolve", array2d_convolve);
 }
 
 #undef INC_COUNT

+ 39 - 2
tests/90_array2d.py

@@ -96,9 +96,19 @@ assert x == d and x is not d
 x.copy_([1, 2, 3, 4, 5, 6, 7, 8])
 assert x.tolist() == [[1, 2], [3, 4], [5, 6], [7, 8]]
 
-# test alive_neighbors
+# test find_one
 a = array2d(3, 3, default=0)
 a[1, 1] = 1
+assert a.find_one(lambda x: x == 1) == vec2i(1, 1)
+try:
+    a.find_one(lambda x: x == 2)
+    exit(1)
+except ValueError:
+    pass
+
+# test alive_neighbors
+a = array2d[int](3, 3, default=0)
+a[1, 1] = 1
 """     Moore    von Neumann
 0 0 0   1 1 1    0 1 0
 0 1 0   1 0 1    1 0 1
@@ -115,6 +125,15 @@ assert _0 == moore_result
 _1 = a.count_neighbors(1, 'von Neumann')
 assert _1 == von_neumann_result
 
+MOORE_KERNEL = array2d[int](3, 3, default=1)
+MOORE_KERNEL[1, 1] = 0
+VON_NEUMANN_KERNEL = array2d[int](3, 3, default=0)
+VON_NEUMANN_KERNEL[0, 1] = VON_NEUMANN_KERNEL[1, 0] = VON_NEUMANN_KERNEL[1, 2] = VON_NEUMANN_KERNEL[2, 1] = 1
+moore_conv_result = a.convolve(MOORE_KERNEL, 0)
+assert moore_conv_result == moore_result
+von_neumann_conv_result = a.convolve(VON_NEUMANN_KERNEL, 0)
+assert von_neumann_conv_result == von_neumann_result
+
 # test slice get
 a = array2d(5, 5, default=0)
 b = array2d(3, 2, default=1)
@@ -177,6 +196,24 @@ assert a.unsafe_get(0, 0) == 1
 a.unsafe_set(0, 0, 2)
 assert a.unsafe_get(0, 0) == 2
 
+# test convolve
+a = array2d[int](5, 2, default=0)
+"""
+1 0 2 4 0
+3 1 0 5 1
+"""
+a[0, 0] = 1; a[1, 0] = 0; a[2, 0] = 2; a[3, 0] = 4; a[4, 0] = 0
+a[0, 1] = 3; a[1, 1] = 1; a[2, 1] = 0; a[3, 1] = 5; a[4, 1] = 1
+assert a.tolist() == [[1, 0, 2, 4, 0], [3, 1, 0, 5, 1]]
+
+kernel = array2d[int](3, 3, default=1)
+res = a.convolve(kernel, -1)
+"""
+0 4 9 9 5
+0 4 9 9 5
+"""
+assert res.tolist() == [[0, 4, 9, 9, 5], [0, 4, 9, 9, 5]]
+
 # stackoverflow bug due to recursive mark-and-sweep
 # class Cell:
 #     neighbors: list['Cell']
@@ -195,4 +232,4 @@ assert a.unsafe_get(0, 0) == 2
 #     ]
 
 # import gc
-# gc.collect()
+# gc.collect()