blueloveTH 2 lat temu
rodzic
commit
46eadebfc0
5 zmienionych plików z 68 dodań i 47 usunięć
  1. 3 0
      docs/modules/linalg.md
  2. 24 28
      include/pocketpy/linalg.h
  3. 3 0
      include/typings/linalg.pyi
  4. 33 19
      src/linalg.cpp
  5. 5 0
      tests/80_linalg.py

+ 3 - 0
docs/modules/linalg.md

@@ -115,6 +115,9 @@ class mat3x3(_StructLike['mat3x3']):
     @overload
     def __matmul__(self, other: vec3) -> vec3: ...
 
+    def __imatmul__(self, other: mat3x3) -> None: ...
+    def invert_(self) -> None: ...
+
     @staticmethod
     def zeros() -> mat3x3: ...
     @staticmethod

+ 24 - 28
include/pocketpy/linalg.h

@@ -158,26 +158,22 @@ struct Mat3x3{
         return *this;
     }
 
-    Mat3x3 matmul(const Mat3x3& other) const{
-        Mat3x3 ret;
-        ret._11 = _11 * other._11 + _12 * other._21 + _13 * other._31;
-        ret._12 = _11 * other._12 + _12 * other._22 + _13 * other._32;
-        ret._13 = _11 * other._13 + _12 * other._23 + _13 * other._33;
-        ret._21 = _21 * other._11 + _22 * other._21 + _23 * other._31;
-        ret._22 = _21 * other._12 + _22 * other._22 + _23 * other._32;
-        ret._23 = _21 * other._13 + _22 * other._23 + _23 * other._33;
-        ret._31 = _31 * other._11 + _32 * other._21 + _33 * other._31;
-        ret._32 = _31 * other._12 + _32 * other._22 + _33 * other._32;
-        ret._33 = _31 * other._13 + _32 * other._23 + _33 * other._33;
-        return ret;
+    void matmul(const Mat3x3& other, Mat3x3& out) const{
+        out._11 = _11 * other._11 + _12 * other._21 + _13 * other._31;
+        out._12 = _11 * other._12 + _12 * other._22 + _13 * other._32;
+        out._13 = _11 * other._13 + _12 * other._23 + _13 * other._33;
+        out._21 = _21 * other._11 + _22 * other._21 + _23 * other._31;
+        out._22 = _21 * other._12 + _22 * other._22 + _23 * other._32;
+        out._23 = _21 * other._13 + _22 * other._23 + _23 * other._33;
+        out._31 = _31 * other._11 + _32 * other._21 + _33 * other._31;
+        out._32 = _31 * other._12 + _32 * other._22 + _33 * other._32;
+        out._33 = _31 * other._13 + _32 * other._23 + _33 * other._33;
     }
 
-    Vec3 matmul(const Vec3& other) const{
-        Vec3 ret;
-        ret.x = _11 * other.x + _12 * other.y + _13 * other.z;
-        ret.y = _21 * other.x + _22 * other.y + _23 * other.z;
-        ret.z = _31 * other.x + _32 * other.y + _33 * other.z;
-        return ret;
+    void matmul(const Vec3& other, Vec3& out) const{
+        out.x = _11 * other.x + _12 * other.y + _13 * other.z;
+        out.y = _21 * other.x + _22 * other.y + _23 * other.z;
+        out.z = _31 * other.x + _32 * other.y + _33 * other.z;
     }
 
     bool operator==(const Mat3x3& other) const{
@@ -207,19 +203,19 @@ struct Mat3x3{
         return ret;
     }
 
-    bool inverse(Mat3x3& ret) const{
+    bool inverse(Mat3x3& out) const{
         float det = determinant();
         if (isclose(det, 0)) return false;
         float inv_det = 1.0f / det;
-        ret._11 = (_22 * _33 - _23 * _32) * inv_det;
-        ret._12 = (_13 * _32 - _12 * _33) * inv_det;
-        ret._13 = (_12 * _23 - _13 * _22) * inv_det;
-        ret._21 = (_23 * _31 - _21 * _33) * inv_det;
-        ret._22 = (_11 * _33 - _13 * _31) * inv_det;
-        ret._23 = (_13 * _21 - _11 * _23) * inv_det;
-        ret._31 = (_21 * _32 - _22 * _31) * inv_det;
-        ret._32 = (_12 * _31 - _11 * _32) * inv_det;
-        ret._33 = (_11 * _22 - _12 * _21) * inv_det;
+        out._11 = (_22 * _33 - _23 * _32) * inv_det;
+        out._12 = (_13 * _32 - _12 * _33) * inv_det;
+        out._13 = (_12 * _23 - _13 * _22) * inv_det;
+        out._21 = (_23 * _31 - _21 * _33) * inv_det;
+        out._22 = (_11 * _33 - _13 * _31) * inv_det;
+        out._23 = (_13 * _21 - _11 * _23) * inv_det;
+        out._31 = (_21 * _32 - _22 * _31) * inv_det;
+        out._32 = (_12 * _31 - _11 * _32) * inv_det;
+        out._33 = (_11 * _22 - _12 * _21) * inv_det;
         return true;
     }
 

+ 3 - 0
include/typings/linalg.pyi

@@ -105,6 +105,9 @@ class mat3x3(_StructLike['mat3x3']):
     @overload
     def __matmul__(self, other: vec3) -> vec3: ...
 
+    def __imatmul__(self, other: mat3x3) -> None: ...
+    def invert_(self) -> None: ...
+
     @staticmethod
     def zeros() -> mat3x3: ...
     @staticmethod

+ 33 - 19
src/linalg.cpp

@@ -279,26 +279,17 @@ static Vec2 SmoothDamp(Vec2 current, Vec2 target, PyVec2& currentVelocity, float
             return VAR(std::move(t));
         });
 
-#define METHOD_PROXY_NONE(name)  \
-        vm->bind_method<0>(type, #name, [](VM* vm, ArgsView args){    \
-            PyMat3x3& self = _CAST(PyMat3x3&, args[0]);               \
-            self.name();                                              \
-            return vm->None;                                          \
-        });
-
-        METHOD_PROXY_NONE(set_zeros)
-        METHOD_PROXY_NONE(set_ones)
-        METHOD_PROXY_NONE(set_identity)
-
-#undef METHOD_PROXY_NONE
+        vm->bind_method<0>(type, "set_zeros", PK_ACTION(PK_OBJ_GET(PyMat3x3, args[0]).set_zeros()));
+        vm->bind_method<0>(type, "set_ones", PK_ACTION(PK_OBJ_GET(PyMat3x3, args[0]).set_ones()));
+        vm->bind_method<0>(type, "set_identity", PK_ACTION(PK_OBJ_GET(PyMat3x3, args[0]).set_identity()));
 
         vm->bind__repr__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* obj){
             PyMat3x3& self = _CAST(PyMat3x3&, obj);
             std::stringstream ss;
             ss << std::fixed << std::setprecision(3);
-            ss << "mat3x3([[" << self._11 << ", " << self._12 << ", " << self._13 << "],\n";
-            ss << "        [" << self._21 << ", " << self._22 << ", " << self._23 << "],\n";
-            ss << "        [" << self._31 << ", " << self._32 << ", " << self._33 << "]])";
+            ss << "mat3x3([" << self._11 << ", " << self._12 << ", " << self._13 << ",\n";
+            ss << "        " << self._21 << ", " << self._22 << ", " << self._23 << ",\n";
+            ss << "        " << self._31 << ", " << self._32 << ", " << self._33 << "])";
             return VAR(ss.str());
         });
 
