memory.h 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. #pragma once
  2. #include "common.h"
  3. namespace pkpy{
  4. template <typename T>
  5. class shared_ptr {
  6. int* counter = nullptr;
  7. #define _t() ((T*)(counter + 1))
  8. #define _inc_counter() if(counter) ++(*counter)
  9. #define _dec_counter() if(counter && --(*counter) == 0){ _t()->~T(); free(counter); }
  10. public:
  11. shared_ptr() {}
  12. shared_ptr(int* block) : counter(block) {}
  13. shared_ptr(const shared_ptr& other) : counter(other.counter) {
  14. _inc_counter();
  15. }
  16. shared_ptr(shared_ptr&& other) noexcept : counter(other.counter) {
  17. other.counter = nullptr;
  18. }
  19. ~shared_ptr() { _dec_counter(); }
  20. bool operator==(const shared_ptr& other) const {
  21. return counter == other.counter;
  22. }
  23. bool operator!=(const shared_ptr& other) const {
  24. return counter != other.counter;
  25. }
  26. bool operator==(std::nullptr_t) const {
  27. return counter == nullptr;
  28. }
  29. bool operator!=(std::nullptr_t) const {
  30. return counter != nullptr;
  31. }
  32. shared_ptr& operator=(const shared_ptr& other) {
  33. _dec_counter();
  34. counter = other.counter;
  35. _inc_counter();
  36. return *this;
  37. }
  38. shared_ptr& operator=(shared_ptr&& other) noexcept {
  39. _dec_counter();
  40. counter = other.counter;
  41. other.counter = nullptr;
  42. return *this;
  43. }
  44. T& operator*() const { return *_t(); }
  45. T* operator->() const { return _t(); }
  46. T* get() const { return _t(); }
  47. int use_count() const { return counter ? *counter : 0; }
  48. void reset(){
  49. _dec_counter();
  50. counter = nullptr;
  51. }
  52. };
  53. #undef _t
  54. #undef _inc_counter
  55. #undef _dec_counter
  56. template <typename T, typename U, typename... Args>
  57. shared_ptr<T> make_shared(Args&&... args) {
  58. static_assert(std::is_base_of_v<T, U>, "U must be derived from T");
  59. static_assert(std::has_virtual_destructor_v<T>, "T must have virtual destructor");
  60. int* p = (int*)malloc(sizeof(int) + sizeof(U));
  61. *p = 1;
  62. new(p+1) U(std::forward<Args>(args)...);
  63. return shared_ptr<T>(p);
  64. }
  65. template <typename T, typename... Args>
  66. shared_ptr<T> make_shared(Args&&... args) {
  67. int* p = (int*)malloc(sizeof(int) + sizeof(T));
  68. *p = 1;
  69. new(p+1) T(std::forward<Args>(args)...);
  70. return shared_ptr<T>(p);
  71. }
  72. };