array2d.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. #include "pocketpy/modules/array2d.hpp"
  2. #include "pocketpy/interpreter/bindings.hpp"
  3. namespace pkpy{
  4. struct Array2d{
  5. PK_ALWAYS_PASS_BY_POINTER(Array2d)
  6. PyVar* data;
  7. int n_cols;
  8. int n_rows;
  9. int numel;
  10. Array2d(){
  11. data = nullptr;
  12. n_cols = 0;
  13. n_rows = 0;
  14. numel = 0;
  15. }
  16. void init(int n_cols, int n_rows){
  17. this->n_cols = n_cols;
  18. this->n_rows = n_rows;
  19. this->numel = n_cols * n_rows;
  20. this->data = new PyVar[numel];
  21. }
  22. bool is_valid(int col, int row) const{
  23. return 0 <= col && col < n_cols && 0 <= row && row < n_rows;
  24. }
  25. void check_valid(VM* vm, int col, int row) const{
  26. if(is_valid(col, row)) return;
  27. vm->IndexError(_S('(', col, ", ", row, ')', " is not a valid index for array2d(", n_cols, ", ", n_rows, ')'));
  28. }
  29. PyVar _get(int col, int row){
  30. return data[row * n_cols + col];
  31. }
  32. void _set(int col, int row, PyVar value){
  33. data[row * n_cols + col] = value;
  34. }
  35. static void _register(VM* vm, PyObject* mod, PyObject* type){
  36. vm->bind(type, "__new__(cls, *args, **kwargs)", [](VM* vm, ArgsView args){
  37. Type cls = PK_OBJ_GET(Type, args[0]);
  38. return vm->new_object<Array2d>(cls);
  39. });
  40. vm->bind(type, "__init__(self, n_cols: int, n_rows: int, default=None)", [](VM* vm, ArgsView args){
  41. Array2d& self = PK_OBJ_GET(Array2d, args[0]);
  42. int n_cols = CAST(int, args[1]);
  43. int n_rows = CAST(int, args[2]);
  44. if(n_cols <= 0 || n_rows <= 0){
  45. vm->ValueError("n_cols and n_rows must be positive integers");
  46. }
  47. self.init(n_cols, n_rows);
  48. if(vm->py_callable(args[3])){
  49. for(int i = 0; i < self.numel; i++) self.data[i] = vm->call(args[3]);
  50. }else{
  51. for(int i = 0; i < self.numel; i++) self.data[i] = args[3];
  52. }
  53. return vm->None;
  54. });
  55. PY_READONLY_FIELD(Array2d, "n_cols", n_cols);
  56. PY_READONLY_FIELD(Array2d, "n_rows", n_rows);
  57. PY_READONLY_FIELD(Array2d, "width", n_cols);
  58. PY_READONLY_FIELD(Array2d, "height", n_rows);
  59. PY_READONLY_FIELD(Array2d, "numel", numel);
  60. // _get
  61. vm->bind_func(type, "_get", 3, [](VM* vm, ArgsView args){
  62. Array2d& self = PK_OBJ_GET(Array2d, args[0]);
  63. int col = CAST(int, args[1]);
  64. int row = CAST(int, args[2]);
  65. self.check_valid(vm, col, row);
  66. return self._get(col, row);
  67. });
  68. // _set
  69. vm->bind_func(type, "_set", 4, [](VM* vm, ArgsView args){
  70. Array2d& self = PK_OBJ_GET(Array2d, args[0]);
  71. int col = CAST(int, args[1]);
  72. int row = CAST(int, args[2]);
  73. self.check_valid(vm, col, row);
  74. self._set(col, row, args[3]);
  75. return vm->None;
  76. });
  77. vm->bind_func(type, "is_valid", 3, [](VM* vm, ArgsView args){
  78. Array2d& self = PK_OBJ_GET(Array2d, args[0]);
  79. int col = CAST(int, args[1]);
  80. int row = CAST(int, args[2]);
  81. return VAR(self.is_valid(col, row));
  82. });
  83. vm->bind(type, "get(self, col: int, row: int, default=None)", [](VM* vm, ArgsView args){
  84. Array2d& self = PK_OBJ_GET(Array2d, args[0]);
  85. int col = CAST(int, args[1]);
  86. int row = CAST(int, args[2]);
  87. if(!self.is_valid(col, row)) return args[3];
  88. return self._get(col, row);
  89. });
  90. #define HANDLE_SLICE() \
  91. int start_col, stop_col, step_col; \
  92. int start_row, stop_row, step_row; \
  93. vm->parse_int_slice(PK_OBJ_GET(Slice, xy[0]), self.n_cols, start_col, stop_col, step_col); \
  94. vm->parse_int_slice(PK_OBJ_GET(Slice, xy[1]), self.n_rows, start_row, stop_row, step_row); \
  95. if(step_col != 1 || step_row != 1) vm->ValueError("slice step must be 1"); \
  96. int slice_width = stop_col - start_col; \
  97. int slice_height = stop_row - start_row; \
  98. if(slice_width <= 0 || slice_height <= 0) vm->ValueError("slice width and height must be positive");
  99. vm->bind__getitem__(type->as<Type>(), [](VM* vm, PyVar _0, PyVar _1){
  100. Array2d& self = PK_OBJ_GET(Array2d, _0);
  101. const Tuple& xy = CAST(Tuple&, _1);
  102. if(is_int(xy[0]) && is_int(xy[1])){
  103. i64 col = xy[0].as<i64>();
  104. i64 row = xy[1].as<i64>();
  105. self.check_valid(vm, col, row);
  106. return self._get(col, row);
  107. }
  108. if(is_type(xy[0], VM::tp_slice) && is_type(xy[1], VM::tp_slice)){
  109. HANDLE_SLICE();
  110. PyVar new_array_obj = vm->new_user_object<Array2d>();
  111. Array2d& new_array = PK_OBJ_GET(Array2d, new_array_obj);
  112. new_array.init(stop_col - start_col, stop_row - start_row);
  113. for(int j = start_row; j < stop_row; j++){
  114. for(int i = start_col; i < stop_col; i++){
  115. new_array._set(i - start_col, j - start_row, self._get(i, j));
  116. }
  117. }
  118. return new_array_obj;
  119. }
  120. vm->TypeError("expected `tuple[int, int]` or `tuple[slice, slice]` as index");
  121. });
  122. vm->bind__setitem__(type->as<Type>(), [](VM* vm, PyVar _0, PyVar _1, PyVar _2){
  123. Array2d& self = PK_OBJ_GET(Array2d, _0);
  124. const Tuple& xy = CAST(Tuple&, _1);
  125. if(is_int(xy[0]) && is_int(xy[1])){
  126. i64 col = xy[0].as<i64>();
  127. i64 row = xy[1].as<i64>();
  128. self.check_valid(vm, col, row);
  129. self._set(col, row, _2);
  130. return;
  131. }
  132. if(is_type(xy[0], VM::tp_slice) && is_type(xy[1], VM::tp_slice)){
  133. HANDLE_SLICE();
  134. bool is_basic_type = false;
  135. switch(vm->_tp(_2).index){
  136. case VM::tp_int.index: is_basic_type = true; break;
  137. case VM::tp_float.index: is_basic_type = true; break;
  138. case VM::tp_str.index: is_basic_type = true; break;
  139. case VM::tp_bool.index: is_basic_type = true; break;
  140. default: is_basic_type = _2 == vm->None;
  141. }
  142. if(is_basic_type){
  143. for(int j = 0; j < slice_height; j++)
  144. for(int i = 0; i < slice_width; i++)
  145. self._set(i + start_col, j + start_row, _2);
  146. return;
  147. }
  148. if(!vm->is_user_type<Array2d>(_2)){
  149. vm->TypeError(_S("expected int/float/str/bool/None or an array2d instance"));
  150. }
  151. Array2d& other = PK_OBJ_GET(Array2d, _2);
  152. if(slice_width != other.n_cols || slice_height != other.n_rows){
  153. vm->ValueError("array2d size does not match the slice size");
  154. }
  155. for(int j = 0; j < slice_height; j++)
  156. for(int i = 0; i < slice_width; i++)
  157. self._set(i + start_col, j + start_row, other._get(i, j));
  158. return;
  159. }
  160. vm->TypeError("expected `tuple[int, int]` or `tuple[slice, slice]` as index");
  161. });
  162. #undef HANDLE_SLICE
  163. vm->bind_func(type, "tolist", 1, [](VM* vm, ArgsView args){
  164. Array2d& self = PK_OBJ_GET(Array2d, args[0]);
  165. List t(self.n_rows);
  166. for(int j = 0; j < self.n_rows; j++){
  167. List row(self.n_cols);
  168. for(int i = 0; i < self.n_cols; i++) row[i] = self._get(i, j);
  169. t[j] = VAR(std::move(row));
  170. }
  171. return VAR(std::move(t));
  172. });
  173. vm->bind__len__(type->as<Type>(), [](VM* vm, PyVar _0){
  174. Array2d& self = PK_OBJ_GET(Array2d, _0);
  175. return (i64)self.numel;
  176. });
  177. vm->bind__repr__(type->as<Type>(), [](VM* vm, PyVar _0) -> Str{
  178. Array2d& self = PK_OBJ_GET(Array2d, _0);
  179. return _S("array2d(", self.n_cols, ", ", self.n_rows, ')');
  180. });
  181. vm->bind_func(type, "map", 2, [](VM* vm, ArgsView args){
  182. Array2d& self = PK_OBJ_GET(Array2d, args[0]);
  183. PyVar f = args[1];
  184. PyVar new_array_obj = vm->new_user_object<Array2d>();
  185. Array2d& new_array = PK_OBJ_GET(Array2d, new_array_obj);
  186. new_array.init(self.n_cols, self.n_rows);
  187. for(int i = 0; i < new_array.numel; i++){
  188. new_array.data[i] = vm->call(f, self.data[i]);
  189. }
  190. return new_array_obj;
  191. });
  192. vm->bind_func(type, "copy", 1, [](VM* vm, ArgsView args){
  193. Array2d& self = PK_OBJ_GET(Array2d, args[0]);
  194. PyVar new_array_obj = vm->new_user_object<Array2d>();
  195. Array2d& new_array = PK_OBJ_GET(Array2d, new_array_obj);
  196. new_array.init(self.n_cols, self.n_rows);
  197. for(int i = 0; i < new_array.numel; i++){
  198. new_array.data[i] = self.data[i];
  199. }
  200. return new_array_obj;
  201. });
  202. vm->bind_func(type, "fill_", 2, [](VM* vm, ArgsView args){
  203. Array2d& self = PK_OBJ_GET(Array2d, args[0]);
  204. for(int i = 0; i < self.numel; i++){
  205. self.data[i] = args[1];
  206. }
  207. return vm->None;
  208. });
  209. vm->bind_func(type, "apply_", 2, [](VM* vm, ArgsView args){
  210. Array2d& self = PK_OBJ_GET(Array2d, args[0]);
  211. PyVar f = args[1];
  212. for(int i = 0; i < self.numel; i++){
  213. self.data[i] = vm->call(f, self.data[i]);
  214. }
  215. return vm->None;
  216. });
  217. vm->bind_func(type, "copy_", 2, [](VM* vm, ArgsView args){
  218. Array2d& self = PK_OBJ_GET(Array2d, args[0]);
  219. if(is_type(args[1], VM::tp_list)){
  220. const List& list = PK_OBJ_GET(List, args[1]);
  221. if(list.size() != self.numel){
  222. vm->ValueError("list size must be equal to the number of elements in the array2d");
  223. }
  224. for(int i = 0; i < self.numel; i++){
  225. self.data[i] = list[i];
  226. }
  227. return vm->None;
  228. }
  229. Array2d& other = CAST(Array2d&, args[1]);
  230. // if self and other have different sizes, re-initialize self
  231. if(self.n_cols != other.n_cols || self.n_rows != other.n_rows){
  232. delete self.data;
  233. self.init(other.n_cols, other.n_rows);
  234. }
  235. for(int i = 0; i < self.numel; i++){
  236. self.data[i] = other.data[i];
  237. }
  238. return vm->None;
  239. });
  240. vm->bind__eq__(type->as<Type>(), [](VM* vm, PyVar _0, PyVar _1){
  241. Array2d& self = PK_OBJ_GET(Array2d, _0);
  242. if(!vm->is_user_type<Array2d>(_1)) return vm->NotImplemented;
  243. Array2d& other = PK_OBJ_GET(Array2d, _1);
  244. if(self.n_cols != other.n_cols || self.n_rows != other.n_rows) return vm->False;
  245. for(int i = 0; i < self.numel; i++){
  246. if(vm->py_ne(self.data[i], other.data[i])) return vm->False;
  247. }
  248. return vm->True;
  249. });
  250. vm->bind(type, "count_neighbors(self, value, neighborhood='Moore') -> array2d[int]", [](VM* vm, ArgsView args){
  251. Array2d& self = PK_OBJ_GET(Array2d, args[0]);
  252. PyVar new_array_obj = vm->new_user_object<Array2d>();
  253. Array2d& new_array = PK_OBJ_GET(Array2d, new_array_obj);
  254. new_array.init(self.n_cols, self.n_rows);
  255. PyVar value = args[1];
  256. const Str& neighborhood = CAST(Str&, args[2]);
  257. if(neighborhood == "Moore"){
  258. for(int j = 0; j < new_array.n_rows; j++){
  259. for(int i = 0; i < new_array.n_cols; i++){
  260. int count = 0;
  261. count += self.is_valid(i-1, j-1) && vm->py_eq(self._get(i-1, j-1), value);
  262. count += self.is_valid(i, j-1) && vm->py_eq(self._get(i, j-1), value);
  263. count += self.is_valid(i+1, j-1) && vm->py_eq(self._get(i+1, j-1), value);
  264. count += self.is_valid(i-1, j) && vm->py_eq(self._get(i-1, j), value);
  265. count += self.is_valid(i+1, j) && vm->py_eq(self._get(i+1, j), value);
  266. count += self.is_valid(i-1, j+1) && vm->py_eq(self._get(i-1, j+1), value);
  267. count += self.is_valid(i, j+1) && vm->py_eq(self._get(i, j+1), value);
  268. count += self.is_valid(i+1, j+1) && vm->py_eq(self._get(i+1, j+1), value);
  269. new_array._set(i, j, VAR(count));
  270. }
  271. }
  272. }else if(neighborhood == "von Neumann"){
  273. for(int j = 0; j < new_array.n_rows; j++){
  274. for(int i = 0; i < new_array.n_cols; i++){
  275. int count = 0;
  276. count += self.is_valid(i, j-1) && vm->py_eq(self._get(i, j-1), value);
  277. count += self.is_valid(i-1, j) && vm->py_eq(self._get(i-1, j), value);
  278. count += self.is_valid(i+1, j) && vm->py_eq(self._get(i+1, j), value);
  279. count += self.is_valid(i, j+1) && vm->py_eq(self._get(i, j+1), value);
  280. new_array._set(i, j, VAR(count));
  281. }
  282. }
  283. }else{
  284. vm->ValueError("neighborhood must be 'Moore' or 'von Neumann'");
  285. }
  286. return new_array_obj;
  287. });
  288. vm->bind_func(type, "count", 2, [](VM* vm, ArgsView args){
  289. Array2d& self = PK_OBJ_GET(Array2d, args[0]);
  290. PyVar value = args[1];
  291. int count = 0;
  292. for(int i = 0; i < self.numel; i++) count += vm->py_eq(self.data[i], value);
  293. return VAR(count);
  294. });
  295. vm->bind_func(type, "find_bounding_rect", 2, [](VM* vm, ArgsView args){
  296. Array2d& self = PK_OBJ_GET(Array2d, args[0]);
  297. PyVar value = args[1];
  298. int left = self.n_cols;
  299. int top = self.n_rows;
  300. int right = 0;
  301. int bottom = 0;
  302. for(int j = 0; j < self.n_rows; j++){
  303. for(int i = 0; i < self.n_cols; i++){
  304. if(vm->py_eq(self._get(i, j), value)){
  305. left = (std::min)(left, i);
  306. top = (std::min)(top, j);
  307. right = (std::max)(right, i);
  308. bottom = (std::max)(bottom, j);
  309. }
  310. }
  311. }
  312. int width = right - left + 1;
  313. int height = bottom - top + 1;
  314. if(width <= 0 || height <= 0) return vm->None;
  315. Tuple t(4);
  316. t[0] = VAR(left);
  317. t[1] = VAR(top);
  318. t[2] = VAR(width);
  319. t[3] = VAR(height);
  320. return VAR(std::move(t));
  321. });
  322. }
  323. void _gc_mark(VM* vm) const{
  324. for(int i = 0; i < numel; i++) vm->obj_gc_mark(data[i]);
  325. }
  326. ~Array2d(){
  327. delete[] data;
  328. }
  329. };
  330. struct Array2dIter{
  331. PK_ALWAYS_PASS_BY_POINTER(Array2dIter)
  332. PyVar ref;
  333. Array2d* a;
  334. int i;
  335. Array2dIter(PyVar ref, Array2d* a): ref(ref), a(a), i(0){}
  336. void _gc_mark(VM* vm) const{ vm->obj_gc_mark(ref); }
  337. static void _register(VM* vm, PyObject* mod, PyObject* type){
  338. vm->bind__iter__(type->as<Type>(), [](VM* vm, PyVar _0) { return _0; });
  339. vm->bind__next__(type->as<Type>(), [](VM* vm, PyVar _0) -> unsigned{
  340. Array2dIter& self = PK_OBJ_GET(Array2dIter, _0);
  341. if(self.i == self.a->numel) return 0;
  342. std::div_t res = std::div(self.i, self.a->n_cols);
  343. vm->s_data.emplace(VM::tp_int, res.rem);
  344. vm->s_data.emplace(VM::tp_int, res.quot);
  345. vm->s_data.push(self.a->data[self.i++]);
  346. return 3;
  347. });
  348. }
  349. };
  350. void add_module_array2d(VM* vm){
  351. PyObject* mod = vm->new_module("array2d");
  352. vm->register_user_class<Array2d>(mod, "array2d", VM::tp_object, true);
  353. vm->register_user_class<Array2dIter>(mod, "_array2d_iter");
  354. Type array2d_iter_t = vm->_tp_user<Array2d>();
  355. vm->bind__iter__(array2d_iter_t, [](VM* vm, PyVar _0){
  356. return vm->new_user_object<Array2dIter>(_0, &_0.obj_get<Array2d>());
  357. });
  358. vm->_all_types[array2d_iter_t].op__iter__ = [](VM* vm, PyVar _0){
  359. vm->new_stack_object<Array2dIter>(vm->_tp_user<Array2dIter>(), _0, &_0.obj_get<Array2d>());
  360. };
  361. }
  362. } // namespace pkpy