Jelajahi Sumber

improve `io`

blueloveTH 1 tahun lalu
induk
melakukan
79cafcf32c
2 mengubah file dengan 74 tambahan dan 19 penghapusan
  1. 44 17
      src/io.cpp
  2. 30 2
      tests/70_file.py

+ 44 - 17
src/io.cpp

@@ -12,12 +12,10 @@ namespace pkpy{
 struct FileIO {
     PY_CLASS(FileIO, io, FileIO)
 
-    Str file;
-    Str mode;
     FILE* fp;
+    bool is_text;
 
-    bool is_text() const { return mode != "rb" && mode != "wb" && mode != "ab"; }
-    FileIO(VM* vm, std::string file, std::string mode);
+    FileIO(VM* vm, const Str& file, const Str& mode);
     void close();
     static void _register(VM* vm, PyObject* mod, PyObject* type);
 };
@@ -62,27 +60,34 @@ void FileIO::_register(VM* vm, PyObject* mod, PyObject* type){
     vm->bind_constructor<3>(type, [](VM* vm, ArgsView args){
         Type cls = PK_OBJ_GET(Type, args[0]);
         return vm->heap.gcnew<FileIO>(cls, vm,
-                    py_cast<Str&>(vm, args[1]).str(),
-                    py_cast<Str&>(vm, args[2]).str());
+                    py_cast<Str&>(vm, args[1]),
+                    py_cast<Str&>(vm, args[2]));
     });
 
-    vm->bind_method<0>(type, "read", [](VM* vm, ArgsView args){
-        FileIO& io = CAST(FileIO&, args[0]);
-        fseek(io.fp, 0, SEEK_END);
-        int buffer_size = ftell(io.fp);
+    vm->bind(type, "read(self, size=-1)", [](VM* vm, ArgsView args){
+        FileIO& io = PK_OBJ_GET(FileIO, args[0]);
+        i64 size = CAST(i64, args[1]);
+        i64 buffer_size;
+        if(size < 0){
+            long current = ftell(io.fp);
+            fseek(io.fp, 0, SEEK_END);
+            buffer_size = ftell(io.fp);
+            fseek(io.fp, current, SEEK_SET);
+        }else{
+            buffer_size = size;
+        }
         unsigned char* buffer = new unsigned char[buffer_size];
-        fseek(io.fp, 0, SEEK_SET);
         size_t actual_size = io_fread(buffer, 1, buffer_size, io.fp);
         PK_ASSERT(actual_size <= buffer_size);
         // in text mode, CR may be dropped, which may cause `actual_size < buffer_size`
         Bytes b(buffer, actual_size);
-        if(io.is_text()) return VAR(b.str());
+        if(io.is_text) return VAR(b.str());
         return VAR(std::move(b));
     });
 
     vm->bind_method<1>(type, "write", [](VM* vm, ArgsView args){
-        FileIO& io = CAST(FileIO&, args[0]);
-        if(io.is_text()){
+        FileIO& io = PK_OBJ_GET(FileIO, args[0]);
+        if(io.is_text){
             Str& s = CAST(Str&, args[1]);
             fwrite(s.data, 1, s.length(), io.fp);
         }else{
@@ -92,14 +97,30 @@ void FileIO::_register(VM* vm, PyObject* mod, PyObject* type){
         return vm->None;
     });
 
+    vm->bind_method<0>(type, "tell", [](VM* vm, ArgsView args){
+        FileIO& io = PK_OBJ_GET(FileIO, args[0]);
+        long pos = ftell(io.fp);
+        if(pos == -1) vm->IOError(strerror(errno));
+        return VAR(pos);
+    });
+
+    vm->bind_method<2>(type, "seek", [](VM* vm, ArgsView args){
+        FileIO& io = PK_OBJ_GET(FileIO, args[0]);
+        long offset = CAST(long, args[1]);
+        int whence = CAST(int, args[2]);
+        int ret = fseek(io.fp, offset, whence);
+        if(ret != 0) vm->IOError(strerror(errno));
+        return vm->None;
+    });
+
     vm->bind_method<0>(type, "close", [](VM* vm, ArgsView args){
-        FileIO& io = CAST(FileIO&, args[0]);
+        FileIO& io = PK_OBJ_GET(FileIO, args[0]);
         io.close();
         return vm->None;
     });
 
     vm->bind_method<0>(type, "__exit__", [](VM* vm, ArgsView args){
-        FileIO& io = CAST(FileIO&, args[0]);
+        FileIO& io = PK_OBJ_GET(FileIO, args[0]);
         io.close();
         return vm->None;
     });
@@ -107,7 +128,8 @@ void FileIO::_register(VM* vm, PyObject* mod, PyObject* type){
     vm->bind_method<0>(type, "__enter__", PK_LAMBDA(args[0]));
 }
 
-FileIO::FileIO(VM* vm, std::string file, std::string mode): file(file), mode(mode) {
+FileIO::FileIO(VM* vm, const Str& file, const Str& mode){
+    this->is_text = mode.sv().find("b") == std::string::npos;
     fp = io_fopen(file.c_str(), mode.c_str());
     if(!fp) vm->IOError(strerror(errno));
 }
@@ -121,6 +143,11 @@ void FileIO::close(){
 void add_module_io(VM* vm){
     PyObject* mod = vm->new_module("io");
     FileIO::register_class(vm, mod);
+
+    mod->attr().set("SEEK_SET", VAR(SEEK_SET));
+    mod->attr().set("SEEK_CUR", VAR(SEEK_CUR));
+    mod->attr().set("SEEK_END", VAR(SEEK_END));
+
     vm->bind(vm->builtins, "open(path, mode='r')", [](VM* vm, ArgsView args){
         PK_LOCAL_STATIC StrName m_io("io");
         PK_LOCAL_STATIC StrName m_FileIO("FileIO");

+ 30 - 2
tests/70_file.py

@@ -13,6 +13,34 @@ a.close()
 with open('123.txt', 'rt') as f:
     assert f.read() == '123456'
 
+with open('123.txt', 'rt') as f:
+    assert f.read(3) == '123'
+    assert f.tell() == 3
+    assert f.read(3) == '456'
+    assert f.tell() == 6
+    assert f.read(3) == ''      # EOF
+    assert f.tell() == 6
+
+with open('123.txt', 'rb') as f:
+    assert f.read(2) == b'12'
+    assert f.tell() == 2
+    assert f.read(2) == b'34'
+    assert f.tell() == 4
+    assert f.read(2) == b'56'
+    assert f.tell() == 6
+    assert f.read(2) == b''     # EOF
+    assert f.tell() == 6
+
+# test fseek
+with open('123.txt', 'rt') as f:
+    f.seek(0, io.SEEK_END)
+    assert f.tell() == 6
+    assert f.read() == ''
+    f.seek(3, io.SEEK_SET)
+    assert f.tell() == 3
+    assert f.read() == '456'
+    assert f.tell() == 6
+
 with open('123.txt', 'a') as f:
     f.write('测试')
 
@@ -29,13 +57,13 @@ with open('123.bin', 'wb') as f:
     f.write('123'.encode())
     f.write('测试'.encode())
 
-def f():
+def f_():
     with open('123.bin', 'rb') as f:
         b = f.read()
         assert isinstance(b, bytes)
         assert b == '123测试'.encode()
 
-f()
+f_()
 
 assert os.path.exists('123.bin')
 os.remove('123.bin')