blueloveTH 1 год назад
Родитель
Сommit
bb653bd383
3 измененных файлов с 44 добавлено и 19 удалено
  1. 6 3
      include/typings/array2d.pyi
  2. 32 13
      src/modules/array2d.c
  3. 6 3
      tests/90_array2d.py

+ 6 - 3
include/typings/array2d.pyi

@@ -1,4 +1,5 @@
 from typing import Callable, Any, Generic, TypeVar, Literal, overload, Iterator
+from linalg import vec2i
 
 T = TypeVar('T')
 
@@ -16,10 +17,8 @@ class array2d(Generic[T]):
     @property
     def numel(self) -> int: ...
 
-    def __new__(self, n_cols: int, n_rows: int, default=None): ...
+    def __new__(cls, n_cols: int, n_rows: int, default=None): ...
     def __len__(self) -> int: ...
-    def __eq__(self, other: 'array2d') -> bool: ...
-    def __ne__(self, other: 'array2d') -> bool: ...
     def __repr__(self) -> str: ...
     def __iter__(self) -> Iterator[tuple[int, int, T]]: ...
 
@@ -35,10 +34,14 @@ class array2d(Generic[T]):
     @overload
     def __getitem__(self, index: tuple[int, int]) -> T: ...
     @overload
+    def __getitem__(self, index: vec2i) -> T: ...
+    @overload
     def __getitem__(self, index: tuple[slice, slice]) -> 'array2d[T]': ...
     @overload
     def __setitem__(self, index: tuple[int, int], value: T): ...
     @overload
+    def __setitem__(self, index: vec2i, value: T): ...
+    @overload
     def __setitem__(self, index: tuple[slice, slice], value: int | float | str | bool | None | 'array2d[T]'): ...
 
     def map(self, f: Callable[[T], Any]) -> 'array2d': ...

+ 32 - 13
src/modules/array2d.c

@@ -426,13 +426,30 @@ static bool array2d_count_neighbors(int argc, py_Ref argv) {
     if(slice_width <= 0 || slice_height <= 0)                                                      \
         return ValueError("slice width and height must be positive");
 
+static bool _array2d_IndexError(c11_array2d* self, int col, int row) {
+    return IndexError("(%d, %d) is not a valid index of array2d(%d, %d)",
+                      col,
+                      row,
+                      self->n_cols,
+                      self->n_rows);
+}
+
 static bool array2d__getitem__(int argc, py_Ref argv) {
     PY_CHECK_ARGC(2);
+    c11_array2d* self = py_touserdata(argv);
+    if(argv[1].type == tp_vec2i) {
+        // fastpath for vec2i
+        c11_vec2i pos = py_tovec2i(&argv[1]);
+        if(py_array2d_is_valid(self, pos.x, pos.y)) {
+            py_assign(py_retval(), py_array2d__get(self, pos.x, pos.y));
+            return true;
+        }
+        return _array2d_IndexError(self, pos.x, pos.y);
+    }
     PY_CHECK_ARG_TYPE(1, tp_tuple);
     if(py_tuple_len(py_arg(1)) != 2) return TypeError("expected a tuple of 2 elements");
     py_Ref x = py_tuple_getitem(py_arg(1), 0);
     py_Ref y = py_tuple_getitem(py_arg(1), 1);
-    c11_array2d* self = py_touserdata(argv);
     if(py_isint(x) && py_isint(y)) {
         int col = py_toint(x);
         int row = py_toint(y);
@@ -440,11 +457,7 @@ static bool array2d__getitem__(int argc, py_Ref argv) {
             py_assign(py_retval(), py_array2d__get(self, col, row));
             return true;
         }
-        return IndexError("(%d, %d) is not a valid index of array2d(%d, %d)",
-                          col,
-                          row,
-                          self->n_cols,
-                          self->n_rows);
+        return _array2d_IndexError(self, col, row);
     } else if(py_istype(x, tp_slice) && py_istype(y, tp_slice)) {
         HANDLE_SLICE();
         c11_array2d* res = py_array2d(py_retval(), slice_width, slice_height);
@@ -461,12 +474,22 @@ static bool array2d__getitem__(int argc, py_Ref argv) {
 
 static bool array2d__setitem__(int argc, py_Ref argv) {
     PY_CHECK_ARGC(3);
+    c11_array2d* self = py_touserdata(argv);
+    py_Ref value = py_arg(2);
+    if(argv[1].type == tp_vec2i) {
+        // fastpath for vec2i
+        c11_vec2i pos = py_tovec2i(&argv[1]);
+        if(py_array2d_is_valid(self, pos.x, pos.y)) {
+            py_array2d__set(self, pos.x, pos.y, value);
+            py_newnone(py_retval());
+            return true;
+        }
+        return _array2d_IndexError(self, pos.x, pos.y);
+    }
     PY_CHECK_ARG_TYPE(1, tp_tuple);
     if(py_tuple_len(py_arg(1)) != 2) return TypeError("expected a tuple of 2 elements");
     py_Ref x = py_tuple_getitem(py_arg(1), 0);
     py_Ref y = py_tuple_getitem(py_arg(1), 1);
-    c11_array2d* self = py_touserdata(argv);
-    py_Ref value = py_arg(2);
     if(py_isint(x) && py_isint(y)) {
         int col = py_toint(x);
         int row = py_toint(y);
@@ -475,11 +498,7 @@ static bool array2d__setitem__(int argc, py_Ref argv) {
             py_newnone(py_retval());
             return true;
         }
-        return IndexError("(%d, %d) is not a valid index of array2d(%d, %d)",
-                          col,
-                          row,
-                          self->n_cols,
-                          self->n_rows);
+        return _array2d_IndexError(self, col, row);
     } else if(py_istype(x, tp_slice) && py_istype(y, tp_slice)) {
         HANDLE_SLICE();
         bool is_basic_type = false;

+ 6 - 3
tests/90_array2d.py

@@ -108,8 +108,11 @@ moore_result[1, 1] = 0
 
 von_neumann_result = array2d(3, 3, default=0)
 von_neumann_result[0, 1] = von_neumann_result[1, 0] = von_neumann_result[1, 2] = von_neumann_result[2, 1] = 1
-a.count_neighbors(0, 'Moore') == moore_result
-a.count_neighbors(0, 'von Neumann') == von_neumann_result
+
+_0 = a.count_neighbors(1, 'Moore')
+assert _0 == moore_result
+_1 = a.count_neighbors(1, 'von Neumann')
+assert _1 == von_neumann_result
 
 # test slice get
 a = array2d(5, 5, default=0)
@@ -152,7 +155,7 @@ except ValueError:
     pass
 
 try:
-    a[:, :] = []
+    a[:, :] = ...
     exit(1)
 except TypeError:
     pass