Checkpointing work

This commit is contained in:
Ian Roddis
2021-07-05 11:57:38 -03:00
parent b7b8d5b6a1
commit 468993edb5
11 changed files with 125 additions and 51 deletions

View File

@@ -28,6 +28,7 @@ namespace daggy {
class Executor { class Executor {
public: public:
Executor() = default;
virtual const std::string getName() const = 0; virtual const std::string getName() const = 0;
// This will block if the executor is full // This will block if the executor is full

View File

@@ -24,12 +24,12 @@ namespace daggy {
}; };
public: public:
Scheduler(size_t schedulerThreads = 10); Scheduler(
Executor & executor
, size_t executorThreads = 30
, size_t schedulerThreads = 10);
// Register an executor ~Scheduler();
void registerExecutor(std::shared_ptr<Executor> executor
, size_t maxParallelTasks
);
// returns DagRun ID // returns DagRun ID
void scheduleDAG(std::string runName void scheduleDAG(std::string runName
@@ -44,17 +44,11 @@ namespace daggy {
// get the current DAG // get the current DAG
DAG dagRunState(); DAG dagRunState();
// Complete running DAGs and shutdown
void drain();
private: private:
struct ExecutionPool {
std::shared_ptr<Executor> executor;
ThreadPool workers;
// taskid -> result
std::unordered_map<std::string, std::future<void>> jobs;
};
struct DAGRun { struct DAGRun {
std::vector<Task> tasks; std::vector<Task> tasks;
std::unordered_map<std::string, ParameterValue> parameters; std::unordered_map<std::string, ParameterValue> parameters;
@@ -63,11 +57,14 @@ namespace daggy {
std::mutex taskGuard_; std::mutex taskGuard_;
}; };
void runDAG(DAGRun & dagRun); void runDAG(const std::string & name, DAGRun & dagRun);
std::vector<AttemptRecord> runTask(const Task & task);
std::unordered_map<std::string, ExecutionPool> executorPools_;
std::unordered_map<std::string, DAGRun> runs_; std::unordered_map<std::string, DAGRun> runs_;
Executor & executor_;
ThreadPool schedulers_; ThreadPool schedulers_;
ThreadPool executorWorkers_;
std::unordered_map<std::string, std::future<void>> jobs;
std::mutex mtx_; std::mutex mtx_;
std::condition_variable cv_; std::condition_variable cv_;
}; };

View File

@@ -4,7 +4,7 @@
#include <pistache/endpoint.h> #include <pistache/endpoint.h>
#include <pistache/http.h> #include <pistache/http.h>
#include <pistache/thirdparty/serializer/rapidjson.h> // #include <pistache/thirdparty/serializer/rapidjson.h>
namespace daggy { namespace daggy {
class Server { class Server {

View File

@@ -7,7 +7,7 @@
namespace daggy { namespace daggy {
struct Task { struct Task {
std::string name; std::string name;
std::string command; std::vector<std::string> command;
uint8_t max_retries; uint8_t max_retries;
uint32_t retry_interval_seconds; // Time to wait between retries uint32_t retry_interval_seconds; // Time to wait between retries
std::vector<std::string> children; std::vector<std::string> children;

View File

@@ -29,6 +29,9 @@ namespace daggy {
void shutdown(); void shutdown();
std::future<void> addTask(AsyncTask fn); std::future<void> addTask(AsyncTask fn);
size_t queueSize();
private: private:
using QueuedAsyncTask = std::shared_ptr<std::packaged_task<void()>>; using QueuedAsyncTask = std::shared_ptr<std::packaged_task<void()>>;

View File

@@ -5,8 +5,9 @@
namespace daggy { namespace daggy {
namespace executor { namespace executor {
class ForkingExecutor : Executor { class ForkingExecutor : public Executor {
public: public:
ForkingExecutor() = default;
const std::string getName() const override { return "ForkingExecutor"; } const std::string getName() const override { return "ForkingExecutor"; }
AttemptRecord runCommand(std::vector<std::string> cmd) override; AttemptRecord runCommand(std::vector<std::string> cmd) override;

View File

@@ -1,4 +1,5 @@
#include <daggy/DAG.hpp> #include <daggy/DAG.hpp>
#include <stdexcept>
namespace daggy { namespace daggy {
size_t DAG::size() const { return vertices_.size(); } size_t DAG::size() const { return vertices_.size(); }
@@ -6,22 +7,26 @@ namespace daggy {
size_t DAG::addVertex() { size_t DAG::addVertex() {
vertices_.push_back(Vertex{.state = VertexState::UNVISITED, .depCount = 0}); vertices_.push_back(Vertex{.state = VertexState::UNVISITED, .depCount = 0});
return vertices_.size(); return vertices_.size() - 1;
} }
void DAG::dropEdge(const size_t from, const size_t to) { void DAG::dropEdge(const size_t from, const size_t to) {
if (from >= vertices_.size()) throw std::runtime_error("No such vertex " + std::to_string(from));
if (to >= vertices_.size()) throw std::runtime_error("No such vertex " + std::to_string(to));
vertices_[from].children.extract(to); vertices_[from].children.extract(to);
} }
void DAG::addEdge(const size_t from, const size_t to) { void DAG::addEdge(const size_t from, const size_t to) {
if (from >= vertices_.size()) throw std::runtime_error("No such vertex " + std::to_string(from));
if (to >= vertices_.size()) throw std::runtime_error("No such vertex " + std::to_string(to));
if (hasPath(to, from)) if (hasPath(to, from))
throw std::runtime_error("Adding edge would result in a cycle"); throw std::runtime_error("Adding edge would result in a cycle");
vertices_[from].children.insert(to); vertices_[from].children.insert(to);
} }
bool DAG::hasPath(const size_t from, const size_t to) const { bool DAG::hasPath(const size_t from, const size_t to) const {
bool pathFound = false; if (from >= vertices_.size()) throw std::runtime_error("No such vertex " + std::to_string(from));
if (to >= vertices_.size()) throw std::runtime_error("No such vertex " + std::to_string(to));
for (const auto & child : vertices_[from].children) { for (const auto & child : vertices_[from].children) {
if (child == to) return true; if (child == to) return true;
if (hasPath(child, to)) return true; if (hasPath(child, to)) return true;

View File

@@ -1,17 +1,20 @@
#include <daggy/Scheduler.hpp> #include <daggy/Scheduler.hpp>
using namespace std::chrono_literals;
namespace daggy { namespace daggy {
Scheduler::Scheduler(size_t schedulerThreads = 10) Scheduler::Scheduler(Executor & executor
: schedulers_(schedulerThreads) , size_t executorThreads
, size_t schedulerThreads)
: executor_(executor)
, schedulers_(schedulerThreads)
, executorWorkers_(executorThreads)
{ } { }
void Scheduler::registerExecutor(std::shared_ptr<Executor> executor, size_t maxParallelTasks) {
executorPools_.emplace(executor->getName() Scheduler::~Scheduler() {
, ExecutionPool{ executorWorkers_.shutdown();
.executor = executor schedulers_.shutdown();
, .workers = ThreadPool{maxParallelTasks}
, .jobs = {}
});
} }
void Scheduler::scheduleDAG(std::string runName void Scheduler::scheduleDAG(std::string runName
@@ -39,33 +42,79 @@ namespace daggy {
} }
// Create the DAGRun // Create the DAGRun
DAGRun run{
.tasks = tasks
, .parameters = parameters
, .dag = dag
, .taskRuns = TaskRun(tasks.size())
};
{ {
std::lock_guard<std::mutex> guard(mtx_); std::lock_guard<std::mutex> guard(mtx_);
runs_.emplace(runName, std::move(run));
auto & dr = runs_[runName]; auto & dr = runs_[runName];
schedulers_.addTask([&]() { runDAG(dr); });
dr.tasks = tasks;
dr.parameters = parameters;
dr.dag = dag;
dr.taskRuns = std::vector<TaskRun>{tasks.size()};
schedulers_.addTask([&]() { runDAG(runName, dr); });
} }
} }
void Scheduler::runDAG(DAGRun & run) void Scheduler::runDAG(const std::string & name, DAGRun & run)
{ {
using namespace std::chrono_literals; std::unordered_map<size_t, std::future<std::vector<AttemptRecord>>> tasks;
std::cout << "Running dag " << name << std::endl;
while (! run.dag.allVisited()) { while (! run.dag.allVisited()) {
// Check for any completed tasks // Check for any completed tasks
for (auto & [tid, fut] : tasks) {
std::cout << "Checking tid " << tid << std::endl;
if (fut.valid()) {
auto ars = fut.get();
if (ars.back().rc == 0) {
std::cout << "Completing node " << tid << std::endl;
run.dag.completeVisit(tid);
}
}
}
// Get the next dag to run
auto t = run.dag.visitNext(); auto t = run.dag.visitNext();
if (! t.has_value()) { while (t.has_value()) {
std::cout << "Scheduling " << t.value() << std::endl;
// Schedule the task to run
std::packaged_task<std::vector<AttemptRecord>()> node([&]() {
return runTask(run.tasks[t.value()]);
});
tasks.emplace(t.value(), node.get_future());
node();
//
auto nt = run.dag.visitNext();
if (not nt.has_value()) break;
t.emplace(nt.value());
}
std::cout << "sleeping" << std::endl;
std::this_thread::sleep_for(250ms);
}
}
std::vector<AttemptRecord>
Scheduler::runTask(const Task & task) {
std::vector<AttemptRecord> attempts;
while (attempts.size() < task.max_retries) {
auto fut = executorWorkers_.addTask([&]() {
attempts.push_back(executor_.runCommand(task.command));
});
fut.get();
if (attempts.back().rc == 0) break;
}
return attempts;
}
void Scheduler::drain() {
while (schedulers_.queueSize()) {
std::this_thread::sleep_for(250ms); std::this_thread::sleep_for(250ms);
continue;
}
} }
} }
} }

View File

@@ -34,7 +34,7 @@ void ThreadPool::shutdown() {
cv_.notify_all(); cv_.notify_all();
for (auto & w : workers_) { for (auto & w : workers_) {
w.join(); if (w.joinable()) w.join();
} }
} }
@@ -48,3 +48,8 @@ std::future<void> ThreadPool::addTask(std::function<void()> fn) {
cv_.notify_one(); cv_.notify_one();
return result; return result;
} }
size_t ThreadPool::queueSize() {
std::unique_lock<std::mutex> lk(guard_);
return taskQueue_.size();
}

View File

@@ -33,7 +33,6 @@ std::string slurp(int fd) {
return result; return result;
} }
daggy::AttemptRecord daggy::AttemptRecord
ForkingExecutor::runCommand(std::vector<std::string> cmd) ForkingExecutor::runCommand(std::vector<std::string> cmd)
{ {

View File

@@ -21,6 +21,20 @@ TEST_CASE("DAG Construction Tests", "[dag]") {
// Cannot add an edge that would result in a cycle // Cannot add an edge that would result in a cycle
REQUIRE_THROWS(dag.addEdge(9, 5)); REQUIRE_THROWS(dag.addEdge(9, 5));
// Bounds checking
SECTION("addEdge Bounds Checking") {
REQUIRE_THROWS(dag.addEdge(20, 0));
REQUIRE_THROWS(dag.addEdge(0, 20));
}
SECTION("dropEdge Bounds Checking") {
REQUIRE_THROWS(dag.dropEdge(20, 0));
REQUIRE_THROWS(dag.dropEdge(0, 20));
}
SECTION("hasPath Bounds Checking") {
REQUIRE_THROWS(dag.hasPath(20, 0));
REQUIRE_THROWS(dag.hasPath(0, 20));
}
} }
TEST_CASE("DAG Traversal Tests", "[dag]") { TEST_CASE("DAG Traversal Tests", "[dag]") {