iter.h 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. #pragma once
  2. #include "ceval.h"
  3. namespace pkpy{
  4. class RangeIter : public BaseIter {
  5. i64 current;
  6. Range r; // copy by value, so we don't need to keep ref
  7. public:
  8. RangeIter(VM* vm, PyObject* ref) : BaseIter(vm) {
  9. this->r = OBJ_GET(Range, ref);
  10. this->current = r.start;
  11. }
  12. bool _has_next(){
  13. return r.step > 0 ? current < r.stop : current > r.stop;
  14. }
  15. PyObject* next(){
  16. if(!_has_next()) return nullptr;
  17. current += r.step;
  18. return VAR(current-r.step);
  19. }
  20. };
  21. template <typename T>
  22. class ArrayIter : public BaseIter {
  23. PyObject* ref;
  24. int index;
  25. public:
  26. ArrayIter(VM* vm, PyObject* ref) : BaseIter(vm), ref(ref), index(0) {}
  27. PyObject* next() override{
  28. const T* p = &OBJ_GET(T, ref);
  29. if(index == p->size()) return nullptr;
  30. return p->operator[](index++);
  31. }
  32. void _gc_mark() const override {
  33. OBJ_MARK(ref);
  34. }
  35. };
  36. class StringIter : public BaseIter {
  37. PyObject* ref;
  38. int index;
  39. public:
  40. StringIter(VM* vm, PyObject* ref) : BaseIter(vm), ref(ref), index(0) {}
  41. PyObject* next() override{
  42. Str* str = &OBJ_GET(Str, ref);
  43. if(index == str->u8_length()) return nullptr;
  44. return VAR(str->u8_getitem(index++));
  45. }
  46. void _gc_mark() const override {
  47. OBJ_MARK(ref);
  48. }
  49. };
  50. inline PyObject* Generator::next(){
  51. if(state == 2) return nullptr;
  52. vm->callstack.push(std::move(frame));
  53. PyObject* ret = vm->_exec();
  54. if(ret == vm->_py_op_yield){
  55. frame = std::move(vm->callstack.top());
  56. vm->callstack.pop();
  57. state = 1;
  58. return frame->popx();
  59. }else{
  60. state = 2;
  61. return nullptr;
  62. }
  63. }
  64. inline void Generator::_gc_mark() const{
  65. if(frame != nullptr) frame->_gc_mark();
  66. }
  67. template<typename T>
  68. void _gc_mark(T& t) {
  69. if constexpr(std::is_base_of_v<BaseIter, T>){
  70. t._gc_mark();
  71. }
  72. }
  73. } // namespace pkpy