@@ -390,16 +381,30 @@ static Vec2 SmoothDamp(Vec2 current, Vec2 target, PyVec2& currentVelocity, float
         vm->bind__matmul__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* _0, PyObject* _1){
             PyMat3x3& self = _CAST(PyMat3x3&, _0);
             if(is_non_tagged_type(_1, PyMat3x3::_type(vm))){
-                PyMat3x3& other = _CAST(PyMat3x3&, _1);
-                return VAR_T(PyMat3x3, self.matmul(other));
+                const PyMat3x3& other = _CAST(PyMat3x3&, _1);
+                Mat3x3 out;
+                self.matmul(other, out);
+                return VAR_T(PyMat3x3, out);
             }
             if(is_non_tagged_type(_1, PyVec3::_type(vm))){
-                PyVec3& other = _CAST(PyVec3&, _1);
-                return VAR_T(PyVec3, self.matmul(other));
+                const PyVec3& other = _CAST(PyVec3&, _1);
+                Vec3 out;
+                self.matmul(other, out);
+                return VAR_T(PyVec3, out);
             }
             return vm->NotImplemented;
         });
 
+        vm->bind_method<1>(type, "__imatmul__", [](VM* vm, ArgsView args){
+            PyMat3x3& self = _CAST(PyMat3x3&, args[0]);
+            vm->check_non_tagged_type(args[1], PyMat3x3::_type(vm));
+            const PyMat3x3& other = _CAST(PyMat3x3&, args[1]);
+            Mat3x3 out;
+            self.matmul(other, out);
+            self = out;
+            return vm->None;
+        });
+
         vm->bind_method<0>(type, "determinant", [](VM* vm, ArgsView args){
             PyMat3x3& self = _CAST(PyMat3x3&, args[0]);
             return VAR(self.determinant());
@@ -418,6 +423,15 @@ static Vec2 SmoothDamp(Vec2 current, Vec2 target, PyVec2& currentVelocity, float
             return VAR_T(PyMat3x3, ret);
         });
 
+        vm->bind_method<0>(type, "invert_", [](VM* vm, ArgsView args){
+            PyMat3x3& self = _CAST(PyMat3x3&, args[0]);
+            Mat3x3 ret;
+            bool ok = self.inverse(ret);
+            if(!ok) vm->ValueError("matrix is not invertible");
+            self = ret;
+            return vm->None;
+        });
+
         // @staticmethod
         vm->bind(type, "zeros()", [](VM* vm, ArgsView args){
             PK_UNUSED(args);

+ 5 - 0
tests/80_linalg.py

@@ -330,6 +330,9 @@ for i in range(3):
         correct_result_mat[i, j] = sum([e1*e2 for e1, e2 in zip(get_row(test_mat_copy, i), get_col(test_mat_copy_2, j))])
 assert result_mat == correct_result_mat
 
+test_mat_copy.__imatmul__(test_mat_copy_2)
+assert test_mat_copy == correct_result_mat
+
 # test determinant
 test_mat_copy = test_mat.copy()
 test_mat_copy.determinant()
@@ -382,6 +385,8 @@ assert test_mat_copy.transpose() == test_mat_copy.transpose().transpose().transp
 
 # test inverse
 assert ~static_test_mat_float == static_test_mat_float_inv
+assert static_test_mat_float.invert_() is None
+assert static_test_mat_float == static_test_mat_float_inv
 
 try:
     ~mat3x3([1, 2, 3, 2, 4, 6, 3, 6, 9])