blueloveTH 5 дней назад
Родитель
Сommit
2a2e9ae1f2
1 измененных файлов с 12 добавлено и 11 удалено
  1. 12 11
      src/modules/lz4.c

+ 12 - 11
src/modules/lz4.c

@@ -2,6 +2,7 @@
 
 #include <string.h>
 #include <assert.h>
+#include "pocketpy/common/utils.h"
 #include "pocketpy/pocketpy.h"
 #include "lz4/lib/lz4.h"
 
@@ -10,13 +11,13 @@ static bool lz4_compress(int argc, py_Ref argv) {
     PY_CHECK_ARG_TYPE(0, tp_bytes);
     int src_size;
     const void* src = py_tobytes(argv, &src_size);
-    int dst_capacity = LZ4_compressBound(src_size);
-    char* p = (char*)py_newbytes(py_retval(), sizeof(int) + dst_capacity);
-    memcpy(p, &src_size, sizeof(int));
-    char* dst = p + sizeof(int);
+    uint32_t dst_capacity = LZ4_compressBound(src_size);
+    char* p = (char*)py_newbytes(py_retval(), sizeof(uint32_t) + dst_capacity);
+    memcpy(p, &src_size, sizeof(uint32_t));
+    char* dst = p + sizeof(uint32_t);
     int dst_size = LZ4_compress_default(src, dst, src_size, dst_capacity);
     if(dst_size <= 0) return ValueError("LZ4 compression failed");
-    py_bytes_resize(py_retval(), sizeof(int) + dst_size);
+    py_bytes_resize(py_retval(), sizeof(uint32_t) + dst_size);
     return true;
 }
 
@@ -24,15 +25,15 @@ static bool lz4_decompress(int argc, py_Ref argv) {
     PY_CHECK_ARGC(1);
     PY_CHECK_ARG_TYPE(0, tp_bytes);
     int total_size;
-    const int* p = (int*)py_tobytes(argv, &total_size);
+    const uint32_t* p = (uint32_t*)py_tobytes(argv, &total_size);
     const char* src = (const char*)(p + 1);
-    if(total_size < sizeof(int)) return ValueError("invalid LZ4 data");
-    int uncompressed_size = *p;
-    if(uncompressed_size < 0) return ValueError("invalid LZ4 data");
+    if(total_size < sizeof(uint32_t)) return ValueError("invalid LZ4 data");
+    uint32_t uncompressed_size = *p;
+    if(uncompressed_size >= INT32_MAX) return ValueError("invalid LZ4 data");
     char* dst = (char*)py_newbytes(py_retval(), uncompressed_size);
-    int dst_size = LZ4_decompress_safe(src, dst, total_size - sizeof(int), uncompressed_size);
+    int dst_size = LZ4_decompress_safe(src, dst, total_size - sizeof(uint32_t), uncompressed_size);
     if(dst_size < 0) return ValueError("LZ4 decompression failed");
-    assert(dst_size == uncompressed_size);
+    c11__rtassert(dst_size == uncompressed_size);
     return true;
 }