Files
daggy/libdaggy/src/DAGRunner.cpp
2022-02-22 10:32:31 -04:00

238 lines
6.7 KiB
C++

#include <chrono>
#include <daggy/DAGRunner.hpp>
#include <mutex>
#include <stdexcept>
namespace daggy {
DAGRunner::DAGRunner(DAGRunID runID, executors::task::TaskExecutor &executor,
loggers::dag_run::DAGRunLogger &logger, TaskDAG dag,
const TaskParameters &taskParams)
: runID_(runID)
, executor_(executor)
, logger_(logger)
, dag_(dag)
, taskParams_(taskParams)
, running_(true)
, kill_(true)
, nRunningTasks_(0)
, nErroredTasks_(0)
{
}
DAGRunner::~DAGRunner()
{
std::lock_guard<std::mutex> lock(runGuard_);
}
TaskDAG DAGRunner::run()
{
kill_ = false;
running_ = true;
logger_.updateDAGRunState(runID_, RunState::RUNNING);
bool allVisited;
{
std::lock_guard<std::mutex> lock(runGuard_);
allVisited = dag_.allVisited();
}
while (!allVisited) {
{
std::lock_guard<std::mutex> runLock(runGuard_);
if (!running_ and kill_) {
killRunning();
}
collectFinished();
queuePending();
if (!running_ and (nRunningTasks_ - nErroredTasks_ <= 0)) {
logger_.updateDAGRunState(runID_, RunState::KILLED);
break;
}
if (nRunningTasks_ > 0 and nErroredTasks_ == nRunningTasks_) {
logger_.updateDAGRunState(runID_, RunState::ERRORED);
break;
}
}
std::this_thread::sleep_for(100ms);
{
std::lock_guard<std::mutex> lock(runGuard_);
allVisited = dag_.allVisited();
}
}
if (dag_.allVisited()) {
logger_.updateDAGRunState(runID_, RunState::COMPLETED);
}
running_ = false;
return dag_;
}
void DAGRunner::resetRunning()
{
if (running_)
throw std::runtime_error("Unable to reset while DAG is running.");
std::lock_guard<std::mutex> lock(runGuard_);
nRunningTasks_ = 0;
nErroredTasks_ = 0;
runningTasks_.clear();
taskAttemptCounts_.clear();
dag_.resetRunning();
}
void DAGRunner::killRunning()
{
for (const auto &[taskName, _] : runningTasks_) {
executor_.stop(runID_, taskName);
}
}
void DAGRunner::stopTask(const std::string &taskName)
{
executor_.stop(runID_, taskName);
}
void DAGRunner::queuePending()
{
if (!running_)
return;
const size_t MAX_SUBMITS = 100;
size_t n_submitted = 0;
/*
In cases where there are many tasks ready to execute,
the blocking nature of executor_.execute(...) means
that tasks will get executed and the futures resolved,
but we won't know until all pending tasks are enqueued.
To avoid this, submit at most MAX_SUBMITS tasks before
returning to allow completed tasks to be updated.
*/
while (n_submitted < MAX_SUBMITS) {
auto t = dag_.visitNext();
if (!t)
break;
auto &taskName = t.value().first;
taskAttemptCounts_[taskName] = 1;
logger_.updateTaskState(runID_, taskName, RunState::RUNNING);
try {
auto &task = t.value().second;
auto fut = executor_.execute(runID_, taskName, task);
runningTasks_.emplace(taskName, std::move(fut));
}
catch (std::exception &e) {
std::cout << "Unable to execute task: " << e.what() << std::endl;
}
++nRunningTasks_;
++n_submitted;
}
}
void DAGRunner::collectFinished()
{
std::vector<std::string> completedTasks;
for (auto &[taskName, fut] : runningTasks_) {
if (fut->ready()) {
auto attempt = fut->get();
logger_.logTaskAttempt(runID_, taskName, attempt);
completedTasks.push_back(taskName);
// Not a reference, since adding tasks will invalidate references
auto vert = dag_.getVertex(taskName);
auto &task = vert.data;
if (attempt.rc == 0) {
logger_.updateTaskState(runID_, taskName, RunState::COMPLETED);
if (task.isGenerator) {
// Parse the output and update the DAGs
try {
auto parsedTasks =
tasksFromJSON(attempt.outputLog, taskParams_.jobDefaults);
auto newTasks =
expandTaskSet(parsedTasks, executor_, taskParams_.variables);
updateDAGFromTasks(dag_, newTasks);
// Add in dependencies from current task to new tasks
for (const auto &[ntName, ntTask] : newTasks) {
logger_.addTask(runID_, ntName, ntTask);
task.children.insert(ntName);
}
// Efficiently add new edges from generator task
// to children
std::unordered_set<std::string> baseNames;
for (const auto &[k, v] : parsedTasks) {
baseNames.insert(v.definedName);
}
dag_.addEdgeIf(taskName, [&](const auto &v) {
return baseNames.count(v.data.definedName) > 0;
});
logger_.updateTask(runID_, taskName, task);
}
catch (std::exception &e) {
logger_.logTaskAttempt(
runID_, taskName,
AttemptRecord{
.executorLog =
std::string{"Failed to parse JSON output: "} +
e.what()});
logger_.updateTaskState(runID_, taskName, RunState::ERRORED);
++nErroredTasks_;
}
}
dag_.completeVisit(taskName);
--nRunningTasks_;
}
else {
// RC isn't 0
if (taskAttemptCounts_[taskName] <= task.maxRetries) {
logger_.updateTaskState(runID_, taskName, RunState::RETRY);
runningTasks_[taskName] = executor_.execute(runID_, taskName, task);
++taskAttemptCounts_[taskName];
}
else {
if (logger_.getTaskState(runID_, taskName) == +RunState::RUNNING or
logger_.getTaskState(runID_, taskName) == +RunState::RETRY) {
logger_.updateTaskState(runID_, taskName, RunState::ERRORED);
++nErroredTasks_;
}
else {
// Task was killed
--nRunningTasks_;
}
}
}
}
}
for (const auto &taskName : completedTasks) {
runningTasks_.extract(taskName);
}
}
void DAGRunner::stop(bool kill, bool blocking)
{
kill_ = kill;
running_ = false;
logger_.updateDAGRunState(runID_, RunState::KILLED);
if (blocking) {
while (true) {
{
std::lock_guard<std::mutex> lock(runGuard_);
if (nRunningTasks_ - nErroredTasks_ == 0)
break;
}
std::this_thread::sleep_for(250ms);
}
}
}
} // namespace daggy