Parcourir la source

impl `yield` stmt

BLUELOVETH il y a 3 ans
Parent
commit
4e4ed4ddbd
8 fichiers modifiés avec 55 ajouts et 10 suppressions
  1. 1 0
      src/codeobject.h
  2. 9 2
      src/compiler.h
  3. 33 5
      src/iter.h
  4. 0 1
      src/obj.h
  5. 1 0
      src/opcodes.h
  6. 1 1
      src/parser.h
  7. 3 1
      src/vm.h
  8. 7 0
      tests/_yield.py

+ 1 - 0
src/codeobject.h

@@ -51,6 +51,7 @@ struct CodeBlock {
 struct CodeObject {
 struct CodeObject {
     pkpy::shared_ptr<SourceData> src;
     pkpy::shared_ptr<SourceData> src;
     Str name;
     Str name;
+    bool is_generator = false;
 
 
     CodeObject(pkpy::shared_ptr<SourceData> src, Str name) {
     CodeObject(pkpy::shared_ptr<SourceData> src, Str name) {
         this->src = src;
         this->src = src;

+ 9 - 2
src/compiler.h

@@ -849,9 +849,16 @@ __LISTCOMP:
             if (!co()->_is_curr_block_loop()) SyntaxError("'continue' not properly in loop");
             if (!co()->_is_curr_block_loop()) SyntaxError("'continue' not properly in loop");
             consume_end_stmt();
             consume_end_stmt();
             emit(OP_LOOP_CONTINUE);
             emit(OP_LOOP_CONTINUE);
+        } else if (match(TK("yield"))) {
+            if (codes.size() == 1) SyntaxError("'yield' outside function");
+            co()->_rvalue = true;
+            EXPR_TUPLE();
+            co()->_rvalue = false;
+            consume_end_stmt();
+            co()->is_generator = true;
+            emit(OP_YIELD_VALUE, -1, true);
         } else if (match(TK("return"))) {
         } else if (match(TK("return"))) {
-            if (codes.size() == 1)
-                SyntaxError("'return' outside function");
+            if (codes.size() == 1) SyntaxError("'return' outside function");
             if(match_end_stmt()){
             if(match_end_stmt()){
                 emit(OP_LOAD_NONE);
                 emit(OP_LOAD_NONE);
             }else{
             }else{

+ 33 - 5
src/iter.h

@@ -11,11 +11,12 @@ public:
         this->current = r.start;
         this->current = r.start;
     }
     }
 
 
-    bool has_next(){
+    inline bool _has_next(){
         return r.step > 0 ? current < r.stop : current > r.stop;
         return r.step > 0 ? current < r.stop : current > r.stop;
     }
     }
 
 
     PyVar next(){
     PyVar next(){
+        if(!_has_next()) return nullptr;
         current += r.step;
         current += r.step;
         return vm->PyInt(current-r.step);
         return vm->PyInt(current-r.step);
     }
     }
@@ -27,8 +28,10 @@ class ArrayIter : public BaseIter {
     const T* p;
     const T* p;
 public:
 public:
     ArrayIter(VM* vm, PyVar _ref) : BaseIter(vm, _ref) { p = &OBJ_GET(T, _ref);}
     ArrayIter(VM* vm, PyVar _ref) : BaseIter(vm, _ref) { p = &OBJ_GET(T, _ref);}
-    bool has_next(){ return index < p->size(); }
-    PyVar next(){ return p->operator[](index++); }
+    PyVar next(){
+        if(index == p->size()) return nullptr;
+        return p->operator[](index++); 
+    }
 };
 };
 
 
 class StringIter : public BaseIter {
 class StringIter : public BaseIter {
@@ -39,6 +42,31 @@ public:
         str = OBJ_GET(Str, _ref);
         str = OBJ_GET(Str, _ref);
     }
     }
 
 
-    bool has_next(){ return index < str.u8_length(); }
-    PyVar next() { return vm->PyStr(str.u8_getitem(index++)); }
+    PyVar next() {
+        if(index == str.u8_length()) return nullptr;
+        return vm->PyStr(str.u8_getitem(index++));
+    }
 };
 };
+
+class Generator: public BaseIter {
+    std::unique_ptr<Frame> frame;
+    int state; // 0,1,2
+public:
+    Generator(VM* vm, std::unique_ptr<Frame>&& frame)
+        : BaseIter(vm, nullptr), frame(std::move(frame)), state(0) {}
+
+    PyVar next() {
+        if(state == 2) return nullptr;
+        vm->callstack.push(std::move(frame));
+        PyVar ret = vm->_exec();
+        if(ret == vm->_py_op_yield){
+            frame = std::move(vm->callstack.top());
+            vm->callstack.pop();
+            state = 1;
+            return frame->pop_value(vm);
+        }else{
+            state = 2;
+            return nullptr;
+        }
+    }
+};

+ 0 - 1
src/obj.h

@@ -68,7 +68,6 @@ protected:
     PyVar _ref;     // keep a reference to the object so it will not be deleted while iterating
     PyVar _ref;     // keep a reference to the object so it will not be deleted while iterating
 public:
 public:
     virtual PyVar next() = 0;
     virtual PyVar next() = 0;
-    virtual bool has_next() = 0;
     PyVarRef var;
     PyVarRef var;
     BaseIter(VM* vm, PyVar _ref) : vm(vm), _ref(_ref) {}
     BaseIter(VM* vm, PyVar _ref) : vm(vm), _ref(_ref) {}
     virtual ~BaseIter() = default;
     virtual ~BaseIter() = default;

+ 1 - 0
src/opcodes.h

@@ -68,6 +68,7 @@ OPCODE(DELETE_REF)
 OPCODE(TRY_BLOCK_ENTER)
 OPCODE(TRY_BLOCK_ENTER)
 OPCODE(TRY_BLOCK_EXIT)
 OPCODE(TRY_BLOCK_EXIT)
 
 
+OPCODE(YIELD_VALUE)
 //OPCODE(FAST_INDEX_0)      // a[0]
 //OPCODE(FAST_INDEX_0)      // a[0]
 //OPCODE(FAST_INDEX_1)      // a[i]
 //OPCODE(FAST_INDEX_1)      // a[i]
 
 

+ 1 - 1
src/parser.h

@@ -12,7 +12,7 @@ constexpr const char* kTokens[] = {
     "==", "!=", ">=", "<=",
     "==", "!=", ">=", "<=",
     "+=", "-=", "*=", "/=", "//=", "%=", "&=", "|=", "^=",
     "+=", "-=", "*=", "/=", "//=", "%=", "&=", "|=", "^=",
     /** KW_BEGIN **/
     /** KW_BEGIN **/
-    "class", "import", "as", "def", "lambda", "pass", "del", "from", "with",
+    "class", "import", "as", "def", "lambda", "pass", "del", "from", "with", "yield",
     "None", "in", "is", "and", "or", "not", "True", "False", "global", "try", "except", "finally",
     "None", "in", "is", "and", "or", "not", "True", "False", "global", "try", "except", "finally",
     "goto", "label",      // extended keywords, not available in cpython
     "goto", "label",      // extended keywords, not available in cpython
     "while", "for", "if", "elif", "else", "break", "continue", "return", "assert", "raise",
     "while", "for", "if", "elif", "else", "break", "continue", "return", "assert", "raise",

Fichier diff supprimé car celui-ci est trop grand
+ 3 - 1
src/vm.h


+ 7 - 0
tests/_yield.py

@@ -0,0 +1,7 @@
+def f(n):
+    for i in range(n):
+        yield i
+
+a = [i for i in f(6)]
+
+assert a == [0,1,2,3,4,5]

Certains fichiers n'ont pas été affichés car il y a eu trop de fichiers modifiés dans ce diff