diff --git a/lib_fiber/cpp/src/wait_group.cpp b/lib_fiber/cpp/src/wait_group.cpp index d102e3c6f..6324b79ca 100644 --- a/lib_fiber/cpp/src/wait_group.cpp +++ b/lib_fiber/cpp/src/wait_group.cpp @@ -17,8 +17,7 @@ wait_group::~wait_group(void) void wait_group::add(int n) { - state_ += (long long)n << 32; - long long state = state_; + long long state = state_.add_fetch((long long)n << 32); int c = (int)(state >> 32); uint32_t w = (uint32_t)state; if(c < 0){ @@ -52,22 +51,26 @@ void wait_group::done(void) void wait_group::wait(void) { - long long state = state_; - int c = (int)(state >> 32); - uint32_t w = (uint32_t)state; - if(c == 0) return; - state_++; - bool found; + for(;;){ + long long state = state_; + int c = (int)(state >> 32); + uint32_t w = (uint32_t)state; + if(c == 0) return; + if(state_.cas(state, state + 1) == state){ + bool found; #ifdef _DEBUG - unsigned long* tid = box_->pop(-1, &found); - assert(found); - delete tid; + unsigned long* tid = box_->pop(-1, &found); + assert(found); + delete tid; #else - (void) box_->pop(-1, &found); - assert(found); + (void) box_->pop(-1, &found); + assert(found); #endif - if(state_ != 0){ - acl_msg_fatal("wait_group: wait_group is reused before previous wait has returned"); + if(state_ != 0){ + acl_msg_fatal("wait_group: wait_group is reused before previous wait has returned"); + } + return; + } } }