diff --git a/daggy/include/daggy/Scheduler.hpp b/daggy/include/daggy/Scheduler.hpp index 80ff82c..5be9bc7 100644 --- a/daggy/include/daggy/Scheduler.hpp +++ b/daggy/include/daggy/Scheduler.hpp @@ -7,24 +7,68 @@ #include "DAG.hpp" #include "Executor.hpp" +#include "ThreadPool.hpp" namespace daggy { using ParameterValue = std::variant>; + using TaskRun = std::vector; class Scheduler { public: + enum class DAGState : uint32_t { + UNKNOWN = 0, + QUEUED, + RUNNING, + ERRORED, + COMPLETE + }; + + public: + Scheduler(size_t schedulerThreads = 10); + // Register an executor void registerExecutor(std::shared_ptr executor , size_t maxParallelTasks ); - void runTasks(std::vector tasks + + // returns DagRun ID + void scheduleDAG(std::string runName + , std::vector tasks , std::unordered_map parameters , DAG dag = {} // Allows for loading of an existing DAG ); + // get the current status of a DAG + DAGState dagRunStatus(std::string runName); + + // get the current DAG + DAG dagRunState(); + private: - std::unordered_map> executors_; - std::unordered_map>> jobs_; - std::unordered_map parameters_; + + + struct ExecutionPool { + std::shared_ptr executor; + ThreadPool workers; + + // taskid -> result + std::unordered_map> jobs; + }; + + struct DAGRun { + std::vector tasks; + std::unordered_map parameters; + DAG dag; + std::vector taskRuns; + std::mutex taskGuard_; + }; + + void runDAG(DAGRun & dagRun); + + std::unordered_map executorPools_; + std::unordered_map runs_; + ThreadPool schedulers_; + std::mutex mtx_; + std::condition_variable cv_; }; } diff --git a/daggy/src/Scheduler.cpp b/daggy/src/Scheduler.cpp index 0490484..8e01e1b 100644 --- a/daggy/src/Scheduler.cpp +++ b/daggy/src/Scheduler.cpp @@ -1,7 +1,71 @@ #include namespace daggy { + Scheduler::Scheduler(size_t schedulerThreads = 10) + : schedulers_(schedulerThreads) + { } + void Scheduler::registerExecutor(std::shared_ptr executor, size_t maxParallelTasks) { - executors_.emplace(executor->getName(), executor); + executorPools_.emplace(executor->getName() + , ExecutionPool{ + .executor = executor + , .workers = ThreadPool{maxParallelTasks} + , .jobs = {} + }); + } + + void Scheduler::scheduleDAG(std::string runName + , std::vector tasks + , std::unordered_map parameters + , DAG dag + ) + { + // Initialize the dag + if (dag.empty()) { + std::unordered_map tids; + + // Add all the vertices + for (size_t i = 0; i < tasks.size(); ++i) { + tids[tasks[i].name] = dag.addVertex(); + } + + // Add edges + for (size_t i = 0; i < tasks.size(); ++i) { + for (const auto & c : tasks[i].children) { + dag.addEdge(i, tids[c]); + } + } + dag.reset(); + } + + // Create the DAGRun + DAGRun run{ + .tasks = tasks + , .parameters = parameters + , .dag = dag + , .taskRuns = TaskRun(tasks.size()) + }; + + { + std::lock_guard guard(mtx_); + runs_.emplace(runName, std::move(run)); + auto & dr = runs_[runName]; + schedulers_.addTask([&]() { runDAG(dr); }); + } + } + + void Scheduler::runDAG(DAGRun & run) + { + using namespace std::chrono_literals; + + while (! run.dag.allVisited()) { + // Check for any completed tasks + + auto t = run.dag.visitNext(); + if (! t.has_value()) { + std::this_thread::sleep_for(250ms); + continue; + } + } } }