linalg.h 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709
  1. #pragma once
  2. #include "common.h"
  3. #if PK_MODULE_LINALG
  4. #include "cffi.h"
  5. namespace pkpy{
  6. static constexpr float kEpsilon = 1e-4f;
  7. inline static bool isclose(float a, float b){ return fabsf(a - b) < kEpsilon; }
  8. struct Vec2{
  9. float x, y;
  10. Vec2() : x(0.0f), y(0.0f) {}
  11. Vec2(float x, float y) : x(x), y(y) {}
  12. Vec2(const Vec2& v) : x(v.x), y(v.y) {}
  13. Vec2 operator+(const Vec2& v) const { return Vec2(x + v.x, y + v.y); }
  14. Vec2& operator+=(const Vec2& v) { x += v.x; y += v.y; return *this; }
  15. Vec2 operator-(const Vec2& v) const { return Vec2(x - v.x, y - v.y); }
  16. Vec2& operator-=(const Vec2& v) { x -= v.x; y -= v.y; return *this; }
  17. Vec2 operator*(float s) const { return Vec2(x * s, y * s); }
  18. Vec2& operator*=(float s) { x *= s; y *= s; return *this; }
  19. Vec2 operator/(float s) const { return Vec2(x / s, y / s); }
  20. Vec2& operator/=(float s) { x /= s; y /= s; return *this; }
  21. Vec2 operator-() const { return Vec2(-x, -y); }
  22. bool operator==(const Vec2& v) const { return isclose(x, v.x) && isclose(y, v.y); }
  23. bool operator!=(const Vec2& v) const { return !isclose(x, v.x) || !isclose(y, v.y); }
  24. float dot(const Vec2& v) const { return x * v.x + y * v.y; }
  25. float cross(const Vec2& v) const { return x * v.y - y * v.x; }
  26. float length() const { return sqrtf(x * x + y * y); }
  27. float length_squared() const { return x * x + y * y; }
  28. Vec2 normalize() const { float l = length(); return Vec2(x / l, y / l); }
  29. };
  30. struct Vec3{
  31. float x, y, z;
  32. Vec3() : x(0.0f), y(0.0f), z(0.0f) {}
  33. Vec3(float x, float y, float z) : x(x), y(y), z(z) {}
  34. Vec3(const Vec3& v) : x(v.x), y(v.y), z(v.z) {}
  35. Vec3 operator+(const Vec3& v) const { return Vec3(x + v.x, y + v.y, z + v.z); }
  36. Vec3& operator+=(const Vec3& v) { x += v.x; y += v.y; z += v.z; return *this; }
  37. Vec3 operator-(const Vec3& v) const { return Vec3(x - v.x, y - v.y, z - v.z); }
  38. Vec3& operator-=(const Vec3& v) { x -= v.x; y -= v.y; z -= v.z; return *this; }
  39. Vec3 operator*(float s) const { return Vec3(x * s, y * s, z * s); }
  40. Vec3& operator*=(float s) { x *= s; y *= s; z *= s; return *this; }
  41. Vec3 operator/(float s) const { return Vec3(x / s, y / s, z / s); }
  42. Vec3& operator/=(float s) { x /= s; y /= s; z /= s; return *this; }
  43. Vec3 operator-() const { return Vec3(-x, -y, -z); }
  44. bool operator==(const Vec3& v) const { return isclose(x, v.x) && isclose(y, v.y) && isclose(z, v.z); }
  45. bool operator!=(const Vec3& v) const { return !isclose(x, v.x) || !isclose(y, v.y) || !isclose(z, v.z); }
  46. float dot(const Vec3& v) const { return x * v.x + y * v.y + z * v.z; }
  47. Vec3 cross(const Vec3& v) const { return Vec3(y * v.z - z * v.y, z * v.x - x * v.z, x * v.y - y * v.x); }
  48. float length() const { return sqrtf(x * x + y * y + z * z); }
  49. float length_squared() const { return x * x + y * y + z * z; }
  50. Vec3 normalize() const { float l = length(); return Vec3(x / l, y / l, z / l); }
  51. };
  52. struct Mat3x3{
  53. union {
  54. struct {
  55. float _11, _12, _13;
  56. float _21, _22, _23;
  57. float _31, _32, _33;
  58. };
  59. float m[3][3];
  60. float v[9];
  61. };
  62. Mat3x3() {}
  63. Mat3x3(float _11, float _12, float _13,
  64. float _21, float _22, float _23,
  65. float _31, float _32, float _33)
  66. : _11(_11), _12(_12), _13(_13)
  67. , _21(_21), _22(_22), _23(_23)
  68. , _31(_31), _32(_32), _33(_33) {}
  69. void set_zeros(){ for (int i=0; i<9; ++i) v[i] = 0.0f; }
  70. void set_ones(){ for (int i=0; i<9; ++i) v[i] = 1.0f; }
  71. void set_identity(){ set_zeros(); _11 = _22 = _33 = 1.0f; }
  72. static Mat3x3 zeros(){
  73. static Mat3x3 ret(0, 0, 0, 0, 0, 0, 0, 0, 0);
  74. return ret;
  75. }
  76. static Mat3x3 ones(){
  77. static Mat3x3 ret(1, 1, 1, 1, 1, 1, 1, 1, 1);
  78. return ret;
  79. }
  80. static Mat3x3 identity(){
  81. static Mat3x3 ret(1, 0, 0, 0, 1, 0, 0, 0, 1);
  82. return ret;
  83. }
  84. Mat3x3 operator+(const Mat3x3& other) const{
  85. Mat3x3 ret;
  86. for (int i=0; i<9; ++i) ret.v[i] = v[i] + other.v[i];
  87. return ret;
  88. }
  89. Mat3x3 operator-(const Mat3x3& other) const{
  90. Mat3x3 ret;
  91. for (int i=0; i<9; ++i) ret.v[i] = v[i] - other.v[i];
  92. return ret;
  93. }
  94. Mat3x3 operator*(float scalar) const{
  95. Mat3x3 ret;
  96. for (int i=0; i<9; ++i) ret.v[i] = v[i] * scalar;
  97. return ret;
  98. }
  99. Mat3x3 operator/(float scalar) const{
  100. Mat3x3 ret;
  101. for (int i=0; i<9; ++i) ret.v[i] = v[i] / scalar;
  102. return ret;
  103. }
  104. Mat3x3& operator+=(const Mat3x3& other){
  105. for (int i=0; i<9; ++i) v[i] += other.v[i];
  106. return *this;
  107. }
  108. Mat3x3& operator-=(const Mat3x3& other){
  109. for (int i=0; i<9; ++i) v[i] -= other.v[i];
  110. return *this;
  111. }
  112. Mat3x3& operator*=(float scalar){
  113. for (int i=0; i<9; ++i) v[i] *= scalar;
  114. return *this;
  115. }
  116. Mat3x3& operator/=(float scalar){
  117. for (int i=0; i<9; ++i) v[i] /= scalar;
  118. return *this;
  119. }
  120. Mat3x3 matmul(const Mat3x3& other) const{
  121. Mat3x3 ret;
  122. ret._11 = _11 * other._11 + _12 * other._21 + _13 * other._31;
  123. ret._12 = _11 * other._12 + _12 * other._22 + _13 * other._32;
  124. ret._13 = _11 * other._13 + _12 * other._23 + _13 * other._33;
  125. ret._21 = _21 * other._11 + _22 * other._21 + _23 * other._31;
  126. ret._22 = _21 * other._12 + _22 * other._22 + _23 * other._32;
  127. ret._23 = _21 * other._13 + _22 * other._23 + _23 * other._33;
  128. ret._31 = _31 * other._11 + _32 * other._21 + _33 * other._31;
  129. ret._32 = _31 * other._12 + _32 * other._22 + _33 * other._32;
  130. ret._33 = _31 * other._13 + _32 * other._23 + _33 * other._33;
  131. return ret;
  132. }
  133. Vec3 matmul(const Vec3& other) const{
  134. Vec3 ret;
  135. ret.x = _11 * other.x + _12 * other.y + _13 * other.z;
  136. ret.y = _21 * other.x + _22 * other.y + _23 * other.z;
  137. ret.z = _31 * other.x + _32 * other.y + _33 * other.z;
  138. return ret;
  139. }
  140. bool operator==(const Mat3x3& other) const{
  141. for (int i=0; i<9; ++i){
  142. if (!isclose(v[i], other.v[i])) return false;
  143. }
  144. return true;
  145. }
  146. bool operator!=(const Mat3x3& other) const{
  147. for (int i=0; i<9; ++i){
  148. if (!isclose(v[i], other.v[i])) return true;
  149. }
  150. return false;
  151. }
  152. float determinant() const{
  153. return _11 * _22 * _33 + _12 * _23 * _31 + _13 * _21 * _32
  154. - _11 * _23 * _32 - _12 * _21 * _33 - _13 * _22 * _31;
  155. }
  156. Mat3x3 transpose() const{
  157. Mat3x3 ret;
  158. ret._11 = _11; ret._12 = _21; ret._13 = _31;
  159. ret._21 = _12; ret._22 = _22; ret._23 = _32;
  160. ret._31 = _13; ret._32 = _23; ret._33 = _33;
  161. return ret;
  162. }
  163. bool inverse(Mat3x3& ret) const{
  164. float det = determinant();
  165. if (fabsf(det) < kEpsilon) return false;
  166. float inv_det = 1.0f / det;
  167. ret._11 = (_22 * _33 - _23 * _32) * inv_det;
  168. ret._12 = (_13 * _32 - _12 * _33) * inv_det;
  169. ret._13 = (_12 * _23 - _13 * _22) * inv_det;
  170. ret._21 = (_23 * _31 - _21 * _33) * inv_det;
  171. ret._22 = (_11 * _33 - _13 * _31) * inv_det;
  172. ret._23 = (_13 * _21 - _11 * _23) * inv_det;
  173. ret._31 = (_21 * _32 - _22 * _31) * inv_det;
  174. ret._32 = (_12 * _31 - _11 * _32) * inv_det;
  175. ret._33 = (_11 * _22 - _12 * _21) * inv_det;
  176. return true;
  177. }
  178. /*************** affine transformations ***************/
  179. static Mat3x3 trs(Vec2 t, float radian, Vec2 s){
  180. float cr = cosf(radian);
  181. float sr = sinf(radian);
  182. return Mat3x3(s.x * cr, -s.y * sr, t.x,
  183. s.x * sr, s.y * cr, t.y,
  184. 0.0f, 0.0f, 1.0f);
  185. }
  186. bool is_affine() const{
  187. float det = _11 * _22 - _12 * _21;
  188. if(fabsf(det) < kEpsilon) return false;
  189. return _31 == 0.0f && _32 == 0.0f && _33 == 1.0f;
  190. }
  191. Mat3x3 inverse_affine() const{
  192. Mat3x3 ret;
  193. float det = _11 * _22 - _12 * _21;
  194. float inv_det = 1.0f / det;
  195. ret._11 = _22 * inv_det;
  196. ret._12 = -_12 * inv_det;
  197. ret._13 = (_12 * _23 - _13 * _22) * inv_det;
  198. ret._21 = -_21 * inv_det;
  199. ret._22 = _11 * inv_det;
  200. ret._23 = (_13 * _21 - _11 * _23) * inv_det;
  201. ret._31 = 0.0f;
  202. ret._32 = 0.0f;
  203. ret._33 = 1.0f;
  204. return ret;
  205. }
  206. Mat3x3 matmul_affine(const Mat3x3& other) const{
  207. Mat3x3 ret;
  208. ret._11 = _11 * other._11 + _12 * other._21;
  209. ret._12 = _11 * other._12 + _12 * other._22;
  210. ret._13 = _11 * other._13 + _12 * other._23 + _13;
  211. ret._21 = _21 * other._11 + _22 * other._21;
  212. ret._22 = _21 * other._12 + _22 * other._22;
  213. ret._23 = _21 * other._13 + _22 * other._23 + _23;
  214. ret._31 = 0.0f;
  215. ret._32 = 0.0f;
  216. ret._33 = 1.0f;
  217. return ret;
  218. }
  219. Vec2 translation() const { return Vec2(_13, _23); }
  220. float rotation() const { return atan2f(_21, _11); }
  221. Vec2 scale() const {
  222. return Vec2(
  223. sqrtf(_11 * _11 + _21 * _21),
  224. sqrtf(_12 * _12 + _22 * _22)
  225. );
  226. }
  227. Vec2 transform_point(Vec2 v) const {
  228. return Vec2(_11 * v.x + _12 * v.y + _13, _21 * v.x + _22 * v.y + _23);
  229. }
  230. Vec2 transform_vector(Vec2 v) const {
  231. return Vec2(_11 * v.x + _12 * v.y, _21 * v.x + _22 * v.y);
  232. }
  233. };
  234. struct PyVec2;
  235. struct PyVec3;
  236. struct PyMat3x3;
  237. PyObject* py_var(VM*, Vec2);
  238. PyObject* py_var(VM*, const PyVec2&);
  239. PyObject* py_var(VM*, Vec3);
  240. PyObject* py_var(VM*, const PyVec3&);
  241. PyObject* py_var(VM*, const Mat3x3&);
  242. PyObject* py_var(VM*, const PyMat3x3&);
  243. #define BIND_VEC_VEC_OP(D, name, op) \
  244. vm->bind_method<1>(type, #name, [](VM* vm, ArgsView args){ \
  245. PyVec##D& self = _CAST(PyVec##D&, args[0]); \
  246. PyVec##D& other = CAST(PyVec##D&, args[1]); \
  247. return VAR(self op other); \
  248. });
  249. #define BIND_VEC_FLOAT_OP(D, name, op) \
  250. vm->bind_method<1>(type, #name, [](VM* vm, ArgsView args){ \
  251. PyVec##D& self = _CAST(PyVec##D&, args[0]); \
  252. f64 other = vm->num_to_float(args[1]); \
  253. return VAR(self op other); \
  254. });
  255. #define BIND_VEC_FUNCTION_0(D, name) \
  256. vm->bind_method<0>(type, #name, [](VM* vm, ArgsView args){ \
  257. PyVec##D& self = _CAST(PyVec##D&, args[0]); \
  258. return VAR(self.name()); \
  259. });
  260. #define BIND_VEC_FUNCTION_1(D, name) \
  261. vm->bind_method<0>(type, #name, [](VM* vm, ArgsView args){ \
  262. PyVec##D& self = _CAST(PyVec##D&, args[0]); \
  263. PyVec##D& other = CAST(PyVec##D&, args[1]); \
  264. return VAR(self.name(other)); \
  265. });
  266. #define BIND_VEC_FIELD(D, name) \
  267. type->attr().set(#name, vm->property([](VM* vm, ArgsView args){ \
  268. PyVec##D& self = _CAST(PyVec##D&, args[0]); \
  269. return VAR(self.name); \
  270. }, [](VM* vm, ArgsView args){ \
  271. PyVec##D& self = _CAST(PyVec##D&, args[0]); \
  272. self.name = vm->num_to_float(args[1]); \
  273. return vm->None; \
  274. }));
  275. struct PyVec2: Vec2 {
  276. PY_CLASS(PyVec2, linalg, vec2)
  277. PyVec2() : Vec2() {}
  278. PyVec2(const Vec2& v) : Vec2(v) {}
  279. PyVec2(const PyVec2& v) : Vec2(v) {}
  280. static void _register(VM* vm, PyObject* mod, PyObject* type){
  281. vm->bind_constructor<3>(type, [](VM* vm, ArgsView args){
  282. float x = CAST_F(args[1]);
  283. float y = CAST_F(args[2]);
  284. return VAR(Vec2(x, y));
  285. });
  286. vm->bind_method<0>(type, "__getnewargs__", [](VM* vm, ArgsView args){
  287. PyVec2& self = _CAST(PyVec2&, args[0]);
  288. return VAR(Tuple({ VAR(self.x), VAR(self.y) }));
  289. });
  290. vm->bind__repr__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* obj){
  291. PyVec2& self = _CAST(PyVec2&, obj);
  292. std::stringstream ss;
  293. ss << "vec2(" << self.x << ", " << self.y << ")";
  294. return VAR(ss.str());
  295. });
  296. vm->bind_method<0>(type, "copy", [](VM* vm, ArgsView args){
  297. PyVec2& self = _CAST(PyVec2&, args[0]);
  298. return VAR_T(PyVec2, self);
  299. });
  300. vm->bind_method<1>(type, "rotate", [](VM* vm, ArgsView args){
  301. Vec2 self = _CAST(PyVec2&, args[0]);
  302. float radian = vm->num_to_float(args[1]);
  303. float cr = cosf(radian);
  304. float sr = sinf(radian);
  305. Mat3x3 rotate(cr, -sr, 0.0f,
  306. sr, cr, 0.0f,
  307. 0.0f, 0.0f, 1.0f);
  308. self = rotate.transform_vector(self);
  309. return VAR(self);
  310. });
  311. BIND_VEC_VEC_OP(2, __add__, +)
  312. BIND_VEC_VEC_OP(2, __sub__, -)
  313. BIND_VEC_FLOAT_OP(2, __mul__, *)
  314. BIND_VEC_FLOAT_OP(2, __rmul__, *)
  315. BIND_VEC_FLOAT_OP(2, __truediv__, /)
  316. BIND_VEC_VEC_OP(2, __eq__, ==)
  317. BIND_VEC_FIELD(2, x)
  318. BIND_VEC_FIELD(2, y)
  319. BIND_VEC_FUNCTION_1(2, dot)
  320. BIND_VEC_FUNCTION_1(2, cross)
  321. BIND_VEC_FUNCTION_0(2, length)
  322. BIND_VEC_FUNCTION_0(2, length_squared)
  323. BIND_VEC_FUNCTION_0(2, normalize)
  324. }
  325. };
  326. struct PyVec3: Vec3 {
  327. PY_CLASS(PyVec3, linalg, vec3)
  328. PyVec3() : Vec3() {}
  329. PyVec3(const Vec3& v) : Vec3(v) {}
  330. PyVec3(const PyVec3& v) : Vec3(v) {}
  331. static void _register(VM* vm, PyObject* mod, PyObject* type){
  332. vm->bind_constructor<4>(type, [](VM* vm, ArgsView args){
  333. float x = CAST_F(args[1]);
  334. float y = CAST_F(args[2]);
  335. float z = CAST_F(args[3]);
  336. return VAR(Vec3(x, y, z));
  337. });
  338. vm->bind_method<0>(type, "__getnewargs__", [](VM* vm, ArgsView args){
  339. PyVec3& self = _CAST(PyVec3&, args[0]);
  340. return VAR(Tuple({ VAR(self.x), VAR(self.y), VAR(self.z) }));
  341. });
  342. vm->bind__repr__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* obj){
  343. PyVec3& self = _CAST(PyVec3&, obj);
  344. std::stringstream ss;
  345. ss << "vec3(" << self.x << ", " << self.y << ", " << self.z << ")";
  346. return VAR(ss.str());
  347. });
  348. vm->bind_method<0>(type, "copy", [](VM* vm, ArgsView args){
  349. PyVec3& self = _CAST(PyVec3&, args[0]);
  350. return VAR_T(PyVec3, self);
  351. });
  352. BIND_VEC_VEC_OP(3, __add__, +)
  353. BIND_VEC_VEC_OP(3, __sub__, -)
  354. BIND_VEC_FLOAT_OP(3, __mul__, *)
  355. BIND_VEC_FLOAT_OP(3, __rmul__, *)
  356. BIND_VEC_FLOAT_OP(3, __truediv__, /)
  357. BIND_VEC_VEC_OP(3, __eq__, ==)
  358. BIND_VEC_FIELD(3, x)
  359. BIND_VEC_FIELD(3, y)
  360. BIND_VEC_FIELD(3, z)
  361. BIND_VEC_FUNCTION_1(3, dot)
  362. BIND_VEC_FUNCTION_1(3, cross)
  363. BIND_VEC_FUNCTION_0(3, length)
  364. BIND_VEC_FUNCTION_0(3, length_squared)
  365. BIND_VEC_FUNCTION_0(3, normalize)
  366. }
  367. };
  368. struct PyMat3x3: Mat3x3{
  369. PY_CLASS(PyMat3x3, linalg, mat3x3)
  370. PyMat3x3(): Mat3x3(){}
  371. PyMat3x3(const Mat3x3& other): Mat3x3(other){}
  372. PyMat3x3(const PyMat3x3& other): Mat3x3(other){}
  373. static void _register(VM* vm, PyObject* mod, PyObject* type){
  374. vm->bind_constructor<-1>(type, [](VM* vm, ArgsView args){
  375. if(args.size() == 1+0) return VAR_T(PyMat3x3, Mat3x3::zeros());
  376. if(args.size() == 1+9){
  377. Mat3x3 mat;
  378. for(int i=0; i<9; i++) mat.v[i] = CAST_F(args[1+i]);
  379. return VAR_T(PyMat3x3, mat);
  380. }
  381. if(args.size() == 1+1){
  382. List& a = CAST(List&, args[1]);
  383. if(a.size() != 3) vm->ValueError("Mat3x3.__new__ takes 3x3 list");
  384. Mat3x3 mat;
  385. for(int i=0; i<3; i++){
  386. List& b = CAST(List&, a[i]);
  387. if(b.size() != 3) vm->ValueError("Mat3x3.__new__ takes 3x3 list");
  388. for(int j=0; j<3; j++){
  389. mat.m[i][j] = CAST_F(b[j]);
  390. }
  391. }
  392. return VAR_T(PyMat3x3, mat);
  393. }
  394. vm->TypeError("Mat3x3.__new__ takes 0 or 1 arguments");
  395. return vm->None;
  396. });
  397. vm->bind_method<0>(type, "__getnewargs__", [](VM* vm, ArgsView args){
  398. PyMat3x3& self = _CAST(PyMat3x3&, args[0]);
  399. Tuple t(9);
  400. for(int i=0; i<9; i++) t[i] = VAR(self.v[i]);
  401. return VAR(std::move(t));
  402. });
  403. #define METHOD_PROXY_NONE(name) \
  404. vm->bind_method<0>(type, #name, [](VM* vm, ArgsView args){ \
  405. PyMat3x3& self = _CAST(PyMat3x3&, args[0]); \
  406. self.name(); \
  407. return vm->None; \
  408. });
  409. METHOD_PROXY_NONE(set_zeros)
  410. METHOD_PROXY_NONE(set_ones)
  411. METHOD_PROXY_NONE(set_identity)
  412. #undef METHOD_PROXY_NONE
  413. vm->bind__repr__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* obj){
  414. PyMat3x3& self = _CAST(PyMat3x3&, obj);
  415. std::stringstream ss;
  416. ss << std::fixed << std::setprecision(4);
  417. ss << "mat3x3([[" << self._11 << ", " << self._12 << ", " << self._13 << "],\n";
  418. ss << " [" << self._21 << ", " << self._22 << ", " << self._23 << "],\n";
  419. ss << " [" << self._31 << ", " << self._32 << ", " << self._33 << "]])";
  420. return VAR(ss.str());
  421. });
  422. vm->bind_method<0>(type, "copy", [](VM* vm, ArgsView args){
  423. PyMat3x3& self = _CAST(PyMat3x3&, args[0]);
  424. return VAR_T(PyMat3x3, self);
  425. });
  426. vm->bind__getitem__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* obj, PyObject* index){
  427. PyMat3x3& self = _CAST(PyMat3x3&, obj);
  428. Tuple& t = CAST(Tuple&, index);
  429. if(t.size() != 2){
  430. vm->TypeError("Mat3x3.__getitem__ takes a tuple of 2 integers");
  431. return vm->None;
  432. }
  433. i64 i = CAST(i64, t[0]);
  434. i64 j = CAST(i64, t[1]);
  435. if(i < 0 || i >= 3 || j < 0 || j >= 3){
  436. vm->IndexError("index out of range");
  437. return vm->None;
  438. }
  439. return VAR(self.m[i][j]);
  440. });
  441. vm->bind_method<2>(type, "__setitem__", [](VM* vm, ArgsView args){
  442. PyMat3x3& self = _CAST(PyMat3x3&, args[0]);
  443. Tuple& t = CAST(Tuple&, args[1]);
  444. if(t.size() != 2){
  445. vm->TypeError("Mat3x3.__setitem__ takes a tuple of 2 integers");
  446. return vm->None;
  447. }
  448. i64 i = CAST(i64, t[0]);
  449. i64 j = CAST(i64, t[1]);
  450. if(i < 0 || i >= 3 || j < 0 || j >= 3){
  451. vm->IndexError("index out of range");
  452. return vm->None;
  453. }
  454. self.m[i][j] = CAST_F(args[2]);
  455. return vm->None;
  456. });
  457. #define PROPERTY_FIELD(field) \
  458. type->attr().set(#field, vm->property([](VM* vm, ArgsView args){ \
  459. PyMat3x3& self = _CAST(PyMat3x3&, args[0]); \
  460. return VAR(self.field); \
  461. }, [](VM* vm, ArgsView args){ \
  462. PyMat3x3& self = _CAST(PyMat3x3&, args[0]); \
  463. self.field = vm->num_to_float(args[1]); \
  464. return vm->None; \
  465. }));
  466. PROPERTY_FIELD(_11)
  467. PROPERTY_FIELD(_12)
  468. PROPERTY_FIELD(_13)
  469. PROPERTY_FIELD(_21)
  470. PROPERTY_FIELD(_22)
  471. PROPERTY_FIELD(_23)
  472. PROPERTY_FIELD(_31)
  473. PROPERTY_FIELD(_32)
  474. PROPERTY_FIELD(_33)
  475. #undef PROPERTY_FIELD
  476. vm->bind_method<1>(type, "__add__", [](VM* vm, ArgsView args){
  477. PyMat3x3& self = _CAST(PyMat3x3&, args[0]);
  478. PyMat3x3& other = CAST(PyMat3x3&, args[1]);
  479. return VAR_T(PyMat3x3, self + other);
  480. });
  481. vm->bind_method<1>(type, "__sub__", [](VM* vm, ArgsView args){
  482. PyMat3x3& self = _CAST(PyMat3x3&, args[0]);
  483. PyMat3x3& other = CAST(PyMat3x3&, args[1]);
  484. return VAR_T(PyMat3x3, self - other);
  485. });
  486. vm->bind_method<1>(type, "__mul__", [](VM* vm, ArgsView args){
  487. PyMat3x3& self = _CAST(PyMat3x3&, args[0]);
  488. f64 other = CAST_F(args[1]);
  489. return VAR_T(PyMat3x3, self * other);
  490. });
  491. vm->bind_method<1>(type, "__rmul__", [](VM* vm, ArgsView args){
  492. PyMat3x3& self = _CAST(PyMat3x3&, args[1]);
  493. f64 other = CAST_F(args[0]);
  494. return VAR_T(PyMat3x3, self * other);
  495. });
  496. vm->bind_method<1>(type, "__truediv__", [](VM* vm, ArgsView args){
  497. PyMat3x3& self = _CAST(PyMat3x3&, args[0]);
  498. f64 other = CAST_F(args[1]);
  499. return VAR_T(PyMat3x3, self / other);
  500. });
  501. auto f_mm = [](VM* vm, ArgsView args){
  502. PyMat3x3& self = _CAST(PyMat3x3&, args[0]);
  503. if(is_non_tagged_type(args[1], PyMat3x3::_type(vm))){
  504. PyMat3x3& other = _CAST(PyMat3x3&, args[1]);
  505. return VAR_T(PyMat3x3, self.matmul(other));
  506. }
  507. if(is_non_tagged_type(args[1], PyVec3::_type(vm))){
  508. PyVec3& other = _CAST(PyVec3&, args[1]);
  509. return VAR_T(PyVec3, self.matmul(other));
  510. }
  511. vm->BinaryOptError("@");
  512. return vm->None;
  513. };
  514. vm->bind_method<1>(type, "__matmul__", f_mm);
  515. vm->bind_method<1>(type, "matmul", f_mm);
  516. vm->bind_method<1>(type, "__eq__", [](VM* vm, ArgsView args){
  517. PyMat3x3& self = _CAST(PyMat3x3&, args[0]);
  518. PyMat3x3& other = CAST(PyMat3x3&, args[1]);
  519. return VAR(self == other);
  520. });
  521. vm->bind_method<0>(type, "determinant", [](VM* vm, ArgsView args){
  522. PyMat3x3& self = _CAST(PyMat3x3&, args[0]);
  523. return VAR(self.determinant());
  524. });
  525. vm->bind_method<0>(type, "transpose", [](VM* vm, ArgsView args){
  526. PyMat3x3& self = _CAST(PyMat3x3&, args[0]);
  527. return VAR_T(PyMat3x3, self.transpose());
  528. });
  529. vm->bind_method<0>(type, "inverse", [](VM* vm, ArgsView args){
  530. PyMat3x3& self = _CAST(PyMat3x3&, args[0]);
  531. Mat3x3 ret;
  532. bool ok = self.inverse(ret);
  533. if(!ok) vm->ValueError("matrix is not invertible");
  534. return VAR_T(PyMat3x3, ret);
  535. });
  536. vm->bind_func<0>(type, "zeros", [](VM* vm, ArgsView args){
  537. return VAR_T(PyMat3x3, Mat3x3::zeros());
  538. });
  539. vm->bind_func<0>(type, "ones", [](VM* vm, ArgsView args){
  540. return VAR_T(PyMat3x3, Mat3x3::ones());
  541. });
  542. vm->bind_func<0>(type, "identity", [](VM* vm, ArgsView args){
  543. return VAR_T(PyMat3x3, Mat3x3::identity());
  544. });
  545. /*************** affine transformations ***************/
  546. vm->bind_func<3>(type, "trs", [](VM* vm, ArgsView args){
  547. PyVec2& t = CAST(PyVec2&, args[0]);
  548. f64 r = CAST_F(args[1]);
  549. PyVec2& s = CAST(PyVec2&, args[2]);
  550. return VAR_T(PyMat3x3, Mat3x3::trs(t, r, s));
  551. });
  552. vm->bind_method<0>(type, "is_affine", [](VM* vm, ArgsView args){
  553. PyMat3x3& self = _CAST(PyMat3x3&, args[0]);
  554. return VAR(self.is_affine());
  555. });
  556. vm->bind_method<0>(type, "inverse_affine", [](VM* vm, ArgsView args){
  557. PyMat3x3& self = _CAST(PyMat3x3&, args[0]);
  558. return VAR_T(PyMat3x3, self.inverse_affine());
  559. });
  560. vm->bind_method<1>(type, "matmul_affine", [](VM* vm, ArgsView args){
  561. PyMat3x3& self = _CAST(PyMat3x3&, args[0]);
  562. PyMat3x3& other = CAST(PyMat3x3&, args[1]);
  563. return VAR_T(PyMat3x3, self.matmul_affine(other));
  564. });
  565. vm->bind_method<0>(type, "translation", [](VM* vm, ArgsView args){
  566. PyMat3x3& self = _CAST(PyMat3x3&, args[0]);
  567. return VAR_T(PyVec2, self.translation());
  568. });
  569. vm->bind_method<0>(type, "rotation", [](VM* vm, ArgsView args){
  570. PyMat3x3& self = _CAST(PyMat3x3&, args[0]);
  571. return VAR(self.rotation());
  572. });
  573. vm->bind_method<0>(type, "scale", [](VM* vm, ArgsView args){
  574. PyMat3x3& self = _CAST(PyMat3x3&, args[0]);
  575. return VAR_T(PyVec2, self.scale());
  576. });
  577. vm->bind_method<1>(type, "transform_point", [](VM* vm, ArgsView args){
  578. PyMat3x3& self = _CAST(PyMat3x3&, args[0]);
  579. PyVec2& v = CAST(PyVec2&, args[1]);
  580. return VAR_T(PyVec2, self.transform_point(v));
  581. });
  582. vm->bind_method<1>(type, "transform_vector", [](VM* vm, ArgsView args){
  583. PyMat3x3& self = _CAST(PyMat3x3&, args[0]);
  584. PyVec2& v = CAST(PyVec2&, args[1]);
  585. return VAR_T(PyVec2, self.transform_vector(v));
  586. });
  587. }
  588. };
  589. inline PyObject* py_var(VM* vm, Vec2 obj){ return VAR_T(PyVec2, obj); }
  590. inline PyObject* py_var(VM* vm, const PyVec2& obj){ return VAR_T(PyVec2, obj);}
  591. inline PyObject* py_var(VM* vm, Vec3 obj){ return VAR_T(PyVec3, obj); }
  592. inline PyObject* py_var(VM* vm, const PyVec3& obj){ return VAR_T(PyVec3, obj);}
  593. inline PyObject* py_var(VM* vm, const Mat3x3& obj){ return VAR_T(PyMat3x3, obj); }
  594. inline PyObject* py_var(VM* vm, const PyMat3x3& obj){ return VAR_T(PyMat3x3, obj); }
  595. inline void add_module_linalg(VM* vm){
  596. PyObject* linalg = vm->new_module("linalg");
  597. PyVec2::register_class(vm, linalg);
  598. PyVec3::register_class(vm, linalg);
  599. PyMat3x3::register_class(vm, linalg);
  600. }
  601. static_assert(sizeof(Py_<PyMat3x3>) <= 64);
  602. } // namespace pkpy
  603. #else
  604. ADD_MODULE_PLACEHOLDER(linalg)
  605. #endif