blueloveTH 2 anos atrás
pai
commit
15fb9f9337
1 arquivos alterados com 24 adições e 15 exclusões
  1. 24 15
      src/modules.cpp

+ 24 - 15
src/modules.cpp

@@ -243,25 +243,21 @@ void add_module_gc(VM* vm){
     vm->bind_func<0>(mod, "collect", PK_LAMBDA(VAR(vm->heap.collect())));
 }
 
+struct LineProfilerW;
+struct _LpGuard{
+    PK_ALWAYS_PASS_BY_POINTER(_LpGuard)
+    LineProfilerW* lp;
+    VM* vm;
+    _LpGuard(LineProfilerW* lp, VM* vm);
+    ~_LpGuard();
+};
+
 // line_profiler wrapper
 struct LineProfilerW{
     PY_CLASS(LineProfilerW, line_profiler, LineProfiler)
 
     LineProfiler profiler;
 
-    void enable_by_count(VM* vm){
-        if(vm->_profiler){
-            vm->ValueError("only one profiler can be enabled at a time");
-        }
-        vm->_profiler = &profiler;
-        profiler.begin();
-    }
-
-    void disable_by_count(VM* vm){
-        vm->_profiler = nullptr;
-        profiler.end();
-    }
-
     static void _register(VM* vm, PyObject* mod, PyObject* type){
         vm->bind_default_constructor<LineProfilerW>(type);
 
@@ -272,14 +268,13 @@ struct LineProfilerW{
 
         vm->bind(type, "runcall(self, func, *args)", [](VM* vm, ArgsView view){
             LineProfilerW& self = PK_OBJ_GET(LineProfilerW, view[0]);
-            self.enable_by_count(vm);
             PyObject* func = view[1];
             const Tuple& args = CAST(Tuple&, view[2]);
             vm->s_data.push(func);
             vm->s_data.push(PY_NULL);
             for(PyObject* arg : args) vm->s_data.push(arg);
+            _LpGuard guard(&self, vm);
             PyObject* ret = vm->vectorcall(args.size());
-            self.disable_by_count(vm);
             return ret;
         });
 
@@ -291,6 +286,20 @@ struct LineProfilerW{
     }
 };
 
+
+_LpGuard::_LpGuard(LineProfilerW* lp, VM* vm): lp(lp), vm(vm) {
+    if(vm->_profiler){
+        vm->ValueError("only one profiler can be enabled at a time");
+    }
+    vm->_profiler = &lp->profiler;
+    lp->profiler.begin();
+}
+
+_LpGuard::~_LpGuard(){
+    vm->_profiler = nullptr;
+    lp->profiler.end();
+}
+
 void add_module_line_profiler(VM *vm){
     PyObject* mod = vm->new_module("line_profiler");
     LineProfilerW::register_class(vm, mod);