blueloveTH 1 yıl önce
ebeveyn
işleme
63f08f7b50
2 değiştirilmiş dosya ile 25 ekleme ve 9 silme
  1. 23 9
      src/modules/array2d.c
  2. 2 0
      tests/90_array2d.py

+ 23 - 9
src/modules/array2d.c

@@ -272,17 +272,31 @@ static bool array2d_apply_(int argc, py_Ref argv) {
 static bool array2d_copy_(int argc, py_Ref argv) {
     // def copy_(self, src: 'array2d') -> None: ...
     PY_CHECK_ARGC(2);
-    PY_CHECK_ARG_TYPE(1, tp_array2d);
     c11_array2d* self = py_touserdata(argv);
-    c11_array2d* src = py_touserdata(py_arg(1));
-    if(self->n_cols != src->n_cols || self->n_rows != src->n_rows) {
-        return ValueError("copy_() expected the same shape: (%d, %d) != (%d, %d)",
-                          self->n_cols,
-                          self->n_rows,
-                          src->n_cols,
-                          src->n_rows);
+
+    py_Type src_type = py_typeof(py_arg(1));
+    if(src_type == tp_array2d) {
+        c11_array2d* src = py_touserdata(py_arg(1));
+        if(self->n_cols != src->n_cols || self->n_rows != src->n_rows) {
+            return ValueError("copy_() expected the same shape: (%d, %d) != (%d, %d)",
+                              self->n_cols,
+                              self->n_rows,
+                              src->n_cols,
+                              src->n_rows);
+        }
+        memcpy(self->data, src->data, self->numel * sizeof(py_TValue));
+    } else {
+        py_TValue* data;
+        int length = pk_arrayview(py_arg(1), &data);
+        if(length != -1) {
+            if(self->numel != length) {
+                return ValueError("copy_() expected the same numel: %d != %d", self->numel, length);
+            }
+            memcpy(self->data, data, self->numel * sizeof(py_TValue));
+        } else {
+            return TypeError("copy_() expected `array2d`, `list` or `tuple`, got '%t", src_type);
+        }
     }
-    memcpy(self->data, src->data, self->numel * sizeof(py_TValue));
     py_newnone(py_retval());
     return true;
 }

+ 2 - 0
tests/90_array2d.py

@@ -92,6 +92,8 @@ assert a == d and a is not d
 x = array2d(2, 4, default=0)
 x.copy_(d)
 assert x == d and x is not d
+x.copy_([1, 2, 3, 4, 5, 6, 7, 8])
+assert x.tolist() == [[1, 2], [3, 4], [5, 6], [7, 8]]
 
 # test alive_neighbors
 a = array2d(3, 3, default=0)