iter.h 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. #pragma once
  2. #include "ceval.h"
  3. #include "frame.h"
  4. namespace pkpy{
  5. class RangeIter final: public BaseIter {
  6. i64 current;
  7. Range r; // copy by value, so we don't need to keep ref
  8. public:
  9. RangeIter(VM* vm, PyObject* ref) : BaseIter(vm) {
  10. this->r = OBJ_GET(Range, ref);
  11. this->current = r.start;
  12. }
  13. bool _has_next(){
  14. return r.step > 0 ? current < r.stop : current > r.stop;
  15. }
  16. PyObject* next(){
  17. if(!_has_next()) return vm->StopIteration;
  18. current += r.step;
  19. return VAR(current-r.step);
  20. }
  21. };
  22. template <typename T>
  23. class ArrayIter final: public BaseIter {
  24. PyObject* ref;
  25. T* array;
  26. int index;
  27. public:
  28. ArrayIter(VM* vm, PyObject* ref) : BaseIter(vm), ref(ref) {
  29. array = &OBJ_GET(T, ref);
  30. index = 0;
  31. }
  32. PyObject* next() override{
  33. if(index >= array->size()) return vm->StopIteration;
  34. return array->operator[](index++);
  35. }
  36. void _gc_mark() const{
  37. OBJ_MARK(ref);
  38. }
  39. };
  40. class StringIter final: public BaseIter {
  41. PyObject* ref;
  42. int index;
  43. public:
  44. StringIter(VM* vm, PyObject* ref) : BaseIter(vm), ref(ref), index(0) {}
  45. PyObject* next() override{
  46. // TODO: optimize this to use iterator
  47. // operator[] is O(n) complexity
  48. Str* str = &OBJ_GET(Str, ref);
  49. if(index == str->u8_length()) return vm->StopIteration;
  50. return VAR(str->u8_getitem(index++));
  51. }
  52. void _gc_mark() const{
  53. OBJ_MARK(ref);
  54. }
  55. };
  56. inline PyObject* Generator::next(){
  57. if(state == 2) return vm->StopIteration;
  58. // reset frame._sp_base
  59. frame._sp_base = frame._s->_sp;
  60. frame._locals.a = frame._s->_sp;
  61. // restore the context
  62. for(PyObject* obj: s_backup) frame._s->push(obj);
  63. s_backup.clear();
  64. vm->callstack.push(std::move(frame));
  65. PyObject* ret = vm->_run_top_frame();
  66. if(ret == PY_OP_YIELD){
  67. // backup the context
  68. frame = std::move(vm->callstack.top());
  69. PyObject* ret = frame._s->popx();
  70. for(PyObject* obj: frame.stack_view()) s_backup.push_back(obj);
  71. vm->_pop_frame();
  72. state = 1;
  73. if(ret == vm->StopIteration) state = 2;
  74. return ret;
  75. }else{
  76. state = 2;
  77. return vm->StopIteration;
  78. }
  79. }
  80. inline void Generator::_gc_mark() const{
  81. frame._gc_mark();
  82. for(PyObject* obj: s_backup) OBJ_MARK(obj);
  83. }
  84. } // namespace pkpy