浏览代码

fix lower_bound

blueloveTH 1 年之前
父节点
当前提交
c98eb31a5e
共有 6 个文件被更改,包括 33 次插入62 次删除
  1. 19 6
      include/pocketpy/common/algorithm.h
  2. 7 1
      include/pocketpy/common/str.hpp
  3. 0 37
      src/common/algorithm.c
  4. 3 1
      src/common/str.c
  5. 0 14
      src/common/str.cpp
  6. 4 3
      src/modules/random.cpp

+ 19 - 6
include/pocketpy/common/algorithm.h

@@ -1,16 +1,29 @@
 #pragma once
 
+#include <stdbool.h>
+
 #ifdef __cplusplus
 extern "C" {
 #endif
 
-#include <stdbool.h>
-
-void *c11__lower_bound(const void *key, const void *ptr, int count, int size,
-                       bool (*less)(const void *, const void *));
+#define c11__less(a, b) ((a) < (b))
 
-int *c11__lower_bound_int(int key, const int *ptr, int count);
-double *c11__lower_bound_double(double key, const double *ptr, int count);
+#define c11__lower_bound(T, ptr, count, key, less, out)                        \
+  do {                                                                         \
+    const T *__first = ptr;                                                    \
+    int __len = count;                                                         \
+    while (__len != 0) {                                                       \
+      int __l2 = (int)((unsigned int)__len / 2);                               \
+      const T *__m = __first + __l2;                                           \
+      if (less((*__m), (key))) {                                               \
+        __first = ++__m;                                                       \
+        __len -= __l2 + 1;                                                     \
+      } else {                                                                 \
+        __len = __l2;                                                          \
+      }                                                                        \
+    }                                                                          \
+    *(out) = __first;                                                          \
+  } while (0)
 
 #ifdef __cplusplus
 }

+ 7 - 1
include/pocketpy/common/str.hpp

@@ -41,7 +41,13 @@ struct Str: pkpy_Str {
         pkpy_Str__ctor2(this, s, len);
     }
 
-    Str(pair<char*, int>);      // take ownership
+    Str(pair<char*, int> detached) {
+        this->size = detached.second;
+        this->is_ascii = c11__isascii(detached.first, detached.second);
+        this->is_sso = false;
+        this->_ptr = detached.first;
+        assert(_ptr[size] == '\0');
+    }
 
     Str(const Str& other){
         pkpy_Str__ctor2(this, pkpy_Str__data(&other), other.size);

+ 0 - 37
src/common/algorithm.c

@@ -1,39 +1,2 @@
 #include "pocketpy/common/algorithm.h"
 
-void *c11__lower_bound(const void *key, const void *ptr, int count, int size,
-                       bool (*less)(const void *, const void *)) {
-    char* __first = (char*)ptr;
-    int __len = count;
-
-    while(__len != 0){
-        int __l2 = (int)((unsigned int)__len / 2);
-        char* __m = __first + __l2 * size;
-        if(less(__m, key)){
-            __m += size;
-            __first = __m;
-            __len -= __l2 + 1;
-        }else{
-            __len = __l2;
-        }
-    }
-    return __first;
-}
-
-static bool c11__less_int(const void* a, const void* b){
-    return *(int*)a < *(int*)b;
-}
-
-static bool c11__less_double(const void* a, const void* b){
-    return *(double*)a < *(double*)b;
-}
-
-int *c11__lower_bound_int(int key, const int *ptr, int count) {
-    void* res = c11__lower_bound(&key, ptr, count, sizeof(int), c11__less_int);
-    return (int*)res;
-}
-
-double *c11__lower_bound_double(double key, const double *ptr, int count) {
-    void* res = c11__lower_bound(&key, ptr, count, sizeof(double), c11__less_double);
-    return (double*)res;
-}
-

+ 3 - 1
src/common/str.c

@@ -386,7 +386,9 @@ static const int kLoRangeB[] = {170,186,443,451,660,1514,1522,1599,1610,1647,174
 
 bool c11__is_unicode_Lo_char(int c){
     if(c == 0x1f955) return true;
-    int index = c11__lower_bound_int(c, kLoRangeA, 476) - kLoRangeA;
+    const int* p;
+    c11__lower_bound(int, kLoRangeA, 476, c, c11__less, &p);
+    int index = p - kLoRangeA;
     if(c == kLoRangeA[index]) return true;
     index -= 1;
     if(index < 0) return false;

+ 0 - 14
src/common/str.cpp

@@ -8,20 +8,6 @@
 
 namespace pkpy {
 
-Str::Str(pair<char*, int> detached) {
-    this->size = detached.second;
-    this->is_ascii = true;
-    this->is_sso = false;
-    this->_ptr = detached.first;
-    for(int i = 0; i < size; i++) {
-        if(!isascii(_ptr[i])) {
-            is_ascii = false;
-            break;
-        }
-    }
-    assert(_ptr[size] == '\0');
-}
-
 static std::map<std::string_view, uint16_t>& _interned() {
     static std::map<std::string_view, uint16_t> interned;
     return interned;

+ 4 - 3
src/modules/random.cpp

@@ -198,9 +198,10 @@ struct Random {
             int k = CAST(int, args[3]);
             List result(k);
             for(int i = 0; i < k; i++) {
-                f64 r = self.gen.uniform(0.0, cum_weights[size - 1]);
-                int idx = c11__lower_bound_double(r, cum_weights.begin(), cum_weights.size()) - cum_weights.begin();
-                result[i] = data[idx];
+                f64 key = self.gen.uniform(0.0, cum_weights[size - 1]);
+                const f64* p;
+                c11__lower_bound(f64, cum_weights.begin(), cum_weights.size(), key, c11__less, &p);
+                result[i] = data[p - cum_weights.begin()];
             }
             return VAR(std::move(result));
         });