array2d.cpp 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. #include "pocketpy/array2d.h"
  2. namespace pkpy{
  3. struct Array2d{
  4. PK_ALWAYS_PASS_BY_POINTER(Array2d)
  5. PY_CLASS(Array2d, array2d, array2d)
  6. PyObject** data;
  7. int n_cols;
  8. int n_rows;
  9. int numel;
  10. Array2d(){
  11. data = nullptr;
  12. n_cols = 0;
  13. n_rows = 0;
  14. numel = 0;
  15. }
  16. Array2d* _() { return this; }
  17. void init(int n_cols, int n_rows){
  18. this->n_cols = n_cols;
  19. this->n_rows = n_rows;
  20. this->numel = n_cols * n_rows;
  21. this->data = new PyObject*[numel];
  22. }
  23. bool is_valid(int col, int row) const{
  24. return 0 <= col && col < n_cols && 0 <= row && row < n_rows;
  25. }
  26. PyObject* _get(int col, int row){
  27. return data[row * n_cols + col];
  28. }
  29. void _set(int col, int row, PyObject* value){
  30. data[row * n_cols + col] = value;
  31. }
  32. static void _register(VM* vm, PyObject* mod, PyObject* type){
  33. vm->bind(type, "__new__(cls, *args, **kwargs)", [](VM* vm, ArgsView args){
  34. Type cls = PK_OBJ_GET(Type, args[0]);
  35. return vm->heap.gcnew<Array2d>(cls);
  36. });
  37. vm->bind(type, "__init__(self, n_cols: int, n_rows: int, default=None)", [](VM* vm, ArgsView args){
  38. Array2d& self = PK_OBJ_GET(Array2d, args[0]);
  39. int n_cols = CAST(int, args[1]);
  40. int n_rows = CAST(int, args[2]);
  41. if(n_cols <= 0 || n_rows <= 0){
  42. vm->ValueError("n_cols and n_rows must be positive integers");
  43. }
  44. self.init(n_cols, n_rows);
  45. if(vm->py_callable(args[3])){
  46. for(int i = 0; i < self.numel; i++) self.data[i] = vm->call(args[3]);
  47. }else{
  48. for(int i = 0; i < self.numel; i++) self.data[i] = args[3];
  49. }
  50. return vm->None;
  51. });
  52. PY_READONLY_FIELD(Array2d, "n_cols", _, n_cols);
  53. PY_READONLY_FIELD(Array2d, "n_rows", _, n_rows);
  54. PY_READONLY_FIELD(Array2d, "width", _, n_cols);
  55. PY_READONLY_FIELD(Array2d, "height", _, n_rows);
  56. PY_READONLY_FIELD(Array2d, "numel", _, numel);
  57. vm->bind(type, "is_valid(self, col: int, row: int)", [](VM* vm, ArgsView args){
  58. Array2d& self = PK_OBJ_GET(Array2d, args[0]);
  59. int col = CAST(int, args[1]);
  60. int row = CAST(int, args[2]);
  61. return VAR(self.is_valid(col, row));
  62. });
  63. vm->bind(type, "get(self, col: int, row: int, default=None)", [](VM* vm, ArgsView args){
  64. Array2d& self = PK_OBJ_GET(Array2d, args[0]);
  65. int col = CAST(int, args[1]);
  66. int row = CAST(int, args[2]);
  67. if(!self.is_valid(col, row)) return args[3];
  68. return self._get(col, row);
  69. });
  70. vm->bind__getitem__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* _0, PyObject* _1){
  71. Array2d& self = PK_OBJ_GET(Array2d, _0);
  72. const Tuple& xy = CAST(Tuple&, _1);
  73. int col = CAST(int, xy[0]);
  74. int row = CAST(int, xy[1]);
  75. if(!self.is_valid(col, row)){
  76. vm->IndexError(_S('(', col, ", ", row, ')', " is not a valid index for array2d(", self.n_cols, ", ", self.n_rows, ')'));
  77. }
  78. return self._get(col, row);
  79. });
  80. vm->bind__setitem__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* _0, PyObject* _1, PyObject* _2){
  81. Array2d& self = PK_OBJ_GET(Array2d, _0);
  82. const Tuple& xy = CAST(Tuple&, _1);
  83. int col = CAST(int, xy[0]);
  84. int row = CAST(int, xy[1]);
  85. if(!self.is_valid(col, row)){
  86. vm->IndexError(_S('(', col, ", ", row, ')', " is not a valid index for array2d(", self.n_cols, ", ", self.n_rows, ')'));
  87. }
  88. self._set(col, row, _2);
  89. });
  90. vm->bind__iter__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* _0){
  91. Array2d& self = PK_OBJ_GET(Array2d, _0);
  92. List t(self.n_rows);
  93. List row(self.n_cols);
  94. for(int j = 0; j < self.n_rows; j++){
  95. for(int i = 0; i < self.n_cols; i++) row[i] = self._get(i, j);
  96. t[j] = VAR(row); // copy
  97. }
  98. return vm->py_iter(VAR(std::move(t)));
  99. });
  100. vm->bind__len__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* _0){
  101. Array2d& self = PK_OBJ_GET(Array2d, _0);
  102. return (i64)self.n_rows;
  103. });
  104. vm->bind__repr__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* _0){
  105. Array2d& self = PK_OBJ_GET(Array2d, _0);
  106. return VAR(_S("array2d(", self.n_cols, ", ", self.n_rows, ')'));
  107. });
  108. vm->bind(type, "map(self, f)", [](VM* vm, ArgsView args){
  109. Array2d& self = PK_OBJ_GET(Array2d, args[0]);
  110. PyObject* f = args[1];
  111. PyObject* new_array_obj = vm->heap.gcnew<Array2d>(Array2d::_type(vm));
  112. Array2d& new_array = PK_OBJ_GET(Array2d, new_array_obj);
  113. new_array.init(self.n_cols, self.n_rows);
  114. for(int i = 0; i < new_array.numel; i++){
  115. new_array.data[i] = vm->call(f, self.data[i]);
  116. }
  117. return new_array_obj;
  118. });
  119. vm->bind(type, "copy(self)", [](VM* vm, ArgsView args){
  120. Array2d& self = PK_OBJ_GET(Array2d, args[0]);
  121. PyObject* new_array_obj = vm->heap.gcnew<Array2d>(Array2d::_type(vm));
  122. Array2d& new_array = PK_OBJ_GET(Array2d, new_array_obj);
  123. new_array.init(self.n_cols, self.n_rows);
  124. for(int i = 0; i < new_array.numel; i++){
  125. new_array.data[i] = self.data[i];
  126. }
  127. return new_array_obj;
  128. });
  129. vm->bind(type, "fill_(self, value)", [](VM* vm, ArgsView args){
  130. Array2d& self = PK_OBJ_GET(Array2d, args[0]);
  131. for(int i = 0; i < self.numel; i++){
  132. self.data[i] = args[1];
  133. }
  134. return vm->None;
  135. });
  136. vm->bind(type, "apply_(self, f)", [](VM* vm, ArgsView args){
  137. Array2d& self = PK_OBJ_GET(Array2d, args[0]);
  138. PyObject* f = args[1];
  139. for(int i = 0; i < self.numel; i++){
  140. self.data[i] = vm->call(f, self.data[i]);
  141. }
  142. return vm->None;
  143. });
  144. vm->bind(type, "copy_(self, other)", [](VM* vm, ArgsView args){
  145. Array2d& self = PK_OBJ_GET(Array2d, args[0]);
  146. Array2d& other = CAST(Array2d&, args[1]);
  147. // if self and other have different sizes, re-initialize self
  148. if(self.n_cols != other.n_cols || self.n_rows != other.n_rows){
  149. delete self.data;
  150. self.init(other.n_cols, other.n_rows);
  151. }
  152. for(int i = 0; i < self.numel; i++){
  153. self.data[i] = other.data[i];
  154. }
  155. return vm->None;
  156. });
  157. vm->bind__eq__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* _0, PyObject* _1){
  158. Array2d& self = PK_OBJ_GET(Array2d, _0);
  159. if(!is_non_tagged_type(_1, Array2d::_type(vm))) return vm->NotImplemented;
  160. Array2d& other = PK_OBJ_GET(Array2d, _1);
  161. if(self.n_cols != other.n_cols || self.n_rows != other.n_rows) return vm->False;
  162. for(int i = 0; i < self.numel; i++){
  163. if(vm->py_ne(self.data[i], other.data[i])) return vm->False;
  164. }
  165. return vm->True;
  166. });
  167. // for cellular automata
  168. vm->bind(type, "count_neighbors(self, value) -> array2d[int]", [](VM* vm, ArgsView args){
  169. Array2d& self = PK_OBJ_GET(Array2d, args[0]);
  170. PyObject* new_array_obj = vm->heap.gcnew<Array2d>(Array2d::_type(vm));
  171. Array2d& new_array = PK_OBJ_GET(Array2d, new_array_obj);
  172. new_array.init(self.n_cols, self.n_rows);
  173. PyObject* value = args[1];
  174. for(int j = 0; j < new_array.n_rows; j++){
  175. for(int i = 0; i < new_array.n_cols; i++){
  176. int count = 0;
  177. count += self.is_valid(i-1, j-1) && vm->py_eq(self._get(i-1, j-1), value);
  178. count += self.is_valid(i, j-1) && vm->py_eq(self._get(i, j-1), value);
  179. count += self.is_valid(i+1, j-1) && vm->py_eq(self._get(i+1, j-1), value);
  180. count += self.is_valid(i-1, j) && vm->py_eq(self._get(i-1, j), value);
  181. count += self.is_valid(i+1, j) && vm->py_eq(self._get(i+1, j), value);
  182. count += self.is_valid(i-1, j+1) && vm->py_eq(self._get(i-1, j+1), value);
  183. count += self.is_valid(i, j+1) && vm->py_eq(self._get(i, j+1), value);
  184. count += self.is_valid(i+1, j+1) && vm->py_eq(self._get(i+1, j+1), value);
  185. new_array._set(i, j, VAR(count));
  186. }
  187. }
  188. return new_array_obj;
  189. });
  190. }
  191. void _gc_mark() const{
  192. for(int i = 0; i < numel; i++) PK_OBJ_MARK(data[i]);
  193. }
  194. ~Array2d(){
  195. delete[] data;
  196. }
  197. };
  198. void add_module_array2d(VM* vm){
  199. PyObject* mod = vm->new_module("array2d");
  200. Array2d::register_class(vm, mod);
  201. }
  202. } // namespace pkpy