Prechádzať zdrojové kódy

add `bytes.__len__` and improve `ord()`

blueloveTH 1 rok pred
rodič
commit
51b14d3526

+ 1 - 0
include/pocketpy/common/str.h

@@ -65,6 +65,7 @@ int c11__byte_index_to_unicode(const char* data, int n);
 
 bool c11__is_unicode_Lo_char(int c);
 int c11__u8_header(unsigned char c, bool suppress);
+int c11__u8_value(int u8bytes, const char* data);
 
 typedef enum IntParsingResult {
     IntParsing_SUCCESS,

+ 20 - 0
src/common/str.c

@@ -297,6 +297,26 @@ int c11__u8_header(unsigned char c, bool suppress) {
     return 0;
 }
 
+int c11__u8_value(int u8bytes, const char* data) {
+    assert(u8bytes != 0);
+    if(u8bytes == 1) return (int)data[0];
+    uint32_t value = 0;
+    for(int k = 0; k < u8bytes; k++) {
+        uint8_t b = data[k];
+        if(k == 0) {
+            if(u8bytes == 2)
+                value = (b & 0b00011111) << 6;
+            else if(u8bytes == 3)
+                value = (b & 0b00001111) << 12;
+            else if(u8bytes == 4)
+                value = (b & 0b00000111) << 18;
+        } else {
+            value |= (b & 0b00111111) << (6 * (u8bytes - k - 1));
+        }
+    }
+    return (int)value;
+}
+
 IntParsingResult c11__parse_uint(c11_sv text, int64_t* out, int base) {
     *out = 0;
 

+ 1 - 15
src/compiler/lexer.c

@@ -225,21 +225,7 @@ static Error* eat_name(Lexer* self) {
                 break;
             }
         }
-        // handle multibyte char
-        uint32_t value = 0;
-        for(int k = 0; k < u8bytes; k++) {
-            uint8_t b = self->curr_char[k];
-            if(k == 0) {
-                if(u8bytes == 2)
-                    value = (b & 0b00011111) << 6;
-                else if(u8bytes == 3)
-                    value = (b & 0b00001111) << 12;
-                else if(u8bytes == 4)
-                    value = (b & 0b00000111) << 18;
-            } else {
-                value |= (b & 0b00111111) << (6 * (u8bytes - k - 1));
-            }
-        }
+        int value = c11__u8_value(u8bytes, self->curr_char);
         if(c11__is_unicode_Lo_char(value)) {
             self->curr_char += u8bytes;
         } else {

+ 8 - 3
src/public/modules.c

@@ -441,10 +441,15 @@ static bool builtins_ord(int argc, py_Ref argv) {
     PY_CHECK_ARGC(1);
     PY_CHECK_ARG_TYPE(0, tp_str);
     c11_sv sv = py_tosv(py_arg(0));
-    if(sv.size != 1) {
-        return TypeError("ord() expected a character, but string of length %d found", sv.size);
+    if(c11_sv__u8_length(sv) != 1) {
+        return TypeError("ord() expected a character, but string of length %d found", c11_sv__u8_length(sv));
     }
-    py_newint(py_retval(), sv.data[0]);
+    int u8bytes = c11__u8_header(sv.data[0], true);
+    if (u8bytes == 0) {
+        return ValueError("invalid char: %c", sv.data[0]);
+    }
+    int value = c11__u8_value(u8bytes, sv.data);
+    py_newint(py_retval(), value);
     return true;
 }
 

+ 8 - 0
src/public/py_str.c

@@ -639,6 +639,13 @@ static bool bytes_decode(int argc, py_Ref argv) {
     return true;
 }
 
+static bool bytes__len__(int argc, py_Ref argv) {
+    PY_CHECK_ARGC(1);
+    c11_bytes* self = py_touserdata(&argv[0]);
+    py_newint(py_retval(), self->size);
+    return true;
+}
+
 py_Type pk_bytes__register() {
     py_Type type = pk_newtype("bytes", tp_object, NULL, NULL, false, true);
     // no need to dtor because the memory is controlled by the object
@@ -650,6 +657,7 @@ py_Type pk_bytes__register() {
     py_bindmagic(tp_bytes, __ne__, bytes__ne__);
     py_bindmagic(tp_bytes, __add__, bytes__add__);
     py_bindmagic(tp_bytes, __hash__, bytes__hash__);
+    py_bindmagic(tp_bytes, __len__, bytes__len__);
 
     py_bindmethod(tp_bytes, "decode", bytes_decode);
     return type;

+ 2 - 0
tests/46_bytes.py

@@ -9,6 +9,8 @@ assert b'' + b'' == b''
 
 assert b'\xff\xee' != b'1234'
 assert b'\xff\xee' == b'\xff\xee'
+assert len(b'\xff\xee') == 2
+assert len(b'') == 0
 
 a = '测试123'
 assert a == a.encode().decode()

+ 16 - 0
tests/76_misc.py

@@ -34,3 +34,19 @@ for i in range(ord('a'), ord('z')+1):
 
 assert A.a == ord('a')
 assert A.z == ord('z')
+
+assert ord('测') == 27979
+
+try:
+    assert ord('测试')
+    print("Should not reach here")
+    exit(1)
+except TypeError:
+    pass
+
+try:
+    assert ord('12')
+    print("Should not reach here")
+    exit(1)
+except TypeError:
+    pass