Explorar el Código

update `array2d.copy_`

blueloveTH hace 2 años
padre
commit
0bbf6af3ff
Se han modificado 3 ficheros con 19 adiciones y 4 borrados
  1. 7 4
      include/typings/array2d.pyi
  2. 10 0
      src/array2d.cpp
  3. 2 0
      tests/83_array2d.py

+ 7 - 4
include/typings/array2d.pyi

@@ -78,14 +78,18 @@ class array2d(Generic[T]):
         return new_a
 
     def fill_(self, value: T) -> None:
-        for i in range(self.n_cols * self.n_rows):
+        for i in range(self.numel):
             self.data[i] = value
 
     def apply_(self, f: Callable[[T], T]) -> None:
-        for i in range(self.n_cols * self.n_rows):
+        for i in range(self.numel):
             self.data[i] = f(self.data[i])
 
-    def copy_(self, other: 'array2d[T]') -> None:
+    def copy_(self, other: 'array2d[T]' | list['T']) -> None:
+        if isinstance(other, list):
+            assert len(other) == self.numel
+            self.data = other.copy()
+            return
         self.n_cols = other.n_cols
         self.n_rows = other.n_rows
         self.data = other.data.copy()
@@ -106,4 +110,3 @@ class array2d(Generic[T]):
                 count += int(self.is_valid(i+1, j+1) and self[i+1, j+1] == value)
                 new_a[i, j] = count
         return new_a
-

+ 10 - 0
src/array2d.cpp

@@ -167,6 +167,16 @@ struct Array2d{
 
         vm->bind(type, "copy_(self, other)", [](VM* vm, ArgsView args){
             Array2d& self = PK_OBJ_GET(Array2d, args[0]);
+            if(is_non_tagged_type(args[1], VM::tp_list)){
+                const List& list = PK_OBJ_GET(List, args[1]);
+                if(list.size() != self.numel){
+                    vm->ValueError("list size must be equal to the number of elements in the array2d");
+                }
+                for(int i = 0; i < self.numel; i++){
+                    self.data[i] = list[i];
+                }
+                return vm->None;
+            }
             Array2d& other = CAST(Array2d&, args[1]);
             // if self and other have different sizes, re-initialize self
             if(self.n_cols != other.n_cols || self.n_rows != other.n_rows){

+ 2 - 0
tests/83_array2d.py

@@ -92,6 +92,8 @@ assert a == d and a is not d
 x = array2d(4, 4, default=0)
 x.copy_(d)
 assert x == d and x is not d
+x.copy_(['a']*d.numel)
+assert x == array2d(d.width, d.height, default='a')
 
 # test subclass array2d
 class A(array2d):