blueloveTH 1 год назад
Родитель
Сommit
6b332dbfbb
4 измененных файлов с 70 добавлено и 27 удалено
  1. 21 0
      include/pocketpy/interpreter/array2d.h
  2. 10 26
      src/modules/array2d.c
  3. 25 1
      src/modules/pickle.c
  4. 14 0
      tests/90_pickle.py

+ 21 - 0
include/pocketpy/interpreter/array2d.h

@@ -0,0 +1,21 @@
+#pragma once
+
+#include "pocketpy/pocketpy.h"
+
+#include "pocketpy/common/utils.h"
+#include "pocketpy/common/sstream.h"
+#include "pocketpy/interpreter/vm.h"
+
+typedef struct c11_array2d {
+    py_TValue* data;  // slots
+    int n_cols;
+    int n_rows;
+    int numel;
+} c11_array2d;
+
+typedef struct c11_array2d_iterator {
+    c11_array2d* array;
+    int index;
+} c11_array2d_iterator;
+
+c11_array2d* py_newarray2d(py_OutRef out, int n_cols, int n_rows);

+ 10 - 26
src/modules/array2d.c

@@ -1,20 +1,4 @@
-#include "pocketpy/pocketpy.h"
-
-#include "pocketpy/common/utils.h"
-#include "pocketpy/common/sstream.h"
-#include "pocketpy/interpreter/vm.h"
-
-typedef struct c11_array2d {
-    py_TValue* data;  // slots
-    int n_cols;
-    int n_rows;
-    int numel;
-} c11_array2d;
-
-typedef struct c11_array2d_iterator {
-    c11_array2d* array;
-    int index;
-} c11_array2d_iterator;
+#include "pocketpy/interpreter/array2d.h"
 
 static bool py_array2d_is_valid(c11_array2d* self, int col, int row) {
     return col >= 0 && col < self->n_cols && row >= 0 && row < self->n_rows;
@@ -28,7 +12,7 @@ static void py_array2d__set(c11_array2d* self, int col, int row, py_Ref value) {
     self->data[row * self->n_cols + col] = *value;
 }
 
-static c11_array2d* py_array2d(py_OutRef out, int n_cols, int n_rows) {
+c11_array2d* py_newarray2d(py_OutRef out, int n_cols, int n_rows) {
     int numel = n_cols * n_rows;
     c11_array2d* ud = py_newobject(out, tp_array2d, numel, sizeof(c11_array2d));
     ud->data = py_getslot(out, 0);
@@ -49,7 +33,7 @@ static bool array2d__new__(int argc, py_Ref argv) {
     int n_rows = argv[2]._i64;
     int numel = n_cols * n_rows;
     if(n_cols <= 0 || n_rows <= 0) return ValueError("array2d() expected positive dimensions");
-    c11_array2d* ud = py_array2d(py_pushtmp(), n_cols, n_rows);
+    c11_array2d* ud = py_newarray2d(py_pushtmp(), n_cols, n_rows);
     // setup initial values
     if(py_callable(default_)) {
         for(int j = 0; j < n_rows; j++) {
@@ -191,7 +175,7 @@ static bool array2d_any(int argc, py_Ref argv) {
 static bool array2d__eq__(int argc, py_Ref argv) {
     PY_CHECK_ARGC(2);
     c11_array2d* self = py_touserdata(argv);
-    c11_array2d* res = py_array2d(py_pushtmp(), self->n_cols, self->n_rows);
+    c11_array2d* res = py_newarray2d(py_pushtmp(), self->n_cols, self->n_rows);
     if(py_istype(py_arg(1), tp_array2d)) {
         c11_array2d* other = py_touserdata(py_arg(1));
         if(!_array2d_check_same_shape(self, other)) return false;
@@ -268,7 +252,7 @@ static bool array2d_map(int argc, py_Ref argv) {
     PY_CHECK_ARGC(2);
     c11_array2d* self = py_touserdata(argv);
     py_Ref f = py_arg(1);
-    c11_array2d* res = py_array2d(py_pushtmp(), self->n_cols, self->n_rows);
+    c11_array2d* res = py_newarray2d(py_pushtmp(), self->n_cols, self->n_rows);
     for(int i = 0; i < self->numel; i++) {
         bool ok = py_call(f, 1, self->data + i);
         if(!ok) return false;
@@ -283,7 +267,7 @@ static bool array2d_copy(int argc, py_Ref argv) {
     // def copy(self) -> 'array2d': ...
     PY_CHECK_ARGC(1);
     c11_array2d* self = py_touserdata(argv);
-    c11_array2d* res = py_array2d(py_retval(), self->n_cols, self->n_rows);
+    c11_array2d* res = py_newarray2d(py_retval(), self->n_cols, self->n_rows);
     memcpy(res->data, self->data, self->numel * sizeof(py_TValue));
     return true;
 }
@@ -356,7 +340,7 @@ static bool array2d_fromlist_STATIC(int argc, py_Ref argv) {
             return ValueError("fromlist() expected a list of lists with the same length");
         }
     }
-    c11_array2d* res = py_array2d(py_retval(), n_cols, n_rows);
+    c11_array2d* res = py_newarray2d(py_retval(), n_cols, n_rows);
     for(int j = 0; j < n_rows; j++) {
         py_Ref row_j = py_list_getitem(argv, j);
         for(int i = 0; i < n_cols; i++) {
@@ -452,7 +436,7 @@ static bool array2d_get_bounding_rect(int argc, py_Ref argv) {
 static bool array2d_count_neighbors(int argc, py_Ref argv) {
     PY_CHECK_ARGC(3);
     c11_array2d* self = py_touserdata(argv);
-    c11_array2d* res = py_array2d(py_pushtmp(), self->n_cols, self->n_rows);
+    c11_array2d* res = py_newarray2d(py_pushtmp(), self->n_cols, self->n_rows);
     py_Ref value = py_arg(1);
     const char* neighborhood = py_tostr(py_arg(2));
 
@@ -556,7 +540,7 @@ static bool array2d__getitem__(int argc, py_Ref argv) {
         return _array2d_IndexError(self, col, row);
     } else if(py_istype(x, tp_slice) && py_istype(y, tp_slice)) {
         HANDLE_SLICE();
-        c11_array2d* res = py_array2d(py_retval(), slice_width, slice_height);
+        c11_array2d* res = py_newarray2d(py_retval(), slice_width, slice_height);
         for(int j = start_row; j < stop_row; j++) {
             for(int i = start_col; i < stop_col; i++) {
                 py_array2d__set(res, i - start_col, j - start_row, py_array2d__get(self, i, j));
@@ -660,7 +644,7 @@ static bool array2d_convolve(int argc, py_Ref argv) {
     int ksize_half = ksize / 2;
     if(!_array2d_check_all_type(self, tp_int)) return false;
     if(!_array2d_check_all_type(kernel, tp_int)) return false;
-    c11_array2d* res = py_array2d(py_pushtmp(), self->n_cols, self->n_rows);
+    c11_array2d* res = py_newarray2d(py_pushtmp(), self->n_cols, self->n_rows);
     for(int j = 0; j < self->n_rows; j++) {
         for(int i = 0; i < self->n_cols; i++) {
             py_i64 sum = 0;

+ 25 - 1
src/modules/pickle.c

@@ -3,7 +3,7 @@
 
 #include "pocketpy/common/utils.h"
 #include "pocketpy/common/sstream.h"
-#include "pocketpy/interpreter/vm.h"
+#include "pocketpy/interpreter/array2d.h"
 #include <stdint.h>
 
 typedef enum {
@@ -23,6 +23,7 @@ typedef enum {
     PKL_VEC2, PKL_VEC3,
     PKL_VEC2I, PKL_VEC3I,
     PKL_TYPE,
+    PKL_ARRAY2D,
     PKL_EOF,
     // clang-format on
 } PickleOp;
@@ -299,6 +300,20 @@ static bool pickle__write_object(PickleObject* buf, py_TValue* obj) {
             c11_string__delete(path);
             break;
         }
+        case tp_array2d: {
+            c11_array2d* arr = py_touserdata(obj);
+            for(int i = 0; i < arr->numel; i++) {
+                if(arr->data[i].is_ptr)
+                    return TypeError(
+                        "'array2d' object is not picklable because it contains heap-allocated objects");
+            }
+            pkl__emit_op(buf, PKL_ARRAY2D);
+            pkl__emit_int(buf, arr->n_cols);
+            pkl__emit_int(buf, arr->n_rows);
+            // TODO: fix type index which is not stable
+            PickleObject__write_bytes(buf, arr->data, arr->numel * sizeof(py_TValue));
+            break;
+        }
         default: return TypeError("'%t' object is not picklable", obj->type);
     }
     if(obj->is_ptr) {
@@ -503,6 +518,15 @@ bool py_pickle_loads(const unsigned char* data, int size) {
                 py_push(py_tpobject(t));
                 break;
             }
+            case PKL_ARRAY2D: {
+                int n_cols = pkl__read_int(&p);
+                int n_rows = pkl__read_int(&p);
+                c11_array2d* arr = py_newarray2d(py_pushtmp(), n_cols, n_rows);
+                int total_size = arr->numel * sizeof(py_TValue);
+                memcpy(arr->data, p, total_size);
+                p += total_size;
+                break;
+            }
             case PKL_EOF: {
                 // [memo, obj]
                 if(py_peek(0) - p0 != 2) return ValueError("invalid pickle data");

+ 14 - 0
tests/90_pickle.py

@@ -31,6 +31,20 @@ test(vec3i(1, 2, 3))            # PKL_VEC3I
 
 test(vec3i)                     # PKL_TYPE
 
+print('-'*50)
+from array2d import array2d
+a = array2d[int].fromlist([
+    [1, 2, 3],
+    [4, 5, 6]
+])
+a_encoded = pkl.dumps(a)
+print(a_encoded)
+a_decoded = pkl.loads(a_encoded)
+assert isinstance(a_decoded, array2d)
+assert a_decoded.width == 3 and a_decoded.height == 2
+assert (a == a_decoded).all()
+print(a_decoded)
+
 test([1, 2, 3])                 # PKL_LIST
 test((1, 2, 3))                 # PKL_TUPLE
 test({1: 2, 3: 4})              # PKL_DICT