memory.h 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. #pragma once
  2. #include "common.h"
  3. struct PyObject;
  4. namespace pkpy{
  5. template<typename T>
  6. struct SpAllocator {
  7. template<typename U>
  8. inline static int* alloc(){
  9. return (int*)malloc(sizeof(int) + sizeof(U));
  10. }
  11. inline static void dealloc(int* counter){
  12. ((T*)(counter + 1))->~T();
  13. free(counter);
  14. }
  15. };
  16. template <typename T>
  17. class shared_ptr {
  18. int* counter;
  19. inline T* _t() const {
  20. if constexpr(std::is_same_v<T, PyObject>){
  21. if(is_tagged()) UNREACHABLE();
  22. }
  23. return (T*)(counter + 1);
  24. }
  25. inline void _inc_counter() const {
  26. if constexpr(std::is_same_v<T, PyObject>){
  27. if(is_tagged()) return;
  28. }
  29. if(counter) ++(*counter);
  30. }
  31. inline void _dec_counter() const {
  32. if constexpr(std::is_same_v<T, PyObject>){
  33. if(is_tagged()) return;
  34. }
  35. if(counter && --(*counter) == 0){
  36. SpAllocator<T>::dealloc(counter);
  37. }
  38. }
  39. public:
  40. shared_ptr() : counter(nullptr) {}
  41. shared_ptr(int* counter) : counter(counter) {}
  42. shared_ptr(const shared_ptr& other) : counter(other.counter) {
  43. _inc_counter();
  44. }
  45. shared_ptr(shared_ptr&& other) noexcept : counter(other.counter) {
  46. other.counter = nullptr;
  47. }
  48. ~shared_ptr() { _dec_counter(); }
  49. bool operator==(const shared_ptr& other) const { return counter == other.counter; }
  50. bool operator!=(const shared_ptr& other) const { return counter != other.counter; }
  51. bool operator<(const shared_ptr& other) const { return counter < other.counter; }
  52. bool operator>(const shared_ptr& other) const { return counter > other.counter; }
  53. bool operator<=(const shared_ptr& other) const { return counter <= other.counter; }
  54. bool operator>=(const shared_ptr& other) const { return counter >= other.counter; }
  55. bool operator==(std::nullptr_t) const { return counter == nullptr; }
  56. bool operator!=(std::nullptr_t) const { return counter != nullptr; }
  57. shared_ptr& operator=(const shared_ptr& other) {
  58. _dec_counter();
  59. counter = other.counter;
  60. _inc_counter();
  61. return *this;
  62. }
  63. shared_ptr& operator=(shared_ptr&& other) noexcept {
  64. _dec_counter();
  65. counter = other.counter;
  66. other.counter = nullptr;
  67. return *this;
  68. }
  69. T& operator*() const { return *_t(); }
  70. T* operator->() const { return _t(); }
  71. T* get() const { return _t(); }
  72. int use_count() const {
  73. if(is_tagged()) return 1;
  74. return counter ? *counter : 0;
  75. }
  76. void reset(){
  77. _dec_counter();
  78. counter = nullptr;
  79. }
  80. template <typename __VAL>
  81. inline __VAL cast() const {
  82. static_assert(std::is_same_v<T, PyObject>, "T must be PyObject");
  83. return reinterpret_cast<__VAL>(counter);
  84. }
  85. inline bool is_tagged() const { return (cast<i64>() & 0b11) != 0b00; }
  86. inline bool is_tag_00() const { return (cast<i64>() & 0b11) == 0b00; }
  87. inline bool is_tag_01() const { return (cast<i64>() & 0b11) == 0b01; }
  88. inline bool is_tag_10() const { return (cast<i64>() & 0b11) == 0b10; }
  89. inline bool is_tag_11() const { return (cast<i64>() & 0b11) == 0b11; }
  90. };
  91. template <typename T, typename U, typename... Args>
  92. shared_ptr<T> make_shared(Args&&... args) {
  93. static_assert(std::is_base_of_v<T, U>, "U must be derived from T");
  94. static_assert(std::has_virtual_destructor_v<T>, "T must have virtual destructor");
  95. static_assert(!std::is_same_v<T, PyObject> || (!std::is_same_v<U, i64> && !std::is_same_v<U, f64>));
  96. int* p = SpAllocator<T>::template alloc<U>(); *p = 1;
  97. new(p+1) U(std::forward<Args>(args)...);
  98. return shared_ptr<T>(p);
  99. }
  100. template <typename T, typename... Args>
  101. shared_ptr<T> make_shared(Args&&... args) {
  102. int* p = SpAllocator<T>::template alloc<T>(); *p = 1;
  103. new(p+1) T(std::forward<Args>(args)...);
  104. return shared_ptr<T>(p);
  105. }
  106. };
  107. static_assert(sizeof(i64) == sizeof(pkpy::shared_ptr<PyObject>));
  108. static_assert(sizeof(f64) == sizeof(pkpy::shared_ptr<PyObject>));