blueloveTH 2 лет назад
Родитель
Сommit
8d9761c6b5
2 измененных файлов с 19 добавлено и 9 удалено
  1. 14 9
      src/linalg.h
  2. 5 0
      src/linalg.pyi

+ 14 - 9
src/linalg.h

@@ -562,17 +562,22 @@ struct PyMat3x3: Mat3x3{
             return VAR_T(PyMat3x3, self / other);
         });
 
-        vm->bind_method<1>(type, "__matmul__", [](VM* vm, ArgsView args){
+        auto f_mm = [](VM* vm, ArgsView args){
             PyMat3x3& self = _CAST(PyMat3x3&, args[0]);
-            PyMat3x3& other = CAST(PyMat3x3&, args[1]);
-            return VAR_T(PyMat3x3, self.matmul(other));
-        });
+            if(is_non_tagged_type(args[1], PyMat3x3::_type(vm))){
+                PyMat3x3& other = _CAST(PyMat3x3&, args[1]);
+                return VAR_T(PyMat3x3, self.matmul(other));
+            }
+            if(is_non_tagged_type(args[1], PyVec3::_type(vm))){
+                PyVec3& other = _CAST(PyVec3&, args[1]);
+                return VAR_T(PyVec3, self.matmul(other));
+            }
+            vm->TypeError("unsupported operand type(s) for @");
+            return vm->None;
+        };
 
-        vm->bind_method<1>(type, "matmul", [](VM* vm, ArgsView args){
-            PyMat3x3& self = _CAST(PyMat3x3&, args[0]);
-            PyMat3x3& other = CAST(PyMat3x3&, args[1]);
-            return VAR_T(PyMat3x3, self.matmul(other));
-        });
+        vm->bind_method<1>(type, "__matmul__", f_mm);
+        vm->bind_method<1>(type, "matmul", f_mm);
 
         vm->bind_method<1>(type, "__eq__", [](VM* vm, ArgsView args){
             PyMat3x3& self = _CAST(PyMat3x3&, args[0]);

+ 5 - 0
src/linalg.pyi

@@ -64,8 +64,13 @@ class mat3x3:
     def __sub__(self, other: mat3x3) -> mat3x3: ...
     def __mul__(self, other: float) -> mat3x3: ...
     def __truediv__(self, other: float) -> mat3x3: ...
+    @overload
     def __matmul__(self, other: mat3x3) -> mat3x3: ...
+    @overload
+    def __matmul__(self, other: vec3) -> vec3: ...
+    @overload
     def matmul(self, other: mat3x3) -> mat3x3: ...
+    @overload
     def matmul(self, other: vec3) -> vec3: ...
     def determinant(self) -> float: ...
     def transpose(self) -> mat3x3: ...