array2d.pyi 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. from typing import Callable, Any, Generic, TypeVar
  2. T = TypeVar('T')
  3. class array2d(Generic[T]):
  4. data: list[T] # not available in native module
  5. def __init__(self, n_cols: int, n_rows: int, default=None):
  6. self.n_cols = n_cols
  7. self.n_rows = n_rows
  8. if callable(default):
  9. self.data = [default() for _ in range(n_cols * n_rows)]
  10. else:
  11. self.data = [default] * n_cols * n_rows
  12. @property
  13. def width(self) -> int:
  14. return self.n_cols
  15. @property
  16. def height(self) -> int:
  17. return self.n_rows
  18. @property
  19. def numel(self) -> int:
  20. return self.n_cols * self.n_rows
  21. def is_valid(self, col: int, row: int) -> bool:
  22. return 0 <= col < self.n_cols and 0 <= row < self.n_rows
  23. def get(self, col: int, row: int, default=None):
  24. if not self.is_valid(col, row):
  25. return default
  26. return self.data[row * self.n_cols + col]
  27. def __getitem__(self, index: tuple[int, int]):
  28. col, row = index
  29. if not self.is_valid(col, row):
  30. raise IndexError(f'({col}, {row}) is not a valid index for {self!r}')
  31. return self.data[row * self.n_cols + col]
  32. def __setitem__(self, index: tuple[int, int], value: T):
  33. col, row = index
  34. if not self.is_valid(col, row):
  35. raise IndexError(f'({col}, {row}) is not a valid index for {self!r}')
  36. self.data[row * self.n_cols + col] = value
  37. def __iter__(self) -> list[list['T']]:
  38. for row in range(self.n_rows):
  39. yield [self[col, row] for col in range(self.n_cols)]
  40. def __len__(self):
  41. return self.n_rows
  42. def __eq__(self, other: 'array2d') -> bool:
  43. if not isinstance(other, array2d):
  44. return NotImplemented
  45. for i in range(self.numel):
  46. if self.data[i] != other.data[i]:
  47. return False
  48. return True
  49. def __ne__(self, other: 'array2d') -> bool:
  50. return not self.__eq__(other)
  51. def __repr__(self):
  52. return f'array2d({self.n_cols}, {self.n_rows})'
  53. def map(self, f: Callable[[T], Any]) -> 'array2d':
  54. new_a: array2d = array2d(self.n_cols, self.n_rows)
  55. for i in range(self.n_cols * self.n_rows):
  56. new_a.data[i] = f(self.data[i])
  57. return new_a
  58. def copy(self) -> 'array2d[T]':
  59. new_a: array2d[T] = array2d(self.n_cols, self.n_rows)
  60. new_a.data = self.data.copy()
  61. return new_a
  62. def fill_(self, value: T) -> None:
  63. for i in range(self.numel):
  64. self.data[i] = value
  65. def apply_(self, f: Callable[[T], T]) -> None:
  66. for i in range(self.numel):
  67. self.data[i] = f(self.data[i])
  68. def copy_(self, other: 'array2d[T] | list[T]') -> None:
  69. if isinstance(other, list):
  70. assert len(other) == self.numel
  71. self.data = other.copy()
  72. return
  73. self.n_cols = other.n_cols
  74. self.n_rows = other.n_rows
  75. self.data = other.data.copy()
  76. # for cellular automata
  77. def count_neighbors(self, value) -> 'array2d[int]':
  78. new_a = array2d(self.n_cols, self.n_rows)
  79. for j in range(self.n_rows):
  80. for i in range(self.n_cols):
  81. count = 0
  82. count += int(self.is_valid(i-1, j-1) and self[i-1, j-1] == value)
  83. count += int(self.is_valid(i, j-1) and self[i, j-1] == value)
  84. count += int(self.is_valid(i+1, j-1) and self[i+1, j-1] == value)
  85. count += int(self.is_valid(i-1, j) and self[i-1, j] == value)
  86. count += int(self.is_valid(i+1, j) and self[i+1, j] == value)
  87. count += int(self.is_valid(i-1, j+1) and self[i-1, j+1] == value)
  88. count += int(self.is_valid(i, j+1) and self[i, j+1] == value)
  89. count += int(self.is_valid(i+1, j+1) and self[i+1, j+1] == value)
  90. new_a[i, j] = count
  91. return new_a