238 lines
6.7 KiB
C++
238 lines
6.7 KiB
C++
#include <chrono>
|
|
#include <daggy/DAGRunner.hpp>
|
|
#include <mutex>
|
|
#include <stdexcept>
|
|
|
|
namespace daggy {
|
|
DAGRunner::DAGRunner(DAGRunID runID, executors::task::TaskExecutor &executor,
|
|
loggers::dag_run::DAGRunLogger &logger, TaskDAG dag,
|
|
const TaskParameters &taskParams)
|
|
: runID_(runID)
|
|
, executor_(executor)
|
|
, logger_(logger)
|
|
, dag_(dag)
|
|
, taskParams_(taskParams)
|
|
, running_(true)
|
|
, kill_(true)
|
|
, nRunningTasks_(0)
|
|
, nErroredTasks_(0)
|
|
{
|
|
}
|
|
|
|
DAGRunner::~DAGRunner()
|
|
{
|
|
std::lock_guard<std::mutex> lock(runGuard_);
|
|
}
|
|
|
|
TaskDAG DAGRunner::run()
|
|
{
|
|
kill_ = false;
|
|
running_ = true;
|
|
logger_.updateDAGRunState(runID_, RunState::RUNNING);
|
|
|
|
bool allVisited;
|
|
{
|
|
std::lock_guard<std::mutex> lock(runGuard_);
|
|
allVisited = dag_.allVisited();
|
|
}
|
|
while (!allVisited) {
|
|
{
|
|
std::lock_guard<std::mutex> runLock(runGuard_);
|
|
if (!running_ and kill_) {
|
|
killRunning();
|
|
}
|
|
collectFinished();
|
|
queuePending();
|
|
|
|
if (!running_ and (nRunningTasks_ - nErroredTasks_ <= 0)) {
|
|
logger_.updateDAGRunState(runID_, RunState::KILLED);
|
|
break;
|
|
}
|
|
|
|
if (nRunningTasks_ > 0 and nErroredTasks_ == nRunningTasks_) {
|
|
logger_.updateDAGRunState(runID_, RunState::ERRORED);
|
|
break;
|
|
}
|
|
}
|
|
|
|
std::this_thread::sleep_for(100ms);
|
|
{
|
|
std::lock_guard<std::mutex> lock(runGuard_);
|
|
allVisited = dag_.allVisited();
|
|
}
|
|
}
|
|
|
|
if (dag_.allVisited()) {
|
|
logger_.updateDAGRunState(runID_, RunState::COMPLETED);
|
|
}
|
|
|
|
running_ = false;
|
|
return dag_;
|
|
}
|
|
|
|
void DAGRunner::resetRunning()
|
|
{
|
|
if (running_)
|
|
throw std::runtime_error("Unable to reset while DAG is running.");
|
|
|
|
std::lock_guard<std::mutex> lock(runGuard_);
|
|
nRunningTasks_ = 0;
|
|
nErroredTasks_ = 0;
|
|
runningTasks_.clear();
|
|
taskAttemptCounts_.clear();
|
|
dag_.resetRunning();
|
|
}
|
|
|
|
void DAGRunner::killRunning()
|
|
{
|
|
for (const auto &[taskName, _] : runningTasks_) {
|
|
executor_.stop(runID_, taskName);
|
|
}
|
|
}
|
|
|
|
void DAGRunner::stopTask(const std::string &taskName)
|
|
{
|
|
executor_.stop(runID_, taskName);
|
|
}
|
|
|
|
void DAGRunner::queuePending()
|
|
{
|
|
if (!running_)
|
|
return;
|
|
|
|
const size_t MAX_SUBMITS = 100;
|
|
size_t n_submitted = 0;
|
|
|
|
/*
|
|
In cases where there are many tasks ready to execute,
|
|
the blocking nature of executor_.execute(...) means
|
|
that tasks will get executed and the futures resolved,
|
|
but we won't know until all pending tasks are enqueued.
|
|
|
|
To avoid this, submit at most MAX_SUBMITS tasks before
|
|
returning to allow completed tasks to be updated.
|
|
*/
|
|
|
|
while (n_submitted < MAX_SUBMITS) {
|
|
auto t = dag_.visitNext();
|
|
if (!t)
|
|
break;
|
|
auto &taskName = t.value().first;
|
|
taskAttemptCounts_[taskName] = 1;
|
|
|
|
logger_.updateTaskState(runID_, taskName, RunState::RUNNING);
|
|
try {
|
|
auto &task = t.value().second;
|
|
auto fut = executor_.execute(runID_, taskName, task);
|
|
runningTasks_.emplace(taskName, std::move(fut));
|
|
}
|
|
catch (std::exception &e) {
|
|
std::cout << "Unable to execute task: " << e.what() << std::endl;
|
|
}
|
|
++nRunningTasks_;
|
|
++n_submitted;
|
|
}
|
|
}
|
|
|
|
void DAGRunner::collectFinished()
|
|
{
|
|
std::vector<std::string> completedTasks;
|
|
for (auto &[taskName, fut] : runningTasks_) {
|
|
if (fut->ready()) {
|
|
auto attempt = fut->get();
|
|
logger_.logTaskAttempt(runID_, taskName, attempt);
|
|
completedTasks.push_back(taskName);
|
|
|
|
// Not a reference, since adding tasks will invalidate references
|
|
auto vert = dag_.getVertex(taskName);
|
|
auto &task = vert.data;
|
|
if (attempt.rc == 0) {
|
|
logger_.updateTaskState(runID_, taskName, RunState::COMPLETED);
|
|
if (task.isGenerator) {
|
|
// Parse the output and update the DAGs
|
|
try {
|
|
auto parsedTasks =
|
|
tasksFromJSON(attempt.outputLog, taskParams_.jobDefaults);
|
|
auto newTasks =
|
|
expandTaskSet(parsedTasks, executor_, taskParams_.variables);
|
|
updateDAGFromTasks(dag_, newTasks);
|
|
|
|
// Add in dependencies from current task to new tasks
|
|
for (const auto &[ntName, ntTask] : newTasks) {
|
|
logger_.addTask(runID_, ntName, ntTask);
|
|
task.children.insert(ntName);
|
|
}
|
|
|
|
// Efficiently add new edges from generator task
|
|
// to children
|
|
std::unordered_set<std::string> baseNames;
|
|
for (const auto &[k, v] : parsedTasks) {
|
|
baseNames.insert(v.definedName);
|
|
}
|
|
dag_.addEdgeIf(taskName, [&](const auto &v) {
|
|
return baseNames.count(v.data.definedName) > 0;
|
|
});
|
|
|
|
logger_.updateTask(runID_, taskName, task);
|
|
}
|
|
catch (std::exception &e) {
|
|
logger_.logTaskAttempt(
|
|
runID_, taskName,
|
|
AttemptRecord{
|
|
.executorLog =
|
|
std::string{"Failed to parse JSON output: "} +
|
|
e.what()});
|
|
logger_.updateTaskState(runID_, taskName, RunState::ERRORED);
|
|
++nErroredTasks_;
|
|
}
|
|
}
|
|
dag_.completeVisit(taskName);
|
|
--nRunningTasks_;
|
|
}
|
|
else {
|
|
// RC isn't 0
|
|
if (taskAttemptCounts_[taskName] <= task.maxRetries) {
|
|
logger_.updateTaskState(runID_, taskName, RunState::RETRY);
|
|
runningTasks_[taskName] = executor_.execute(runID_, taskName, task);
|
|
++taskAttemptCounts_[taskName];
|
|
}
|
|
else {
|
|
if (logger_.getTaskState(runID_, taskName) == +RunState::RUNNING or
|
|
logger_.getTaskState(runID_, taskName) == +RunState::RETRY) {
|
|
logger_.updateTaskState(runID_, taskName, RunState::ERRORED);
|
|
++nErroredTasks_;
|
|
}
|
|
else {
|
|
// Task was killed
|
|
--nRunningTasks_;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
for (const auto &taskName : completedTasks) {
|
|
runningTasks_.extract(taskName);
|
|
}
|
|
}
|
|
|
|
void DAGRunner::stop(bool kill, bool blocking)
|
|
{
|
|
kill_ = kill;
|
|
running_ = false;
|
|
|
|
logger_.updateDAGRunState(runID_, RunState::KILLED);
|
|
if (blocking) {
|
|
while (true) {
|
|
{
|
|
std::lock_guard<std::mutex> lock(runGuard_);
|
|
if (nRunningTasks_ - nErroredTasks_ == 0)
|
|
break;
|
|
}
|
|
std::this_thread::sleep_for(250ms);
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace daggy
|