array2d.pyi 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from typing import Callable, Any, Generic, TypeVar
  2. T = TypeVar('T')
  3. class array2d(Generic[T]):
  4. data: list[T] # not available in native module
  5. def __init__(self, n_cols: int, n_rows: int, default=None):
  6. self.n_cols = n_cols
  7. self.n_rows = n_rows
  8. if callable(default):
  9. self.data = [default() for _ in range(n_cols * n_rows)]
  10. else:
  11. self.data = [default] * n_cols * n_rows
  12. @property
  13. def width(self) -> int:
  14. return self.n_cols
  15. @property
  16. def height(self) -> int:
  17. return self.n_rows
  18. @property
  19. def numel(self) -> int:
  20. return self.n_cols * self.n_rows
  21. def is_valid(self, col: int, row: int) -> bool:
  22. return 0 <= col < self.n_cols and 0 <= row < self.n_rows
  23. def get(self, col: int, row: int, default=None):
  24. if not self.is_valid(col, row):
  25. return default
  26. return self.data[row * self.n_cols + col]
  27. def __getitem__(self, index: tuple[int, int]):
  28. col, row = index
  29. if not self.is_valid(col, row):
  30. raise IndexError(f'({col}, {row}) is not a valid index for {self!r}')
  31. return self.data[row * self.n_cols + col]
  32. def __setitem__(self, index: tuple[int, int], value: T):
  33. col, row = index
  34. if not self.is_valid(col, row):
  35. raise IndexError(f'({col}, {row}) is not a valid index for {self!r}')
  36. self.data[row * self.n_cols + col] = value
  37. def __iter__(self) -> list[list['T']]:
  38. for row in range(self.n_rows):
  39. yield [self[col, row] for col in range(self.n_cols)]
  40. def __len__(self):
  41. return self.n_rows
  42. def __eq__(self, other: 'array2d') -> bool:
  43. if not isinstance(other, array2d):
  44. return NotImplemented
  45. for i in range(self.numel):
  46. if self.data[i] != other.data[i]:
  47. return False
  48. return True
  49. def __ne__(self, other: 'array2d') -> bool:
  50. return not self.__eq__(other)
  51. def __repr__(self):
  52. return f'array2d({self.n_cols}, {self.n_rows})'
  53. def map(self, f: Callable[[T], Any]) -> 'array2d':
  54. new_a: array2d = array2d(self.n_cols, self.n_rows)
  55. for i in range(self.n_cols * self.n_rows):
  56. new_a.data[i] = f(self.data[i])
  57. return new_a
  58. def copy(self) -> 'array2d[T]':
  59. new_a: array2d[T] = array2d(self.n_cols, self.n_rows)
  60. new_a.data = self.data.copy()
  61. return new_a
  62. def fill_(self, value: T) -> None:
  63. for i in range(self.n_cols * self.n_rows):
  64. self.data[i] = value
  65. def apply_(self, f: Callable[[T], T]) -> None:
  66. for i in range(self.n_cols * self.n_rows):
  67. self.data[i] = f(self.data[i])
  68. def copy_(self, other: 'array2d[T]') -> None:
  69. self.n_cols = other.n_cols
  70. self.n_rows = other.n_rows
  71. self.data = other.data.copy()