Skip to content

Commit

Permalink
Revert "ThreadPool: Spend less time busy waiting. (#21545)"
Browse files Browse the repository at this point in the history
This reverts commit 4e15b22.
  • Loading branch information
snnn authored Oct 4, 2024
1 parent d0ddfa9 commit 46e73a2
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 83 deletions.
65 changes: 10 additions & 55 deletions include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ class RunQueue {

static std::atomic<uint32_t> next_tag{1};

template <typename Environment, bool kIsHybrid>
template <typename Environment>
class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInterface {
private:
struct PerThread;
Expand Down Expand Up @@ -767,29 +767,6 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
typedef std::function<void()> Task;
typedef RunQueue<Task, Tag, 1024> Queue;

// Class for waiting w/ exponential backoff.
// Template argument is maximum number of spins in backoff loop.
template <unsigned kMaxBackoff>
class ThreadPoolWaiter {
// Current number if spins in backoff loop
unsigned pause_time_;

public:
void wait() {
// If kMaxBackoff is zero don't do any pausing.
if constexpr (kMaxBackoff == 1) {
onnxruntime::concurrency::SpinPause();
} else if constexpr (kMaxBackoff > 1) {
// Exponential backoff
unsigned pause_time = pause_time_ + 1U;
for (unsigned i = 0; i < pause_time; ++i) {
onnxruntime::concurrency::SpinPause();
}
pause_time_ = (pause_time * 2U) % kMaxBackoff;
}
}
};

ThreadPoolTempl(const CHAR_TYPE* name, int num_threads, bool allow_spinning, Environment& env,
const ThreadOptions& thread_options)
: profiler_(num_threads, name),
Expand Down Expand Up @@ -931,9 +908,8 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
// finish dispatch work. This avoids new tasks being started
// concurrently with us attempting to end the parallel section.
if (ps.dispatch_q_idx != -1) {
ThreadPoolWaiter<4> waiter{};
while (!ps.dispatch_done.load(std::memory_order_acquire)) {
waiter.wait();
onnxruntime::concurrency::SpinPause();
}
}

Expand All @@ -955,17 +931,15 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter

// Wait for the dispatch task's own work...
if (ps.dispatch_q_idx > -1) {
ThreadPoolWaiter<kIsHybrid ? 0 : 1> waiter{};
while (!ps.work_done.load(std::memory_order_acquire)) {
waiter.wait();
onnxruntime::concurrency::SpinPause();
}
}

// ...and wait for any other tasks not revoked to finish their work
auto tasks_to_wait_for = tasks_started - ps.tasks_revoked;
ThreadPoolWaiter<kIsHybrid ? 0 : 1> waiter{};
while (ps.tasks_finished < tasks_to_wait_for) {
waiter.wait();
onnxruntime::concurrency::SpinPause();
}

// Clear status to allow the ThreadPoolParallelSection to be
Expand Down Expand Up @@ -1283,10 +1257,9 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
// Increase the worker count if needed. Each worker will pick up
// loops to execute from the current parallel section.
std::function<void(unsigned)> worker_fn = [&ps](unsigned par_idx) {
ThreadPoolWaiter<kIsHybrid ? 4 : 0> waiter{};
while (ps.active) {
if (ps.current_loop.load() == nullptr) {
waiter.wait();
onnxruntime::concurrency::SpinPause();
} else {
ps.workers_in_loop++;
ThreadPoolLoop* work_item = ps.current_loop;
Expand All @@ -1307,9 +1280,8 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter

// Wait for workers to exit the loop
ps.current_loop = 0;
ThreadPoolWaiter<kIsHybrid ? 1 : 4> waiter{};
while (ps.workers_in_loop) {
waiter.wait();
onnxruntime::concurrency::SpinPause();
}
profiler_.LogEnd(ThreadPoolProfiler::WAIT);
}
Expand Down Expand Up @@ -1560,30 +1532,13 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter

assert(td.GetStatus() == WorkerData::ThreadStatus::Spinning);

// The exact value of spin_count and steal_count are arbitrary and
// were experimentally determined. These numbers yielded the best
// performance across a range of workloads and
// machines. Generally, the goal of tuning spin_count is to make
// the number as small as possible while ensuring there is enough
// slack so that if each core is doing the same amount of work it
// won't sleep before they have all finished. The idea here is
// that in pipelined workloads, it won't sleep during each stage
// if it's done a bit faster than its neighbors, but that if there
// are non-equal sizes of work distributed, it won't take too long
// to reach sleep giving power (and thus frequency/performance) to
// its neighbors. Since hybrid has P/E cores, a lower value is
// chosen. On hybrid systems, even with equal sized workloads
// distributed the compute time won't stay synced. Typically in
// the hybrid case the P cores finish first (and are thus waiting)
// which is essentially a priority inversion.
constexpr int pref_spin_count = kIsHybrid ? 5000 : 10000;
const int spin_count = allow_spinning_ ? pref_spin_count : 0;
constexpr int steal_count = pref_spin_count / (kIsHybrid ? 25 : 100);
constexpr int log2_spin = 20;
const int spin_count = allow_spinning_ ? (1ull << log2_spin) : 0;
const int steal_count = spin_count / 100;

SetDenormalAsZero(set_denormal_as_zero_);
profiler_.LogThreadId(thread_id);

ThreadPoolWaiter<kIsHybrid ? 1 : 8> waiter{};
while (!should_exit) {
Task t = q.PopFront();
if (!t) {
Expand All @@ -1599,7 +1554,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
if (spin_loop_status_.load(std::memory_order_relaxed) == SpinLoopStatus::kIdle) {
break;
}
waiter.wait();
onnxruntime::concurrency::SpinPause();
}

// Attempt to block
Expand Down
5 changes: 2 additions & 3 deletions include/onnxruntime/core/platform/threadpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ struct TensorOpCost {

namespace concurrency {

template <typename Environment, bool kIsHybrid>
template <typename Environment>
class ThreadPoolTempl;

class ExtendedThreadPoolInterface;
Expand Down Expand Up @@ -424,8 +424,7 @@ class ThreadPool {
ExtendedThreadPoolInterface* underlying_threadpool_ = nullptr;

// If used, underlying_threadpool_ is instantiated and owned by the ThreadPool.
std::unique_ptr<ThreadPoolTempl<Env, true>> extended_eigen_hybrid_threadpool_;
std::unique_ptr<ThreadPoolTempl<Env, false>> extended_eigen_normal_threadpool_;
std::unique_ptr<ThreadPoolTempl<Env> > extended_eigen_threadpool_;

// Force the thread pool to run in hybrid mode on a normal cpu.
bool force_hybrid_ = false;
Expand Down
38 changes: 13 additions & 25 deletions onnxruntime/core/common/threadpool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -389,23 +389,13 @@ ThreadPool::ThreadPool(Env* env,
assert(thread_options_.affinities.size() >= size_t(threads_to_create));
}

if (force_hybrid_) {
extended_eigen_hybrid_threadpool_ =
std::make_unique<ThreadPoolTempl<Env, true> >(name,
threads_to_create,
low_latency_hint,
*env,
thread_options_);
underlying_threadpool_ = extended_eigen_hybrid_threadpool_.get();
} else {
extended_eigen_normal_threadpool_ =
std::make_unique<ThreadPoolTempl<Env, false> >(name,
threads_to_create,
low_latency_hint,
*env,
thread_options_);
underlying_threadpool_ = extended_eigen_normal_threadpool_.get();
}
extended_eigen_threadpool_ =
std::make_unique<ThreadPoolTempl<Env> >(name,
threads_to_create,
low_latency_hint,
*env,
thread_options_);
underlying_threadpool_ = extended_eigen_threadpool_.get();
}
}

Expand Down Expand Up @@ -674,17 +664,15 @@ std::string ThreadPool::StopProfiling(concurrency::ThreadPool* tp) {
}

void ThreadPool::EnableSpinning() {
if (extended_eigen_hybrid_threadpool_)
extended_eigen_hybrid_threadpool_->EnableSpinning();
else if (extended_eigen_normal_threadpool_)
extended_eigen_normal_threadpool_->EnableSpinning();
if (extended_eigen_threadpool_) {
extended_eigen_threadpool_->EnableSpinning();
}
}

void ThreadPool::DisableSpinning() {
if (extended_eigen_hybrid_threadpool_)
extended_eigen_hybrid_threadpool_->DisableSpinning();
else if (extended_eigen_normal_threadpool_)
extended_eigen_normal_threadpool_->DisableSpinning();
if (extended_eigen_threadpool_) {
extended_eigen_threadpool_->DisableSpinning();
}
}

// Return the number of threads created by the pool.
Expand Down

0 comments on commit 46e73a2

Please sign in to comment.