فهرست منبع

allows `vec` * `vec`

blueloveTH 2 سال پیش
والد
کامیت
a0a5753283
4فایلهای تغییر یافته به همراه58 افزوده شده و 31 حذف شده
  1. 3 12
      include/pocketpy/linalg.h
  2. 15 0
      include/typings/linalg.pyi
  3. 34 18
      src/linalg.cpp
  4. 6 1
      tests/80_linalg.py

+ 3 - 12
include/pocketpy/linalg.h

@@ -13,13 +13,10 @@ struct Vec2{
     Vec2(const Vec2& v) = default;
 
     Vec2 operator+(const Vec2& v) const { return Vec2(x + v.x, y + v.y); }
-    Vec2& operator+=(const Vec2& v) { x += v.x; y += v.y; return *this; }
     Vec2 operator-(const Vec2& v) const { return Vec2(x - v.x, y - v.y); }
-    Vec2& operator-=(const Vec2& v) { x -= v.x; y -= v.y; return *this; }
     Vec2 operator*(float s) const { return Vec2(x * s, y * s); }
-    Vec2& operator*=(float s) { x *= s; y *= s; return *this; }
+    Vec2 operator*(const Vec2& v) const { return Vec2(x * v.x, y * v.y); }
     Vec2 operator/(float s) const { return Vec2(x / s, y / s); }
-    Vec2& operator/=(float s) { x /= s; y /= s; return *this; }
     Vec2 operator-() const { return Vec2(-x, -y); }
     bool operator==(const Vec2& v) const { return isclose(x, v.x) && isclose(y, v.y); }
     bool operator!=(const Vec2& v) const { return !isclose(x, v.x) || !isclose(y, v.y); }
@@ -40,13 +37,10 @@ struct Vec3{
     Vec3(const Vec3& v) = default;
 
     Vec3 operator+(const Vec3& v) const { return Vec3(x + v.x, y + v.y, z + v.z); }
-    Vec3& operator+=(const Vec3& v) { x += v.x; y += v.y; z += v.z; return *this; }
     Vec3 operator-(const Vec3& v) const { return Vec3(x - v.x, y - v.y, z - v.z); }
-    Vec3& operator-=(const Vec3& v) { x -= v.x; y -= v.y; z -= v.z; return *this; }
     Vec3 operator*(float s) const { return Vec3(x * s, y * s, z * s); }
-    Vec3& operator*=(float s) { x *= s; y *= s; z *= s; return *this; }
+    Vec3 operator*(const Vec3& v) const { return Vec3(x * v.x, y * v.y, z * v.z); }
     Vec3 operator/(float s) const { return Vec3(x / s, y / s, z / s); }
-    Vec3& operator/=(float s) { x /= s; y /= s; z /= s; return *this; }
     Vec3 operator-() const { return Vec3(-x, -y, -z); }
     bool operator==(const Vec3& v) const { return isclose(x, v.x) && isclose(y, v.y) && isclose(z, v.z); }
     bool operator!=(const Vec3& v) const { return !isclose(x, v.x) || !isclose(y, v.y) || !isclose(z, v.z); }
@@ -66,13 +60,10 @@ struct Vec4{
     Vec4(const Vec4& v) = default;
 
     Vec4 operator+(const Vec4& v) const { return Vec4(x + v.x, y + v.y, z + v.z, w + v.w); }
-    Vec4& operator+=(const Vec4& v) { x += v.x; y += v.y; z += v.z; w += v.w; return *this; }
     Vec4 operator-(const Vec4& v) const { return Vec4(x - v.x, y - v.y, z - v.z, w - v.w); }
-    Vec4& operator-=(const Vec4& v) { x -= v.x; y -= v.y; z -= v.z; w -= v.w; return *this; }
     Vec4 operator*(float s) const { return Vec4(x * s, y * s, z * s, w * s); }
-    Vec4& operator*=(float s) { x *= s; y *= s; z *= s; w *= s; return *this; }
+    Vec4 operator*(const Vec4& v) const { return Vec4(x * v.x, y * v.y, z * v.z, w * v.w); }
     Vec4 operator/(float s) const { return Vec4(x / s, y / s, z / s, w / s); }
-    Vec4& operator/=(float s) { x /= s; y /= s; z /= s; w /= s; return *this; }
     Vec4 operator-() const { return Vec4(-x, -y, -z, -w); }
     bool operator==(const Vec4& v) const { return isclose(x, v.x) && isclose(y, v.y) && isclose(z, v.z) && isclose(w, v.w); }
     bool operator!=(const Vec4& v) const { return !isclose(x, v.x) || !isclose(y, v.y) || !isclose(z, v.z) || !isclose(w, v.w); }

+ 15 - 0
include/typings/linalg.pyi

@@ -8,7 +8,12 @@ class vec2(_StructLike['vec2']):
     def __init__(self, x: float, y: float) -> None: ...
     def __add__(self, other: vec2) -> vec2: ...
     def __sub__(self, other: vec2) -> vec2: ...
+
+    @overload
     def __mul__(self, other: float) -> vec2: ...
+    @overload
+    def __mul__(self, other: vec2) -> vec2: ...
+
     def __rmul__(self, other: float) -> vec2: ...
     def __truediv__(self, other: float) -> vec2: ...
     def dot(self, other: vec2) -> float: ...
@@ -44,7 +49,12 @@ class vec3(_StructLike['vec3']):
     def __init__(self, x: float, y: float, z: float) -> None: ...
     def __add__(self, other: vec3) -> vec3: ...
     def __sub__(self, other: vec3) -> vec3: ...
+
+    @overload
     def __mul__(self, other: float) -> vec3: ...
+    @overload
+    def __mul__(self, other: vec3) -> vec3: ...
+
     def __rmul__(self, other: float) -> vec3: ...
     def __truediv__(self, other: float) -> vec3: ...
     def dot(self, other: vec3) -> float: ...
@@ -65,7 +75,12 @@ class vec4(_StructLike['vec4']):
     def __init__(self, x: float, y: float, z: float, w: float) -> None: ...
     def __add__(self, other: vec4) -> vec4: ...
     def __sub__(self, other: vec4) -> vec4: ...
+
+    @overload
     def __mul__(self, other: float) -> vec4: ...
+    @overload
+    def __mul__(self, other: vec4) -> vec4: ...
+
     def __rmul__(self, other: float) -> vec4: ...
     def __truediv__(self, other: float) -> vec4: ...
     def dot(self, other: vec4) -> float: ...

+ 34 - 18
src/linalg.cpp

@@ -2,18 +2,18 @@
 
 namespace pkpy{
 
-#define BIND_VEC_VEC_OP(D, name, op)                                        \
-        vm->bind_method<1>(type, #name, [](VM* vm, ArgsView args){          \
-            PyVec##D& self = _CAST(PyVec##D&, args[0]);                     \
-            PyVec##D& other = CAST(PyVec##D&, args[1]);                     \
-            return VAR(self op other);                                      \
+#define BIND_VEC_VEC_OP(D, name, op)                                                    \
+        vm->bind##name(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* _0, PyObject* _1){  \
+            PyVec##D& self = _CAST(PyVec##D&, _0);                                      \
+            PyVec##D& other = CAST(PyVec##D&, _1);                                      \
+            return VAR(self op other);                                                  \
         });
 
 #define BIND_VEC_FLOAT_OP(D, name, op)  \
-        vm->bind_method<1>(type, #name, [](VM* vm, ArgsView args){          \
-            PyVec##D& self = _CAST(PyVec##D&, args[0]);                     \
-            f64 other = CAST(f64, args[1]);                                 \
-            return VAR(self op other);                                      \
+        vm->bind##name(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* _0, PyObject* _1){  \
+            PyVec##D& self = _CAST(PyVec##D&, _0);                                      \
+            f64 other = CAST(f64, _1);                                                  \
+            return VAR(self op other);                                                  \
         });
 
 #define BIND_VEC_FUNCTION_0(D, name)        \
@@ -29,6 +29,27 @@ namespace pkpy{
             return VAR(self.name(other));                                   \
         });
 
+#define BIND_VEC_MUL_OP(D)                                                                \
+        vm->bind__mul__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* _0, PyObject* _1){     \
+            PyVec##D& self = _CAST(PyVec##D&, _0);                                          \
+            if(is_non_tagged_type(_1, PyVec##D::_type(vm))){                                \
+                PyVec##D& other = _CAST(PyVec##D&, _1);                                     \
+                return VAR(self * other);                                                   \
+            }                                                                               \
+            f64 other = CAST(f64, _1);                                                      \
+            return VAR(self * other);                                                       \
+        });                                                                                 \
+        vm->bind_method<1>(type, "__rmul__", [](VM* vm, ArgsView args){                     \
+            PyVec##D& self = _CAST(PyVec##D&, args[0]);                                     \
+            f64 other = CAST(f64, args[1]);                                                 \
+            return VAR(self * other);                                                       \
+        });                                                                                 \
+        vm->bind__truediv__(PK_OBJ_GET(Type, type), [](VM* vm, PyObject* _0, PyObject* _1){ \
+            PyVec##D& self = _CAST(PyVec##D&, _0);                                          \
+            f64 other = CAST(f64, _1);                                                      \
+            return VAR(self / other);                                                       \
+        });
+
 // https://github.com/Unity-Technologies/UnityCsReference/blob/master/Runtime/Export/Math/Vector2.cs#L289
 static Vec2 SmoothDamp(Vec2 current, Vec2 target, PyVec2& currentVelocity, float smoothTime, float maxSpeed, float deltaTime)
 {
@@ -142,8 +163,7 @@ static Vec2 SmoothDamp(Vec2 current, Vec2 target, PyVec2& currentVelocity, float
 
         BIND_VEC_VEC_OP(2, __add__, +)
         BIND_VEC_VEC_OP(2, __sub__, -)
-        BIND_VEC_FLOAT_OP(2, __mul__, *)
-        BIND_VEC_FLOAT_OP(2, __rmul__, *)
+        BIND_VEC_MUL_OP(2)
         BIND_VEC_FLOAT_OP(2, __truediv__, /)
         BIND_VEC_FUNCTION_1(2, dot)
         BIND_VEC_FUNCTION_1(2, cross)
@@ -178,9 +198,7 @@ static Vec2 SmoothDamp(Vec2 current, Vec2 target, PyVec2& currentVelocity, float
 
         BIND_VEC_VEC_OP(3, __add__, +)
         BIND_VEC_VEC_OP(3, __sub__, -)
-        BIND_VEC_FLOAT_OP(3, __mul__, *)
-        BIND_VEC_FLOAT_OP(3, __rmul__, *)
-        BIND_VEC_FLOAT_OP(3, __truediv__, /)
+        BIND_VEC_MUL_OP(3)
         BIND_VEC_FUNCTION_1(3, dot)
         BIND_VEC_FUNCTION_1(3, cross)
         BIND_VEC_FUNCTION_1(3, copy_)
@@ -216,9 +234,7 @@ static Vec2 SmoothDamp(Vec2 current, Vec2 target, PyVec2& currentVelocity, float
 
         BIND_VEC_VEC_OP(4, __add__, +)
         BIND_VEC_VEC_OP(4, __sub__, -)
-        BIND_VEC_FLOAT_OP(4, __mul__, *)
-        BIND_VEC_FLOAT_OP(4, __rmul__, *)
-        BIND_VEC_FLOAT_OP(4, __truediv__, /)
+        BIND_VEC_MUL_OP(4)
         BIND_VEC_FUNCTION_1(4, dot)
         BIND_VEC_FUNCTION_1(4, copy_)
         BIND_VEC_FUNCTION_0(4, length)
@@ -228,7 +244,7 @@ static Vec2 SmoothDamp(Vec2 current, Vec2 target, PyVec2& currentVelocity, float
     }
 
 #undef BIND_VEC_VEC_OP
-#undef BIND_VEC_FLOAT_OP
+#undef BIND_VEC_MUL_OP
 #undef BIND_VEC_FUNCTION_0
 #undef BIND_VEC_FUNCTION_1
 

+ 6 - 1
tests/80_linalg.py

@@ -481,4 +481,9 @@ try:
     assert d[6, 6]
     exit(1)
 except IndexError:
-    pass
+    pass
+
+# test vec * vec
+assert vec2(1, 2) * vec2(3, 4) == vec2(3, 8)
+assert vec3(1, 2, 3) * vec3(4, 5, 6) == vec3(4, 10, 18)
+assert vec4(1, 2, 3, 4) * vec4(5, 6, 7, 8) == vec4(5, 12, 21, 32)