瀏覽代碼

fix `list.sort`

blueloveTH 1 年之前
父節點
當前提交
22ae57fc9b
共有 4 個文件被更改,包括 43 次插入14 次删除
  1. 2 2
      include/pocketpy/common/algorithm.h
  2. 20 12
      src/common/algorithm.c
  3. 9 0
      src/public/py_list.c
  4. 12 0
      tests/05_list.py

+ 2 - 2
include/pocketpy/common/algorithm.h

@@ -37,10 +37,10 @@ extern "C" {
  * @param cmp Comparison function that takes two elements and returns an integer similar to
  * @param cmp Comparison function that takes two elements and returns an integer similar to
  * `strcmp`.
  * `strcmp`.
  */
  */
-void c11__stable_sort(void* ptr,
+bool c11__stable_sort(void* ptr,
                       int count,
                       int count,
                       int elem_size,
                       int elem_size,
-                      int (*cmp)(const void* a, const void* b));
+                      int (*f_le)(const void* a, const void* b));
 
 
 #ifdef __cplusplus
 #ifdef __cplusplus
 }
 }

+ 20 - 12
src/common/algorithm.c

@@ -2,16 +2,19 @@
 #include <string.h>
 #include <string.h>
 #include <stdlib.h>
 #include <stdlib.h>
 
 
-static void merge(char* a_begin,
+static bool merge(char* a_begin,
                   char* a_end,
                   char* a_end,
                   char* b_begin,
                   char* b_begin,
                   char* b_end,
                   char* b_end,
                   char* res,
                   char* res,
                   int elem_size,
                   int elem_size,
-                  int (*cmp)(const void* a, const void* b)) {
+                  int (*f_le)(const void* a, const void* b)) {
     char *a = a_begin, *b = b_begin, *r = res;
     char *a = a_begin, *b = b_begin, *r = res;
     while(a < a_end && b < b_end) {
     while(a < a_end && b < b_end) {
-        if(cmp(a, b) <= 0) {
+        int res = f_le(a, b);
+        // check error
+        if(res == -1) return false;
+        if(res) {
             memcpy(r, a, elem_size);
             memcpy(r, a, elem_size);
             a += elem_size;
             a += elem_size;
         } else {
         } else {
@@ -26,22 +29,27 @@ static void merge(char* a_begin,
         memcpy(r, a, elem_size);
         memcpy(r, a, elem_size);
     for(; b < b_end; b += elem_size, r += elem_size)
     for(; b < b_end; b += elem_size, r += elem_size)
         memcpy(r, b, elem_size);
         memcpy(r, b, elem_size);
+    return true;
 }
 }
 
 
-void c11__stable_sort(void* ptr_,
+bool c11__stable_sort(void* ptr_,
                       int count,
                       int count,
                       int elem_size,
                       int elem_size,
-                      int (*cmp)(const void* a, const void* b)) {
+                      int (*f_le)(const void* a, const void* b)) {
     // merge sort
     // merge sort
-    char* ptr = ptr_, *tmp = malloc(count * elem_size);
+    char *ptr = ptr_, *tmp = malloc(count * elem_size);
     for(int seg = 1; seg < count; seg *= 2) {
     for(int seg = 1; seg < count; seg *= 2) {
         for(char* a = ptr; a < ptr + (count - seg) * elem_size; a += 2 * seg * elem_size) {
         for(char* a = ptr; a < ptr + (count - seg) * elem_size; a += 2 * seg * elem_size) {
-            char* b = a + seg * elem_size, *a_end = b, *b_end = b + seg * elem_size;
-			if (b_end > ptr + count * elem_size)
-				b_end = ptr + count * elem_size;
-			merge(a, a_end, b, b_end, tmp, elem_size, cmp);
-			memcpy(a, tmp, b_end - a);
+            char *b = a + seg * elem_size, *a_end = b, *b_end = b + seg * elem_size;
+            if(b_end > ptr + count * elem_size) b_end = ptr + count * elem_size;
+            bool ok = merge(a, a_end, b, b_end, tmp, elem_size, f_le);
+            if(!ok) {
+                free(tmp);
+                return false;
+            }
+            memcpy(a, tmp, b_end - a);
         }
         }
     }
     }
-	free(tmp);
+    free(tmp);
+    return true;
 }
 }

+ 9 - 0
src/public/py_list.c

@@ -319,6 +319,14 @@ static bool _py_list__insert(int argc, py_Ref argv) {
     return true;
     return true;
 }
 }
 
 
+static bool _py_list__sort(int argc, py_Ref argv) {
+    PY_CHECK_ARGC(1);
+    List* self = py_touserdata(py_arg(0));
+    c11__stable_sort(self->data, self->count, sizeof(py_TValue), (int (*)(const void*, const void*))py_le);
+    py_newnone(py_retval());
+    return true;
+}
+
 py_Type pk_list__register() {
 py_Type pk_list__register() {
     pk_VM* vm = pk_current_vm;
     pk_VM* vm = pk_current_vm;
     py_Type type = pk_VM__new_type(vm, "list", tp_object, NULL, false);
     py_Type type = pk_VM__new_type(vm, "list", tp_object, NULL, false);
@@ -346,5 +354,6 @@ py_Type pk_list__register() {
     py_bindmethod(type, "remove", _py_list__remove);
     py_bindmethod(type, "remove", _py_list__remove);
     py_bindmethod(type, "pop", _py_list__pop);
     py_bindmethod(type, "pop", _py_list__pop);
     py_bindmethod(type, "insert", _py_list__insert);
     py_bindmethod(type, "insert", _py_list__insert);
+    py_bindmethod(type, "sort", _py_list__sort);
     return type;
     return type;
 }
 }

+ 12 - 0
tests/05_list.py

@@ -87,6 +87,18 @@ assert list(range(5, 1, -2)) == [5, 3]
 
 
 # test sort
 # test sort
 a = [8, 2, 4, 2, 9]
 a = [8, 2, 4, 2, 9]
+assert a.sort() == None
+assert a == [2, 2, 4, 8, 9]
+
+a = []
+assert a.sort() == None
+assert a == []
+
+a = [0, 0, 0, 0, 1, 1, 3, -1]
+assert a.sort() == None
+assert a == [-1, 0, 0, 0, 0, 1, 1, 3]
+
+# test sorted
 assert sorted(a) == [2, 2, 4, 8, 9]
 assert sorted(a) == [2, 2, 4, 8, 9]
 assert sorted(a, reverse=True) == [9, 8, 4, 2, 2]
 assert sorted(a, reverse=True) == [9, 8, 4, 2, 2]