diff --git a/daggy/include/daggy/Scheduler.hpp b/daggy/include/daggy/Scheduler.hpp index 52fada7..f81f79e 100644 --- a/daggy/include/daggy/Scheduler.hpp +++ b/daggy/include/daggy/Scheduler.hpp @@ -32,7 +32,8 @@ namespace daggy { ~Scheduler(); // returns DagRun ID - void scheduleDAG(std::string runName + std::future + scheduleDAG(std::string runName , std::vector tasks , std::unordered_map parameters , DAG dag = {} // Allows for loading of an existing DAG @@ -61,6 +62,7 @@ namespace daggy { std::vector runTask(const Task & task); std::unordered_map runs_; + std::vector> futs_; Executor & executor_; ThreadPool schedulers_; ThreadPool executors_; diff --git a/daggy/include/daggy/ThreadPool.hpp b/daggy/include/daggy/ThreadPool.hpp index 9dd434e..389d6db 100644 --- a/daggy/include/daggy/ThreadPool.hpp +++ b/daggy/include/daggy/ThreadPool.hpp @@ -10,6 +10,8 @@ #include #include +using namespace std::chrono_literals; + namespace daggy { /* @@ -66,6 +68,7 @@ namespace daggy { : tqit_(taskQueues_.begin()) , stop_(false) + , drain_(false) { resize(nWorkers); } @@ -82,35 +85,45 @@ namespace daggy { } void drain() { - resize(workers_.size()); + drain_ = true; + while (true) { + { + std::lock_guard guard(mtx_); + if (taskQueues_.empty()) break; + } + std::this_thread::sleep_for(250ms); + } + } + + void restart() { + drain_ = false; } void resize(size_t nWorkers) { shutdown(); workers_.clear(); + stop_ = false; for(size_t i = 0;i< nWorkers;++i) workers_.emplace_back( [&] { - for(;;) { + while (true) { std::packaged_task task; { std::unique_lock lock(mtx_); - cv_.wait(lock, [&]{ return stop_ || !taskQueues_.empty(); }); + cv_.wait(lock, [&]{ return stop_ || ! taskQueues_.empty(); }); if(taskQueues_.empty()) { if(stop_) return; continue; } if (tqit_ == taskQueues_.end()) tqit_ = taskQueues_.begin(); - if (not (*tqit_)->empty()) { - task = std::move((*tqit_)->pop()); - if ((*tqit_)->empty()) { - tqit_ = taskQueues_.erase(tqit_); - } else { - tqit_++; - } + task = std::move((*tqit_)->pop()); + if ((*tqit_)->empty()) { + tqit_ = taskQueues_.erase(tqit_); + } else { + tqit_++; } } - task(); + task(); } } ); @@ -118,6 +131,7 @@ namespace daggy { template decltype(auto) addTask(F&& f, Args&&... args) { + if (drain_) throw std::runtime_error("Unable to add task to draining pool"); auto tq = std::make_shared(); auto fut = tq->addTask(f, args...); @@ -130,7 +144,8 @@ namespace daggy { return fut; } - void addTaskQueue(std::shared_ptr tq) { + void addTasks(std::shared_ptr tq) { + if (drain_) throw std::runtime_error("Unable to add task to draining pool"); std::lock_guard guard(mtx_); taskQueues_.push_back(tq); cv_.notify_one(); @@ -140,7 +155,6 @@ namespace daggy { // need to keep track of threads so we can join them std::vector< std::thread > workers_; // the task queue - std::queue< std::packaged_task > tasks_; std::list> taskQueues_; std::list>::iterator tqit_; @@ -148,6 +162,7 @@ namespace daggy { std::mutex mtx_; std::condition_variable cv_; std::atomic stop_; + std::atomic drain_; }; } diff --git a/daggy/src/Scheduler.cpp b/daggy/src/Scheduler.cpp index 295f8a0..aba1bc6 100644 --- a/daggy/src/Scheduler.cpp +++ b/daggy/src/Scheduler.cpp @@ -17,7 +17,8 @@ namespace daggy { schedulers_.shutdown(); } - void Scheduler::scheduleDAG(std::string runName + std::future + Scheduler::scheduleDAG(std::string runName , std::vector tasks , std::unordered_map parameters , DAG dag @@ -42,17 +43,16 @@ namespace daggy { } // Create the DAGRun - { - std::lock_guard guard(mtx_); - auto & dr = runs_[runName]; + std::lock_guard guard(mtx_); + auto & dr = runs_[runName]; - dr.tasks = tasks; - dr.parameters = parameters; - dr.dag = dag; - dr.taskRuns = std::vector{tasks.size()}; + dr.tasks = tasks; + dr.parameters = parameters; + dr.dag = dag; + dr.taskRuns = std::vector{tasks.size()}; - schedulers_.addTask([&]() { runDAG(runName, dr); }); - } + // return std::move(schedulers_.addTask([&]() { runDAG(runName, dr); })); + return std::move(schedulers_.addTask([&]() { runDAG(runName, dr); })); } void Scheduler::runDAG(const std::string & name, DAGRun & run) @@ -65,20 +65,15 @@ namespace daggy { std::vector tasks; - std::cout << "Running dag " << name << std::endl; while (! run.dag.allVisited()) { // Check for any completed tasks - std::cout << "Polling completed" << std::endl; for (auto & task : tasks) { if (task.complete) continue; if (task.fut.valid()) { - std::cout << "Checking tid " << task.tid << std::endl; auto ars = task.fut.get(); - std::cout << "Got" << std::endl; if (ars.back().rc == 0) { - std::cout << "Completing node " << task.tid << std::endl; run.dag.completeVisit(task.tid); } task.complete = true; @@ -86,10 +81,8 @@ namespace daggy { } // Get the next dag to run - std::cout << "Polling scheduling" << std::endl; auto t = run.dag.visitNext(); while (t.has_value()) { - std::cout << "Scheduling " << t.value() << std::endl; // Schedule the task to run Task tsk{ .tid = t.value() , .fut = executors_.addTask([&](){return runTask(run.tasks[t.value()]);}) @@ -103,7 +96,6 @@ namespace daggy { t.emplace(nt.value()); } - std::cout << "sleeping" << std::endl; std::this_thread::sleep_for(250ms); } } @@ -113,7 +105,6 @@ namespace daggy { std::vector attempts; while (attempts.size() < task.max_retries) { - std::cout << "Attempt " << attempts.size() << ": Running " << task.command.front() << std::endl; attempts.push_back(executor_.runCommand(task.command)); if (attempts.back().rc == 0) break; } diff --git a/tests/unit_scheduler.cpp b/tests/unit_scheduler.cpp index 780cf99..c2aa450 100644 --- a/tests/unit_scheduler.cpp +++ b/tests/unit_scheduler.cpp @@ -17,7 +17,7 @@ TEST_CASE("Basic Scheduler Execution", "[scheduler]") { }; SECTION("Simple Run") { - sched.scheduleDAG("Simple", tasks, {}); - sched.drain(); + auto fut = sched.scheduleDAG("Simple", tasks, {}); + fut.get(); } } diff --git a/tests/unit_threadpool.cpp b/tests/unit_threadpool.cpp index a836fdd..0c64a26 100644 --- a/tests/unit_threadpool.cpp +++ b/tests/unit_threadpool.cpp @@ -13,12 +13,12 @@ TEST_CASE("Threadpool Construction", "[threadpool]") { std::vector> rets; - SECTION("Simple runs") { + SECTION("Adding large tasks queues with return values") { auto tq = std::make_shared(); std::vector> res; for (size_t i = 0; i < 100; ++i) res.emplace_back(std::move(tq->addTask([&cnt]() { cnt++; return cnt.load(); }))); - tp.addTaskQueue(tq); + tp.addTasks(tq); for (auto & r : res) r.get(); REQUIRE(cnt == 100); }