From 9b9409d504f79daa13e07076962d095e6944b1de Mon Sep 17 00:00:00 2001 From: Ian Roddis Date: Mon, 5 Jul 2021 15:37:29 -0300 Subject: [PATCH] Things mostly work, just a strange hang when executing code with forking executor --- .gitignore | 1 + daggy/include/daggy/Scheduler.hpp | 2 +- daggy/include/daggy/ThreadPool.hpp | 159 ++++++++++++++++++++++++----- daggy/src/Scheduler.cpp | 53 +++++----- daggy/src/ThreadPool.cpp | 55 ---------- tests/unit_scheduler.cpp | 23 +++++ tests/unit_threadpool.cpp | 9 +- 7 files changed, 195 insertions(+), 107 deletions(-) delete mode 100644 daggy/src/ThreadPool.cpp create mode 100644 tests/unit_scheduler.cpp diff --git a/.gitignore b/.gitignore index 378eac2..9785597 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ build +.cache diff --git a/daggy/include/daggy/Scheduler.hpp b/daggy/include/daggy/Scheduler.hpp index 83f7cc4..52fada7 100644 --- a/daggy/include/daggy/Scheduler.hpp +++ b/daggy/include/daggy/Scheduler.hpp @@ -63,7 +63,7 @@ namespace daggy { std::unordered_map runs_; Executor & executor_; ThreadPool schedulers_; - ThreadPool executorWorkers_; + ThreadPool executors_; std::unordered_map> jobs; std::mutex mtx_; std::condition_variable cv_; diff --git a/daggy/include/daggy/ThreadPool.hpp b/daggy/include/daggy/ThreadPool.hpp index c24fced..9dd434e 100644 --- a/daggy/include/daggy/ThreadPool.hpp +++ b/daggy/include/daggy/ThreadPool.hpp @@ -6,39 +6,148 @@ #include #include #include -#include - -/* -TODO: There's probably a better implementation of this at - - https://github.com/vit-vit/CTPL/blob/master/ctpl_stl.h - - but for now assume that all work is done using closures. -*/ +#include +#include +#include namespace daggy { - using AsyncTask = std::function; - class ThreadPool { + + /* + A Task Queue is a collection of async tasks to be executed by the + thread pool. Using individual task queues allows for a rough QoS + when a single thread may be submitting batches of requests -- + one producer won't starve out another, but all tasks will be run + as quickly as possible. + */ + class TaskQueue { public: - ThreadPool(size_t nWorkers); - ~ThreadPool(); + template + decltype(auto) addTask(F&& f, Args&&... args) { + // using return_type = std::invoke_result::type; + using return_type = std::invoke_result_t; - ThreadPool(const ThreadPool & other) = delete; - ThreadPool(ThreadPool & other) = delete; + std::packaged_task task( + std::bind(std::forward(f), std::forward(args)...) + ); - void shutdown(); + std::future res = task.get_future(); + { + std::lock_guard guard(mtx_); + tasks_.emplace(std::move(task)); + } + return res; + } - std::future addTask(AsyncTask fn); + std::packaged_task pop() { + std::lock_guard guard(mtx_); + auto task = std::move(tasks_.front()); + tasks_.pop(); + return task; + } - size_t queueSize(); + size_t size() { + std::lock_guard guard(mtx_); + return tasks_.size(); + } + + bool empty() { + std::lock_guard guard(mtx_); + return tasks_.empty(); + } private: - using QueuedAsyncTask = std::shared_ptr>; - - std::atomic shutdown_; - std::mutex guard_; - std::condition_variable cv_; - std::vector workers_; - std::deque taskQueue_; + std::queue< std::packaged_task > tasks_; + std::mutex mtx_; }; + + class ThreadPool { + public: + explicit ThreadPool(size_t nWorkers) + : + tqit_(taskQueues_.begin()) + , stop_(false) + { + resize(nWorkers); + } + + ~ThreadPool() { shutdown(); } + + void shutdown() { + stop_ = true; + cv_.notify_all(); + for (std::thread& worker : workers_) { + if (worker.joinable()) + worker.join(); + } + } + + void drain() { + resize(workers_.size()); + } + + void resize(size_t nWorkers) { + shutdown(); + workers_.clear(); + + for(size_t i = 0;i< nWorkers;++i) + workers_.emplace_back( [&] { + for(;;) { + std::packaged_task task; + { + std::unique_lock lock(mtx_); + cv_.wait(lock, [&]{ return stop_ || !taskQueues_.empty(); }); + if(taskQueues_.empty()) { + if(stop_) return; + continue; + } + if (tqit_ == taskQueues_.end()) tqit_ = taskQueues_.begin(); + if (not (*tqit_)->empty()) { + task = std::move((*tqit_)->pop()); + if ((*tqit_)->empty()) { + tqit_ = taskQueues_.erase(tqit_); + } else { + tqit_++; + } + } + } + task(); + } + } + ); + }; + + template + decltype(auto) addTask(F&& f, Args&&... args) { + auto tq = std::make_shared(); + + auto fut = tq->addTask(f, args...); + + { + std::lock_guard guard(mtx_); + taskQueues_.push_back(tq); + } + cv_.notify_one(); + return fut; + } + + void addTaskQueue(std::shared_ptr tq) { + std::lock_guard guard(mtx_); + taskQueues_.push_back(tq); + cv_.notify_one(); + } + + private: + // need to keep track of threads so we can join them + std::vector< std::thread > workers_; + // the task queue + std::queue< std::packaged_task > tasks_; + std::list> taskQueues_; + std::list>::iterator tqit_; + + // synchronization + std::mutex mtx_; + std::condition_variable cv_; + std::atomic stop_; + }; + } diff --git a/daggy/src/Scheduler.cpp b/daggy/src/Scheduler.cpp index 0260470..295f8a0 100644 --- a/daggy/src/Scheduler.cpp +++ b/daggy/src/Scheduler.cpp @@ -8,12 +8,12 @@ namespace daggy { , size_t schedulerThreads) : executor_(executor) , schedulers_(schedulerThreads) - , executorWorkers_(executorThreads) + , executors_(executorThreads) { } Scheduler::~Scheduler() { - executorWorkers_.shutdown(); + executors_.shutdown(); schedulers_.shutdown(); } @@ -57,34 +57,45 @@ namespace daggy { void Scheduler::runDAG(const std::string & name, DAGRun & run) { - std::unordered_map>> tasks; + struct Task { + size_t tid; + std::future> fut; + bool complete; + }; + + std::vector tasks; std::cout << "Running dag " << name << std::endl; while (! run.dag.allVisited()) { + // Check for any completed tasks - for (auto & [tid, fut] : tasks) { - std::cout << "Checking tid " << tid << std::endl; - if (fut.valid()) { - auto ars = fut.get(); + std::cout << "Polling completed" << std::endl; + for (auto & task : tasks) { + if (task.complete) continue; + + if (task.fut.valid()) { + std::cout << "Checking tid " << task.tid << std::endl; + auto ars = task.fut.get(); + std::cout << "Got" << std::endl; if (ars.back().rc == 0) { - std::cout << "Completing node " << tid << std::endl; - run.dag.completeVisit(tid); + std::cout << "Completing node " << task.tid << std::endl; + run.dag.completeVisit(task.tid); } + task.complete = true; } } // Get the next dag to run + std::cout << "Polling scheduling" << std::endl; auto t = run.dag.visitNext(); while (t.has_value()) { std::cout << "Scheduling " << t.value() << std::endl; // Schedule the task to run - - std::packaged_task()> node([&]() { - return runTask(run.tasks[t.value()]); - }); - - tasks.emplace(t.value(), node.get_future()); - node(); + Task tsk{ .tid = t.value() + , .fut = executors_.addTask([&](){return runTask(run.tasks[t.value()]);}) + , .complete = false + }; + tasks.push_back(std::move(tsk)); // auto nt = run.dag.visitNext(); @@ -102,10 +113,8 @@ namespace daggy { std::vector attempts; while (attempts.size() < task.max_retries) { - auto fut = executorWorkers_.addTask([&]() { - attempts.push_back(executor_.runCommand(task.command)); - }); - fut.get(); + std::cout << "Attempt " << attempts.size() << ": Running " << task.command.front() << std::endl; + attempts.push_back(executor_.runCommand(task.command)); if (attempts.back().rc == 0) break; } @@ -113,8 +122,6 @@ namespace daggy { } void Scheduler::drain() { - while (schedulers_.queueSize()) { - std::this_thread::sleep_for(250ms); - } + schedulers_.drain(); } } diff --git a/daggy/src/ThreadPool.cpp b/daggy/src/ThreadPool.cpp deleted file mode 100644 index 6168d3d..0000000 --- a/daggy/src/ThreadPool.cpp +++ /dev/null @@ -1,55 +0,0 @@ -#include - -using namespace daggy; - -ThreadPool::ThreadPool(size_t nWorkers) { - shutdown_ = false; - std::lock_guard lk(guard_); - for (size_t i = 0; i < nWorkers; ++i) { - workers_.emplace_back([&]() { - while (true) { - QueuedAsyncTask tsk; - { - std::unique_lock lk(guard_); - cv_.wait(lk, []{ return true; }); - if (shutdown_) return; - if (taskQueue_.empty()) continue; - - tsk = taskQueue_.front(); - taskQueue_.pop_front(); - } - - (*tsk)(); - } - }); - } -} - -ThreadPool::~ThreadPool() { - shutdown(); -} - -void ThreadPool::shutdown() { - shutdown_ = true; - cv_.notify_all(); - - for (auto & w : workers_) { - if (w.joinable()) w.join(); - } -} - -std::future ThreadPool::addTask(std::function fn) { - auto task = std::make_shared>(fn); - std::future result = task->get_future(); - { - std::unique_lock lk(guard_); - taskQueue_.push_back(task); - } - cv_.notify_one(); - return result; -} - -size_t ThreadPool::queueSize() { - std::unique_lock lk(guard_); - return taskQueue_.size(); -} diff --git a/tests/unit_scheduler.cpp b/tests/unit_scheduler.cpp new file mode 100644 index 0000000..780cf99 --- /dev/null +++ b/tests/unit_scheduler.cpp @@ -0,0 +1,23 @@ +#include +#include + +#include "daggy/executors/ForkingExecutor.hpp" +#include "daggy/Scheduler.hpp" + +#include "catch.hpp" + +TEST_CASE("Basic Scheduler Execution", "[scheduler]") { + daggy::executor::ForkingExecutor ex; + daggy::Scheduler sched(ex); + + std::vector tasks { + daggy::Task{ "task_a", { "/usr/bin/echo", "task_a"}, 3, 30, { "task_c"} } + , daggy::Task{ "task_b", { "/usr/bin/echo", "task_b"}, 3, 30, { "task_c" } } + , daggy::Task{ "task_c", { "/usr/bin/echo", "task_c"}, 3, 30, {} } + }; + + SECTION("Simple Run") { + sched.scheduleDAG("Simple", tasks, {}); + sched.drain(); + } +} diff --git a/tests/unit_threadpool.cpp b/tests/unit_threadpool.cpp index 7163009..a836fdd 100644 --- a/tests/unit_threadpool.cpp +++ b/tests/unit_threadpool.cpp @@ -11,21 +11,24 @@ TEST_CASE("Threadpool Construction", "[threadpool]") { std::atomic cnt(0); ThreadPool tp(10); - std::vector> res; + std::vector> rets; SECTION("Simple runs") { + auto tq = std::make_shared(); + std::vector> res; for (size_t i = 0; i < 100; ++i) - res.push_back(tp.addTask([&cnt]() { cnt++; return; })); + res.emplace_back(std::move(tq->addTask([&cnt]() { cnt++; return cnt.load(); }))); + tp.addTaskQueue(tq); for (auto & r : res) r.get(); REQUIRE(cnt == 100); } SECTION("Slow runs") { + std::vector> res; using namespace std::chrono_literals; for (size_t i = 0; i < 100; ++i) res.push_back(tp.addTask([&cnt]() { std::this_thread::sleep_for(20ms); cnt++; return; })); for (auto & r : res) r.get(); REQUIRE(cnt == 100); } - }