Explorar el Código

refactor threads

blueloveTH hace 3 meses
padre
commit
a05eb008a4
Se han modificado 3 ficheros con 5 adiciones y 2 borrados
  1. 1 0
      include/pocketpy/common/threads.h
  2. 3 2
      src/common/threads.c
  3. 1 0
      src2/test_threads.c

+ 1 - 0
include/pocketpy/common/threads.h

@@ -76,5 +76,6 @@ typedef struct c11_thrdpool {
 void c11_thrdpool__ctor(c11_thrdpool* pool, int length);
 void c11_thrdpool__dtor(c11_thrdpool* pool);
 void c11_thrdpool__map(c11_thrdpool* pool, c11_thrdpool_func_t func, void** args, int num_tasks);
+void c11_thrdpool__join(c11_thrdpool* pool);
 
 #endif

+ 3 - 2
src/common/threads.c

@@ -193,7 +193,6 @@ void c11_thrdpool__dtor(c11_thrdpool* pool) {
 }
 
 void c11_thrdpool__map(c11_thrdpool* pool, c11_thrdpool_func_t func, void** args, int num_tasks) {
-    if(num_tasks == 0) return;
     c11_thrdpool_debug_log(-1, "c11_thrdpool__map() called on %d tasks...", num_tasks);
     while(atomic_load_explicit(&pool->ready_workers_num, memory_order_relaxed) < pool->length) {
         c11_thrd__yield();
@@ -210,13 +209,15 @@ void c11_thrdpool__map(c11_thrdpool* pool, c11_thrdpool_func_t func, void** args
     atomic_store_explicit(&pool->tasks.completed_count, 0, memory_order_relaxed);
     c11_cond__broadcast(&pool->workers_cond);
     c11_mutex__unlock(&pool->workers_mutex);
+}
 
+void c11_thrdpool__join(c11_thrdpool *pool) {
     // wait for complete
+    int num_tasks = pool->tasks.length;
     c11_thrdpool_debug_log(-1, "Waiting for %d tasks to complete...", num_tasks);
     while(atomic_load_explicit(&pool->tasks.completed_count, memory_order_acquire) < num_tasks) {
         c11_thrd__yield();
     }
-
     atomic_store_explicit(&pool->tasks.sync_val, 0, memory_order_relaxed);
     c11_thrdpool_debug_log(-1, "All %d tasks completed, `sync_val` was reset.", num_tasks);
 }

+ 1 - 0
src2/test_threads.c

@@ -33,6 +33,7 @@ int main(int argc, char** argv) {
         printf("==> %dth run\n", i + 1);
         int64_t start_ns = time_ns();
         c11_thrdpool__map(&pool, func, args, num_tasks);
+        c11_thrdpool__join(&pool);
         int64_t end_ns = time_ns();
         double elapsed = (end_ns - start_ns) / 1e9;
         printf("  Results: %lld, %lld, %lld, %lld, %lld\n",