Pārlūkot izejas kodu

improve `array2d.__setitem__`

blueloveTH 2 gadi atpakaļ
vecāks
revīzija
3534492bb6
3 mainītis faili ar 47 papildinājumiem un 6 dzēšanām
  1. 1 1
      include/typings/array2d.pyi
  2. 24 5
      src/array2d.cpp
  3. 22 0
      tests/83_array2d.py

+ 1 - 1
include/typings/array2d.pyi

@@ -24,7 +24,7 @@ class array2d(Generic[T]):
     @overload
     def __setitem__(self, index: tuple[int, int], value: T): ...
     @overload
-    def __setitem__(self, index: tuple[slice, slice], value: 'array2d[T]'): ...
+    def __setitem__(self, index: tuple[slice, slice], value: int | float | str | bool | None | 'array2d[T]'): ...
 
     def __len__(self) -> int: ...
     def __eq__(self, other: 'array2d') -> bool: ...

+ 24 - 5
src/array2d.cpp

@@ -133,15 +133,34 @@ struct Array2d{
 
             if(is_non_tagged_type(xy[0], VM::tp_slice) && is_non_tagged_type(xy[1], VM::tp_slice)){
                 HANDLE_SLICE();
-                Array2d& other = CAST(Array2d&, _2);        // _2 must be an array2d
+
+                bool is_basic_type = false;
+                switch(vm->_tp(_2).index){
+                    case VM::tp_int.index: is_basic_type = true; break;
+                    case VM::tp_float.index: is_basic_type = true; break;
+                    case VM::tp_str.index: is_basic_type = true; break;
+                    case VM::tp_bool.index: is_basic_type = true; break;
+                    default: is_basic_type = _2 == vm->None;
+                }
+
+                if(is_basic_type){
+                    for(int j = 0; j < slice_height; j++)
+                        for(int i = 0; i < slice_width; i++)
+                            self._set(i + start_col, j + start_row, _2);
+                    return;
+                }
+
+                if(!is_non_tagged_type(_2, Array2d::_type(vm))){
+                    vm->TypeError(_S("expected int/float/str/bool/None or an array2d instance"));
+                }
+
+                Array2d& other = PK_OBJ_GET(Array2d, _2);
                 if(slice_width != other.n_cols || slice_height != other.n_rows){
                     vm->ValueError("array2d size does not match the slice size");
                 }
-                for(int j = 0; j < slice_height; j++){
-                    for(int i = 0; i < slice_width; i++){
+                for(int j = 0; j < slice_height; j++)
+                    for(int i = 0; i < slice_width; i++)
                         self._set(i + start_col, j + start_row, other._get(i, j));
-                    }
-                }
                 return;
             }
             vm->TypeError("expected `tuple[int, int]` or `tuple[slice, slice]` as index");

+ 22 - 0
tests/83_array2d.py

@@ -144,3 +144,25 @@ assert a.find_bounding_rect(0) == (0, 0, 5, 5)
 assert a.find_bounding_rect(2) == None
 
 
+a = array2d(3, 2, default='?')
+# int/float/str/bool/None
+
+for value in [0, 0.0, '0', False, None]:
+    a[0:2, 0:1] = value
+    assert a[2, 1] == '?'
+    assert a[0, 0] == value
+
+a[:, :] = 3
+assert a == array2d(3, 2, default=3)
+
+try:
+    a[:, :] = array2d(1, 1)
+    exit(1)
+except ValueError:
+    pass
+
+try:
+    a[:, :] = []
+    exit(1)
+except TypeError:
+    pass