소스 검색

update array2d

blueloveTH 1 년 전
부모
커밋
f4597ed01a
3개의 변경된 파일52개의 추가작업 그리고 79개의 파일을 삭제
  1. 7 12
      include/typings/array2d.pyi
  2. 30 47
      src/modules/array2d.c
  3. 15 20
      tests/90_array2d.py

+ 7 - 12
include/typings/array2d.pyi

@@ -1,11 +1,9 @@
 from typing import Callable, Any, Generic, TypeVar, Literal, overload, Iterator
 from linalg import vec2i
 
-T = TypeVar('T')
-
 Neighborhood = Literal['Moore', 'von Neumann']
 
-class array2d(Generic[T]):
+class array2d[T]:
     @property
     def n_cols(self) -> int: ...
     @property
@@ -17,24 +15,21 @@ class array2d(Generic[T]):
     @property
     def numel(self) -> int: ...
 
-    def __new__(cls, n_cols: int, n_rows: int, default=None): ...
-    def __len__(self) -> int: ...
+    def __new__(cls, n_cols: int, n_rows: int, default: Callable[[vec2i], T] = None): ...
     def __eq__(self, other: object) -> array2d[bool]: ... # type: ignore
     def __ne__(self, other: object) -> array2d[bool]: ... # type: ignore
     def __repr__(self) -> str: ...
-    def __iter__(self) -> Iterator[tuple[int, int, T]]: ...
+    def __iter__(self) -> Iterator[tuple[vec2i, T]]: ...
 
     @overload
     def is_valid(self, col: int, row: int) -> bool: ...
     @overload
     def is_valid(self, pos: vec2i) -> bool: ...
 
-    def get(self, col: int, row: int, default=None) -> T | None:
-        """Returns the value at the given position or the default value if out of bounds."""
-    def unsafe_get(self, col: int, row: int) -> T:
-        """Returns the value at the given position without bounds checking."""
-    def unsafe_set(self, col: int, row: int, value: T):
-        """Sets the value at the given position without bounds checking."""
+    @overload
+    def get[R](self, col: int, row: int, default: R) -> T | R: ...
+    @overload
+    def get[R](self, pos: vec2i, default: R) -> T | R: ...
 
     @overload
     def __getitem__(self, index: tuple[int, int]) -> T: ...

+ 30 - 47
src/modules/array2d.c

