Explorar o código

improve `array2d.is_valid`

blueloveTH hai 1 ano
pai
achega
b939df167d
Modificáronse 3 ficheiros con 26 adicións e 15 borrados
  1. 3 0
      include/typings/array2d.pyi
  2. 16 9
      src/modules/array2d.c
  3. 7 6
      tests/90_array2d.py

+ 3 - 0
include/typings/array2d.pyi

@@ -22,7 +22,10 @@ class array2d(Generic[T]):
     def __repr__(self) -> str: ...
     def __iter__(self) -> Iterator[tuple[int, int, T]]: ...
 
+    @overload
     def is_valid(self, col: int, row: int) -> bool: ...
+    @overload
+    def is_valid(self, pos: vec2i) -> bool: ...
 
     def get(self, col: int, row: int, default=None) -> T | None:
         """Returns the value at the given position or the default value if out of bounds."""

+ 16 - 9
src/modules/array2d.c

@@ -89,12 +89,21 @@ static bool array2d_numel(int argc, py_Ref argv) {
 }
 
 static bool array2d_is_valid(int argc, py_Ref argv) {
-    PY_CHECK_ARGC(3);
     c11_array2d* self = py_touserdata(argv);
-    PY_CHECK_ARG_TYPE(1, tp_int);
-    PY_CHECK_ARG_TYPE(2, tp_int);
-    int col = py_toint(py_arg(1));
-    int row = py_toint(py_arg(2));
+    int col, row;
+    if(argc == 2) {
+        PY_CHECK_ARG_TYPE(1, tp_vec2i);
+        c11_vec2i pos = py_tovec2i(py_arg(1));
+        col = pos.x;
+        row = pos.y;
+    } else if(argc == 3) {
+        PY_CHECK_ARG_TYPE(1, tp_int);
+        PY_CHECK_ARG_TYPE(2, tp_int);
+        col = py_toint(py_arg(1));
+        row = py_toint(py_arg(2));
+    } else {
+        return TypeError("is_valid() expected 2 or 3 arguments");
+    }
     py_newbool(py_retval(), py_array2d_is_valid(self, col, row));
     return true;
 }
@@ -315,7 +324,7 @@ static bool array2d_tolist(int argc, py_Ref argv) {
     return true;
 }
 
-static bool array2d_render(int argc, py_Ref argv){
+static bool array2d_render(int argc, py_Ref argv) {
     PY_CHECK_ARGC(1);
     c11_sbuf buf;
     c11_sbuf__ctor(&buf);
@@ -326,9 +335,7 @@ static bool array2d_render(int argc, py_Ref argv){
             if(!py_str(item)) return false;
             c11_sbuf__write_sv(&buf, py_tosv(py_retval()));
         }
-        if(j < self->n_rows - 1){
-            c11_sbuf__write_char(&buf, '\n');
-        }
+        if(j < self->n_rows - 1) c11_sbuf__write_char(&buf, '\n');
     }
     c11_sbuf__py_submit(&buf, py_retval());
     return true;

+ 7 - 6
tests/90_array2d.py

@@ -1,4 +1,5 @@
 from array2d import array2d
+from linalg import vec2i
 
 # test error args for __init__
 try:
@@ -15,12 +16,12 @@ assert a.height == a.n_rows == 4
 assert a.numel == 8
 
 # test is_valid
-assert a.is_valid(0, 0)
-assert a.is_valid(1, 3)
-assert not a.is_valid(2, 0)
-assert not a.is_valid(0, 4)
-assert not a.is_valid(-1, 0)
-assert not a.is_valid(0, -1)
+assert a.is_valid(0, 0) and a.is_valid(vec2i(0, 0))
+assert a.is_valid(1, 3) and a.is_valid(vec2i(1, 3))
+assert not a.is_valid(2, 0) and not a.is_valid(vec2i(2, 0))
+assert not a.is_valid(0, 4) and not a.is_valid(vec2i(0, 4))
+assert not a.is_valid(-1, 0) and not a.is_valid(vec2i(-1, 0))
+assert not a.is_valid(0, -1) and not a.is_valid(vec2i(0, -1))
 
 # test get
 assert a.get(0, 0) == 0