@@ -40,7 +40,7 @@ static c11_array2d* py_array2d(py_OutRef out, int n_cols, int n_rows) {
 
 /* bindings */
 static bool array2d__new__(int argc, py_Ref argv) {
-    // __new__(cls, n_cols: int, n_rows: int, default=None)
+    // __new__(cls, n_cols: int, n_rows: int, default: Callable[[vec2i], T] = None)
     py_Ref default_ = py_arg(3);
     PY_CHECK_ARG_TYPE(0, tp_type);
     PY_CHECK_ARG_TYPE(1, tp_int);
@@ -52,10 +52,17 @@ static bool array2d__new__(int argc, py_Ref argv) {
     c11_array2d* ud = py_array2d(py_pushtmp(), n_cols, n_rows);
     // setup initial values
     if(py_callable(default_)) {
-        for(int i = 0; i < numel; i++) {
-            bool ok = py_call(default_, 0, NULL);
-            if(!ok) return false;
-            ud->data[i] = *py_retval();
+        for(int j = 0; j < n_rows; j++) {
+            for(int i = 0; i < n_cols; i++) {
+                py_TValue tmp;
+                py_newvec2i(&tmp,
+                            (c11_vec2i){
+                                {i, j}
+                });
+                bool ok = py_call(default_, 1, &tmp);
+                if(!ok) return false;
+                ud->data[j * n_cols + i] = *py_retval();
+            }
         }
     } else {
         for(int i = 0; i < numel; i++) {
@@ -111,17 +118,24 @@ static bool array2d_is_valid(int argc, py_Ref argv) {
 static bool array2d_get(int argc, py_Ref argv) {
     py_Ref default_;
     c11_array2d* self = py_touserdata(argv);
-    PY_CHECK_ARG_TYPE(1, tp_int);
-    PY_CHECK_ARG_TYPE(2, tp_int);
+    int col, row;
     if(argc == 3) {
-        default_ = py_None();
+        // get[R](self, pos: vec2i, default: R) -> T | R
+        PY_CHECK_ARG_TYPE(1, tp_vec2i);
+        c11_vec2i pos = py_tovec2i(py_arg(1));
+        col = pos.x;
+        row = pos.y;
+        default_ = py_arg(2);
     } else if(argc == 4) {
+        // get(self, col: int, row: int, default: T) -> T
+        PY_CHECK_ARG_TYPE(1, tp_int);
+        PY_CHECK_ARG_TYPE(2, tp_int);
+        col = py_toint(py_arg(1));
+        row = py_toint(py_arg(2));
         default_ = py_arg(3);
     } else {
         return TypeError("get() expected 3 or 4 arguments");
     }
-    int col = py_toint(py_arg(1));
-    int row = py_toint(py_arg(2));
     if(py_array2d_is_valid(self, col, row)) {
         py_assign(py_retval(), py_array2d__get(self, col, row));
     } else {
@@ -130,36 +144,6 @@ static bool array2d_get(int argc, py_Ref argv) {
     return true;
 }
 
-static bool array2d_unsafe_get(int argc, py_Ref argv) {
-    PY_CHECK_ARGC(3);
-    c11_array2d* self = py_touserdata(argv);
-    PY_CHECK_ARG_TYPE(1, tp_int);
-    PY_CHECK_ARG_TYPE(2, tp_int);
-    int col = py_toint(py_arg(1));
-    int row = py_toint(py_arg(2));
-    py_assign(py_retval(), py_array2d__get(self, col, row));
-    return true;
-}
-
-static bool array2d_unsafe_set(int argc, py_Ref argv) {
-    PY_CHECK_ARGC(4);
-    c11_array2d* self = py_touserdata(argv);
-    PY_CHECK_ARG_TYPE(1, tp_int);
-    PY_CHECK_ARG_TYPE(2, tp_int);
-    int col = py_toint(py_arg(1));
-    int row = py_toint(py_arg(2));
-    py_array2d__set(self, col, row, py_arg(3));
-    py_newnone(py_retval());
-    return true;
-}
-
-static bool array2d__len__(int argc, py_Ref argv) {
-    PY_CHECK_ARGC(1);
-    c11_array2d* self = py_touserdata(argv);
-    py_newint(py_retval(), self->numel);
-    return true;
-}
-
 static bool _array2d_check_all_type(c11_array2d* self, py_Type type) {
     for(int i = 0; i < self->numel; i++) {
         py_Type item_type = self->data[i].type;
@@ -273,11 +257,13 @@ static bool array2d_iterator__next__(int argc, py_Ref argv) {
     c11_array2d_iterator* self = py_touserdata(argv);
     if(self->index < self->array->numel) {
         div_t res = div(self->index, self->array->n_cols);
-        py_newtuple(py_retval(), 3);
+        py_newtuple(py_retval(), 2);
         py_TValue* data = py_tuple_data(py_retval());
-        py_newint(&data[0], res.rem);
-        py_newint(&data[1], res.quot);
-        py_assign(&data[2], self->array->data + self->index);
+        py_newvec2i(&data[0],
+                    (c11_vec2i){
+                        {res.rem, res.quot}
+        });
+        py_assign(&data[1], self->array->data + self->index);
         self->index++;
         return true;
     }
@@ -725,7 +711,6 @@ void pk__add_module_array2d() {
             "__new__(cls, n_cols: int, n_rows: int, default=None)",
             array2d__new__);
 
-    py_bindmagic(array2d, __len__, array2d__len__);
     py_bindmagic(array2d, __eq__, array2d__eq__);
     py_bindmagic(array2d, __ne__, array2d__ne__);
     py_bindmagic(array2d, __repr__, array2d__repr__);
@@ -742,8 +727,6 @@ void pk__add_module_array2d() {
 
     py_bindmethod(array2d, "is_valid", array2d_is_valid);
     py_bindmethod(array2d, "get", array2d_get);
-    py_bindmethod(array2d, "unsafe_get", array2d_unsafe_get);
-    py_bindmethod(array2d, "unsafe_set", array2d_unsafe_set);
 
     py_bindmethod(array2d, "map", array2d_map);
     py_bindmethod(array2d, "copy", array2d_copy);

+ 15 - 20
tests/90_array2d.py

@@ -9,11 +9,16 @@ except ValueError:
     pass
 
 # test callable constructor
-a = array2d[int](2, 4, lambda: 0)
+a = array2d[int](2, 4, lambda pos: (pos.x, pos.y))
 
 assert a.width == a.n_cols == 2
 assert a.height == a.n_rows == 4
 assert a.numel == 8
+assert a.tolist() == [
+    [(0, 0), (1, 0)],
+    [(0, 1), (1, 1)],
+    [(0, 2), (1, 2)],
+    [(0, 3), (1, 3)]]
 
 # test is_valid
 assert a.is_valid(0, 0) and a.is_valid(vec2i(0, 0))
@@ -24,14 +29,14 @@ assert not a.is_valid(-1, 0) and not a.is_valid(vec2i(-1, 0))
 assert not a.is_valid(0, -1) and not a.is_valid(vec2i(0, -1))
 
 # test get
-assert a.get(0, 0) == 0
-assert a.get(1, 3) == 0
-assert a.get(2, 0) is None
-assert a.get(0, 4, 'S') == 'S'
+assert a.get(0, 0, -1) == (0, 0)
+assert a.get(vec2i(1, 3), -1) == (1, 3)
+assert a.get(2, 0, None) is None
+assert a.get(vec2i(0, 4), 'S') == 'S'
 
 # test __getitem__
-assert a[0, 0] == 0
-assert a[1, 3] == 0
+assert a[0, 0] == (0, 0)
+assert a[1, 3] == (1, 3)
 try:
     a[2, 0]
     exit(1)
@@ -39,6 +44,7 @@ except IndexError:
     pass
 
 # test __setitem__
+a = array2d[int](2, 4, default=0)
 a[0, 0] = 5
 assert a[0, 0] == 5
 a[1, 3] = 6
@@ -53,9 +59,6 @@ except IndexError:
 a_list = [[5, 0], [0, 0], [0, 0], [0, 6]]
 assert a_list == a.tolist()
 
-# test __len__
-assert len(a) == 4*2
-
 # test __eq__
 x = array2d(2, 4, default=0)
 b = array2d(2, 4, default=0)
@@ -174,16 +177,8 @@ except TypeError:
 
 # test __iter__
 a = array2d(3, 4, default=1)
-for i, j, x in a:
-    assert a[i, j] == x
-
-assert len(a) == a.numel
-
-# test _get and _set
-a = array2d(3, 4, default=1)
-assert a.unsafe_get(0, 0) == 1
-a.unsafe_set(0, 0, 2)
-assert a.unsafe_get(0, 0) == 2
+for xy, val in a:
+    assert a[xy] == x
 
 # test convolve
 a = array2d[int].fromlist([[1, 0, 2, 4, 0], [3, 1, 0, 5, 1]])