diff --git a/CMakeLists.txt b/CMakeLists.txt index 3418100..f2107c7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,9 +1,18 @@ cmake_minimum_required(VERSION 3.14) project(overall) + +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "Debug") +endif() + set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED True) set(CMAKE_EXPORT_COMPILE_COMMANDS True) -SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Werror") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Werror") + +if(CMAKE_BUILD_TYPE MATCHES "Debug") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=thread -fno-omit-frame-pointer") +endif() set(THIRD_PARTY_DIR ${CMAKE_BINARY_DIR}/third_party) @@ -19,6 +28,8 @@ include(cmake/argparse.cmake) include(cmake/Catch2.cmake) include(cmake/daggy_features.cmake) +message("-- CMAKE Build Type is ${CMAKE_BUILD_TYPE}") + # use, i.e. don't skip the full RPATH for the build tree set(CMAKE_SKIP_BUILD_RPATH FALSE) diff --git a/README.md b/README.md index 4601c13..fa0377b 100644 --- a/README.md +++ b/README.md @@ -28,11 +28,10 @@ graph LR Individual tasks (vertices) are run via a task executor. Daggy supports multiple executors, from local executor (via fork), to distributed work managers like [slurm](https://slurm.schedmd.com/overview.html) -or [kubernetes](https://kubernetes.io/) (both planned). +or [kubernetes](https://kubernetes.io/) (planned). -State is maintained via state loggers. Currently daggy supports an in-memory state manager (OStreamLogger), and a -filesystem logger (FileSystemLogger). Future plans include supporting [redis](https://redis.io) -and [postgres](https://postgresql.org). +State is maintained via state loggers. Currently daggy supports an in-memory state manager (OStreamLogger). +Future plans include supporting [redis](https://redis.io) and [postgres](https://postgresql.org). Building == @@ -43,13 +42,17 @@ Building - cmake >= 3.14 - gcc >= 8 +- libslurm (if needed) + ``` git clone https://gitlab.com/iroddis/daggy cd daggy mkdir build cd build -cmake .. +cmake [-DDAGGY_ENABLE_SLURM=ON] .. make + +tests/tests # for unit tests ``` DAG Run Definition diff --git a/cmake/daggy_features.cmake b/cmake/daggy_features.cmake index 67ca04c..d929413 100644 --- a/cmake/daggy_features.cmake +++ b/cmake/daggy_features.cmake @@ -1,5 +1,5 @@ # SLURM -message("DAGGY_ENABLED_SLURM is set to ${DAGGY_ENABLE_SLURM}") +message("-- DAGGY_ENABLED_SLURM is set to ${DAGGY_ENABLE_SLURM}") if (DAGGY_ENABLE_SLURM) find_library(SLURM_LIB libslurm.so libslurm.a slurm REQUIRED) find_path(SLURM_INCLUDE_DIR "slurm/slurm.h" REQUIRED) diff --git a/daggy/include/daggy/DAGRunner.hpp b/daggy/include/daggy/DAGRunner.hpp new file mode 100644 index 0000000..f821de2 --- /dev/null +++ b/daggy/include/daggy/DAGRunner.hpp @@ -0,0 +1,55 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +#include "DAG.hpp" +#include "Defines.hpp" +#include "Serialization.hpp" +#include "Utilities.hpp" +#include "daggy/executors/task/TaskExecutor.hpp" +#include "daggy/loggers/dag_run/DAGRunLogger.hpp" + +using namespace std::chrono_literals; + +namespace daggy { + class DAGRunner + { + public: + DAGRunner(DAGRunID runID, executors::task::TaskExecutor &executor, + loggers::dag_run::DAGRunLogger &logger, TaskDAG dag, + const TaskParameters &taskParams); + + ~DAGRunner(); + + TaskDAG run(); + void resetRunning(); + void stop(bool kill = false, bool blocking = false); + + private: + void collectFinished(); + void queuePending(); + void killRunning(); + + DAGRunID runID_; + executors::task::TaskExecutor &executor_; + loggers::dag_run::DAGRunLogger &logger_; + TaskDAG dag_; + const TaskParameters &taskParams_; + std::atomic running_; + std::atomic kill_; + + ssize_t nRunningTasks_; + ssize_t nErroredTasks_; + std::unordered_map> runningTasks_; + std::unordered_map taskAttemptCounts_; + + std::mutex runGuard_; + }; +} // namespace daggy diff --git a/daggy/include/daggy/Defines.hpp b/daggy/include/daggy/Defines.hpp index cbf4422..4e7e1ee 100644 --- a/daggy/include/daggy/Defines.hpp +++ b/daggy/include/daggy/Defines.hpp @@ -22,9 +22,8 @@ namespace daggy { // DAG Runs using DAGRunID = size_t; - BETTER_ENUM(RunState, uint32_t, QUEUED = 1 << 0, RUNNING = 1 << 1, - RETRY = 1 << 2, ERRORED = 1 << 3, KILLED = 1 << 4, - COMPLETED = 1 << 5); + BETTER_ENUM(RunState, uint32_t, QUEUED = 1, RUNNING, RETRY, ERRORED, KILLED, + PAUSED, COMPLETED); struct Task { @@ -50,6 +49,20 @@ namespace daggy { using TaskSet = std::unordered_map; + // All the components required to define and run a DAG + struct TaskParameters + { + ConfigValues variables; + ConfigValues jobDefaults; + }; + + struct DAGSpec + { + std::string tag; + TaskSet tasks; + TaskParameters taskConfig; + }; + struct AttemptRecord { TimePoint startTime; diff --git a/daggy/include/daggy/Serialization.hpp b/daggy/include/daggy/Serialization.hpp index 0d4fb36..5259d3a 100644 --- a/daggy/include/daggy/Serialization.hpp +++ b/daggy/include/daggy/Serialization.hpp @@ -8,6 +8,7 @@ #include #include "Defines.hpp" +#include "Utilities.hpp" namespace rj = rapidjson; @@ -36,6 +37,10 @@ namespace daggy { std::string tasksToJSON(const TaskSet &tasks); + // Full specs + DAGSpec dagFromJSON(const rj::Value &spec); + DAGSpec dagFromJSON(const std::string &jsonSpec); + // Attempt Records std::string attemptRecordToJSON(const AttemptRecord &attemptRecord); diff --git a/daggy/include/daggy/Server.hpp b/daggy/include/daggy/Server.hpp index f4f4231..2f6dbc8 100644 --- a/daggy/include/daggy/Server.hpp +++ b/daggy/include/daggy/Server.hpp @@ -6,10 +6,15 @@ #include +#include "DAGRunner.hpp" #include "ThreadPool.hpp" #include "executors/task/TaskExecutor.hpp" #include "loggers/dag_run/DAGRunLogger.hpp" +#define DAGGY_REST_HANDLER(func) \ + void func(const Pistache::Rest::Request &request, \ + Pistache::Http::ResponseWriter response); + namespace fs = std::filesystem; namespace daggy { @@ -18,14 +23,8 @@ namespace daggy { public: Server(const Pistache::Address &listenSpec, loggers::dag_run::DAGRunLogger &logger, - executors::task::TaskExecutor &executor, size_t nDAGRunners) - : endpoint_(listenSpec) - , desc_("Daggy API", "0.1") - , logger_(logger) - , executor_(executor) - , runnerPool_(nDAGRunners) - { - } + executors::task::TaskExecutor &executor, size_t nDAGRunners); + ~Server(); Server &setSSLCertificates(const fs::path &cert, const fs::path &key); @@ -39,21 +38,21 @@ namespace daggy { private: void createDescription(); + void queueDAG_(DAGRunID runID, const TaskDAG &dag, + const TaskParameters &taskParameters); - void handleRunDAG(const Pistache::Rest::Request &request, - Pistache::Http::ResponseWriter response); + DAGGY_REST_HANDLER(handleReady); // X + DAGGY_REST_HANDLER(handleQueryDAGs); // X + DAGGY_REST_HANDLER(handleRunDAG); // X + DAGGY_REST_HANDLER(handleValidateDAG); // X + DAGGY_REST_HANDLER(handleGetDAGRun); // X + DAGGY_REST_HANDLER(handleGetDAGRunState); // X + DAGGY_REST_HANDLER(handleSetDAGRunState); // X + DAGGY_REST_HANDLER(handleGetTask); // X + DAGGY_REST_HANDLER(handleGetTaskState); // X + DAGGY_REST_HANDLER(handleSetTaskState); // X - void handleGetDAGRuns(const Pistache::Rest::Request &request, - Pistache::Http::ResponseWriter response); - - void handleGetDAGRun(const Pistache::Rest::Request &request, - Pistache::Http::ResponseWriter response); - - void handleReady(const Pistache::Rest::Request &request, - Pistache::Http::ResponseWriter response); - - bool handleAuth(const Pistache::Rest::Request &request, - Pistache::Http::ResponseWriter &response); + bool handleAuth(const Pistache::Rest::Request &request); Pistache::Http::Endpoint endpoint_; Pistache::Rest::Description desc_; @@ -62,5 +61,8 @@ namespace daggy { loggers::dag_run::DAGRunLogger &logger_; executors::task::TaskExecutor &executor_; ThreadPool runnerPool_; + + std::mutex runnerGuard_; + std::unordered_map> runners_; }; } // namespace daggy diff --git a/daggy/include/daggy/Utilities.hpp b/daggy/include/daggy/Utilities.hpp index f5e9602..8c2f4dc 100644 --- a/daggy/include/daggy/Utilities.hpp +++ b/daggy/include/daggy/Utilities.hpp @@ -31,9 +31,5 @@ namespace daggy { void updateDAGFromTasks(TaskDAG &dag, const TaskSet &tasks); - TaskDAG runDAG(DAGRunID runID, executors::task::TaskExecutor &executor, - loggers::dag_run::DAGRunLogger &logger, TaskDAG dag, - const ConfigValues job = {}); - std::ostream &operator<<(std::ostream &os, const TimePoint &tp); } // namespace daggy diff --git a/daggy/include/daggy/executors/task/ForkingTaskExecutor.hpp b/daggy/include/daggy/executors/task/ForkingTaskExecutor.hpp index 860fe04..573c4fe 100644 --- a/daggy/include/daggy/executors/task/ForkingTaskExecutor.hpp +++ b/daggy/include/daggy/executors/task/ForkingTaskExecutor.hpp @@ -10,10 +10,8 @@ namespace daggy::executors::task { public: using Command = std::vector; - explicit ForkingTaskExecutor(size_t nThreads) - : tp_(nThreads) - { - } + explicit ForkingTaskExecutor(size_t nThreads); + ~ForkingTaskExecutor() override; // Validates the job to ensure that all required values are set and are of // the right type, @@ -23,11 +21,16 @@ namespace daggy::executors::task { const ConfigValues &job, const ConfigValues &expansionValues) override; // Runs the task - std::future execute(const std::string &taskName, + std::future execute(DAGRunID runID, + const std::string &taskName, const Task &task) override; + bool stop(DAGRunID runID, const std::string &taskName) override; + private: ThreadPool tp_; - AttemptRecord runTask(const Task &task); + std::mutex taskControlsGuard_; + AttemptRecord runTask(const Task &task, std::atomic &running); + std::unordered_map> taskControls_; }; } // namespace daggy::executors::task diff --git a/daggy/include/daggy/executors/task/NoopTaskExecutor.hpp b/daggy/include/daggy/executors/task/NoopTaskExecutor.hpp index 751b255..92730cc 100644 --- a/daggy/include/daggy/executors/task/NoopTaskExecutor.hpp +++ b/daggy/include/daggy/executors/task/NoopTaskExecutor.hpp @@ -16,7 +16,10 @@ namespace daggy::executors::task { const ConfigValues &job, const ConfigValues &expansionValues) override; // Runs the task - std::future execute(const std::string &taskName, + std::future execute(DAGRunID runID, + const std::string &taskName, const Task &task) override; + + bool stop(DAGRunID runID, const std::string &taskName) override; }; } // namespace daggy::executors::task diff --git a/daggy/include/daggy/executors/task/SlurmTaskExecutor.hpp b/daggy/include/daggy/executors/task/SlurmTaskExecutor.hpp index f8f159e..90e3e2e 100644 --- a/daggy/include/daggy/executors/task/SlurmTaskExecutor.hpp +++ b/daggy/include/daggy/executors/task/SlurmTaskExecutor.hpp @@ -19,15 +19,20 @@ namespace daggy::executors::task { const ConfigValues &job, const ConfigValues &expansionValues) override; // Runs the task - std::future execute(const std::string &taskName, + std::future execute(DAGRunID runID, + const std::string &taskName, const Task &task) override; + bool stop(DAGRunID runID, const std::string &taskName) override; + private: struct Job { std::promise prom; std::string stdoutFile; std::string stderrFile; + DAGRunID runID; + std::string taskName; }; std::mutex promiseGuard_; diff --git a/daggy/include/daggy/executors/task/TaskExecutor.hpp b/daggy/include/daggy/executors/task/TaskExecutor.hpp index 682bcea..f2c02ec 100644 --- a/daggy/include/daggy/executors/task/TaskExecutor.hpp +++ b/daggy/include/daggy/executors/task/TaskExecutor.hpp @@ -27,7 +27,11 @@ namespace daggy::executors::task { const ConfigValues &job, const ConfigValues &expansionValues) = 0; // Blocking execution of a task - virtual std::future execute(const std::string &taskName, + virtual std::future execute(DAGRunID runID, + const std::string &taskName, const Task &task) = 0; + + // Kill a currently executing task. This will resolve the future. + virtual bool stop(DAGRunID runID, const std::string &taskName) = 0; }; } // namespace daggy::executors::task diff --git a/daggy/include/daggy/loggers/dag_run/DAGRunLogger.hpp b/daggy/include/daggy/loggers/dag_run/DAGRunLogger.hpp index 0e97036..3232dc5 100644 --- a/daggy/include/daggy/loggers/dag_run/DAGRunLogger.hpp +++ b/daggy/include/daggy/loggers/dag_run/DAGRunLogger.hpp @@ -17,8 +17,8 @@ namespace daggy::loggers::dag_run { public: virtual ~DAGRunLogger() = default; - // Execution - virtual DAGRunID startDAGRun(std::string name, const TaskSet &tasks) = 0; + // Insertion / Updates + virtual DAGRunID startDAGRun(const DAGSpec &dagSpec) = 0; virtual void addTask(DAGRunID dagRunID, const std::string &taskName, const Task &task) = 0; @@ -35,8 +35,16 @@ namespace daggy::loggers::dag_run { RunState state) = 0; // Querying - virtual std::vector getDAGs(uint32_t stateMask) = 0; + virtual DAGSpec getDAGSpec(DAGRunID dagRunID) = 0; - virtual DAGRunRecord getDAGRun(DAGRunID dagRunID) = 0; + virtual std::vector queryDAGRuns(const std::string &tag = "", + bool all = false) = 0; + + virtual RunState getDAGRunState(DAGRunID dagRunID) = 0; + virtual DAGRunRecord getDAGRun(DAGRunID dagRunID) = 0; + + virtual Task &getTask(DAGRunID dagRunID, const std::string &taskName) = 0; + virtual RunState &getTaskState(DAGRunID dagRunID, + const std::string &taskName) = 0; }; } // namespace daggy::loggers::dag_run diff --git a/daggy/include/daggy/loggers/dag_run/Defines.hpp b/daggy/include/daggy/loggers/dag_run/Defines.hpp index eff8010..db2d810 100644 --- a/daggy/include/daggy/loggers/dag_run/Defines.hpp +++ b/daggy/include/daggy/loggers/dag_run/Defines.hpp @@ -25,8 +25,7 @@ namespace daggy::loggers::dag_run { // Pretty heavy weight, but struct DAGRunRecord { - std::string name; - TaskSet tasks; + DAGSpec dagSpec; std::unordered_map taskRunStates; std::unordered_map> taskAttempts; std::vector taskStateChanges; @@ -36,7 +35,7 @@ namespace daggy::loggers::dag_run { struct DAGRunSummary { DAGRunID runID; - std::string name; + std::string tag; RunState runState; TimePoint startTime; TimePoint lastUpdate; diff --git a/daggy/include/daggy/loggers/dag_run/OStreamLogger.hpp b/daggy/include/daggy/loggers/dag_run/OStreamLogger.hpp index ad0979c..53c7bb8 100644 --- a/daggy/include/daggy/loggers/dag_run/OStreamLogger.hpp +++ b/daggy/include/daggy/loggers/dag_run/OStreamLogger.hpp @@ -15,9 +15,10 @@ namespace daggy::loggers::dag_run { { public: explicit OStreamLogger(std::ostream &os); + ~OStreamLogger() override; // Execution - DAGRunID startDAGRun(std::string name, const TaskSet &tasks) override; + DAGRunID startDAGRun(const DAGSpec &dagSpec) override; void addTask(DAGRunID dagRunID, const std::string &taskName, const Task &task) override; @@ -34,10 +35,18 @@ namespace daggy::loggers::dag_run { RunState state) override; // Querying - std::vector getDAGs(uint32_t stateMask) override; + DAGSpec getDAGSpec(DAGRunID dagRunID) override; + std::vector queryDAGRuns(const std::string &tag = "", + bool all = false) override; + + RunState getDAGRunState(DAGRunID dagRunID) override; DAGRunRecord getDAGRun(DAGRunID dagRunID) override; + Task &getTask(DAGRunID dagRunID, const std::string &taskName) override; + RunState &getTaskState(DAGRunID dagRunID, + const std::string &taskName) override; + private: std::mutex guard_; std::ostream &os_; diff --git a/daggy/src/CMakeLists.txt b/daggy/src/CMakeLists.txt index 1af1510..7dbe518 100644 --- a/daggy/src/CMakeLists.txt +++ b/daggy/src/CMakeLists.txt @@ -2,6 +2,7 @@ target_sources(${PROJECT_NAME} PRIVATE Serialization.cpp Server.cpp Utilities.cpp + DAGRunner.cpp ) add_subdirectory(executors) diff --git a/daggy/src/DAGRunner.cpp b/daggy/src/DAGRunner.cpp new file mode 100644 index 0000000..164fbc8 --- /dev/null +++ b/daggy/src/DAGRunner.cpp @@ -0,0 +1,213 @@ +#include +#include +#include +#include + +namespace daggy { + DAGRunner::DAGRunner(DAGRunID runID, executors::task::TaskExecutor &executor, + loggers::dag_run::DAGRunLogger &logger, TaskDAG dag, + const TaskParameters &taskParams) + : runID_(runID) + , executor_(executor) + , logger_(logger) + , dag_(dag) + , taskParams_(taskParams) + , running_(true) + , kill_(true) + , nRunningTasks_(0) + , nErroredTasks_(0) + { + } + + DAGRunner::~DAGRunner() + { + std::lock_guard lock(runGuard_); + } + + TaskDAG DAGRunner::run() + { + kill_ = false; + running_ = true; + logger_.updateDAGRunState(runID_, RunState::RUNNING); + + bool allVisited; + { + std::lock_guard lock(runGuard_); + allVisited = dag_.allVisited(); + } + while (!allVisited) { + { + std::lock_guard runLock(runGuard_); + if (!running_ and kill_) { + killRunning(); + } + collectFinished(); + queuePending(); + + if (!running_ and (nRunningTasks_ - nErroredTasks_ <= 0)) { + logger_.updateDAGRunState(runID_, RunState::KILLED); + break; + } + + if (nRunningTasks_ > 0 and nErroredTasks_ == nRunningTasks_) { + logger_.updateDAGRunState(runID_, RunState::ERRORED); + break; + } + } + + std::this_thread::sleep_for(250ms); + { + std::lock_guard lock(runGuard_); + allVisited = dag_.allVisited(); + } + } + + if (dag_.allVisited()) { + logger_.updateDAGRunState(runID_, RunState::COMPLETED); + } + + running_ = false; + return dag_; + } + + void DAGRunner::resetRunning() + { + if (running_) + throw std::runtime_error("Unable to reset while DAG is running."); + + std::lock_guard lock(runGuard_); + nRunningTasks_ = 0; + nErroredTasks_ = 0; + runningTasks_.clear(); + taskAttemptCounts_.clear(); + dag_.resetRunning(); + } + + void DAGRunner::killRunning() + { + for (const auto &[taskName, _] : runningTasks_) { + executor_.stop(runID_, taskName); + } + } + + void DAGRunner::queuePending() + { + if (!running_) + return; + + // Check for any completed tasks + // Add all remaining tasks in a task queue to avoid dominating the thread + // pool + auto t = dag_.visitNext(); + while (t.has_value()) { + // Schedule the task to run + auto &taskName = t.value().first; + auto &task = t.value().second; + taskAttemptCounts_[taskName] = 1; + + logger_.updateTaskState(runID_, taskName, RunState::RUNNING); + runningTasks_.emplace(taskName, + executor_.execute(runID_, taskName, task)); + ++nRunningTasks_; + + auto nextTask = dag_.visitNext(); + if (not nextTask.has_value()) + break; + t.emplace(nextTask.value()); + } + } + + void DAGRunner::collectFinished() + { + for (auto &[taskName, fut] : runningTasks_) { + if (fut.valid() and fut.wait_for(1ms) == std::future_status::ready) { + auto attempt = fut.get(); + logger_.logTaskAttempt(runID_, taskName, attempt); + + // Not a reference, since adding tasks will invalidate references + auto vert = dag_.getVertex(taskName); + auto &task = vert.data; + if (attempt.rc == 0) { + logger_.updateTaskState(runID_, taskName, RunState::COMPLETED); + if (task.isGenerator) { + // Parse the output and update the DAGs + try { + auto parsedTasks = + tasksFromJSON(attempt.outputLog, taskParams_.jobDefaults); + auto newTasks = + expandTaskSet(parsedTasks, executor_, taskParams_.variables); + updateDAGFromTasks(dag_, newTasks); + + // Add in dependencies from current task to new tasks + for (const auto &[ntName, ntTask] : newTasks) { + logger_.addTask(runID_, ntName, ntTask); + task.children.insert(ntName); + } + + // Efficiently add new edges from generator task + // to children + std::unordered_set baseNames; + for (const auto &[k, v] : parsedTasks) { + baseNames.insert(v.definedName); + } + dag_.addEdgeIf(taskName, [&](const auto &v) { + return baseNames.count(v.data.definedName) > 0; + }); + + logger_.updateTask(runID_, taskName, task); + } + catch (std::exception &e) { + logger_.logTaskAttempt( + runID_, taskName, + AttemptRecord{ + .executorLog = + std::string{"Failed to parse JSON output: "} + + e.what()}); + logger_.updateTaskState(runID_, taskName, RunState::ERRORED); + ++nErroredTasks_; + } + } + dag_.completeVisit(taskName); + --nRunningTasks_; + } + else { + // RC isn't 0 + if (taskAttemptCounts_[taskName] <= task.maxRetries) { + logger_.updateTaskState(runID_, taskName, RunState::RETRY); + runningTasks_[taskName] = executor_.execute(runID_, taskName, task); + ++taskAttemptCounts_[taskName]; + } + else { + if (logger_.getTaskState(runID_, taskName) == +RunState::RUNNING or + logger_.getTaskState(runID_, taskName) == +RunState::RETRY) { + logger_.updateTaskState(runID_, taskName, RunState::ERRORED); + ++nErroredTasks_; + } + else { + // Task was killed + --nRunningTasks_; + } + } + } + } + } + } + + void DAGRunner::stop(bool kill, bool blocking) + { + kill_ = kill; + running_ = false; + + if (blocking) { + while (true) { + { + std::lock_guard lock(runGuard_); + if (nRunningTasks_ - nErroredTasks_ == 0) + break; + } + std::this_thread::sleep_for(250ms); + } + } + } + +} // namespace daggy diff --git a/daggy/src/Serialization.cpp b/daggy/src/Serialization.cpp index 7540eaa..422670b 100644 --- a/daggy/src/Serialization.cpp +++ b/daggy/src/Serialization.cpp @@ -276,17 +276,56 @@ namespace daggy { std::string timePointToString(const TimePoint &tp) { - std::stringstream ss; - ss << tp; - return ss.str(); + return std::to_string(tp.time_since_epoch().count()); } TimePoint stringToTimePoint(const std::string &timeString) { - std::tm dt{}; - std::stringstream ss{timeString}; - ss >> std::get_time(&dt, "%Y-%m-%d %H:%M:%S %Z"); - return Clock::from_time_t(mktime(&dt)); + using namespace std::chrono; + + size_t nanos = std::stoull(timeString); + nanoseconds dur(nanos); + + return TimePoint(dur); + } + + DAGSpec dagFromJSON(const rj::Value &spec) + { + DAGSpec info; + + if (!spec.IsObject()) { + throw std::runtime_error("Payload is not a dictionary."); + } + if (!spec.HasMember("tag")) { + throw std::runtime_error("DAG Run is missing a name."); + } + if (!spec.HasMember("tasks")) { + throw std::runtime_error("DAG Run has no tasks."); + } + + info.tag = spec["tag"].GetString(); + + // Get parameters if there are any + if (spec.HasMember("parameters")) { + info.taskConfig.variables = configFromJSON(spec["parameters"]); + } + + // Job Defaults + if (spec.HasMember("jobDefaults")) { + info.taskConfig.jobDefaults = configFromJSON(spec["jobDefaults"]); + } + + // Get the tasks + info.tasks = tasksFromJSON(spec["tasks"], info.taskConfig.jobDefaults); + + return info; + } + + DAGSpec dagFromJSON(const std::string &jsonSpec) + { + rj::Document doc; + checkRJParse(doc.Parse(jsonSpec.c_str()), "Parsing config"); + return dagFromJSON(doc); } } // namespace daggy diff --git a/daggy/src/Server.cpp b/daggy/src/Server.cpp index 81cacf8..9062706 100644 --- a/daggy/src/Server.cpp +++ b/daggy/src/Server.cpp @@ -4,12 +4,18 @@ #include #include #include +#include +#include +#include +#include +#include -#define REQ_ERROR(code, msg) \ - response.send(Pistache::Http::Code::code, msg); \ +#define REQ_RESPONSE(code, msg) \ + std::stringstream ss; \ + ss << R"({"message": )" << std::quoted(msg) << "}"; \ + response.send(Pistache::Http::Code::code, ss.str()); \ return; -namespace rj = rapidjson; using namespace Pistache; namespace daggy { @@ -25,6 +31,22 @@ namespace daggy { createDescription(); } + Server::Server(const Pistache::Address &listenSpec, + loggers::dag_run::DAGRunLogger &logger, + executors::task::TaskExecutor &executor, size_t nDAGRunners) + : endpoint_(listenSpec) + , desc_("Daggy API", "0.1") + , logger_(logger) + , executor_(executor) + , runnerPool_(nDAGRunners) + { + } + + Server::~Server() + { + shutdown(); + } + void Server::start() { router_.initFromDescription(desc_); @@ -42,6 +64,7 @@ namespace daggy { void Server::shutdown() { endpoint_.shutdown(); + runnerPool_.shutdown(); } uint16_t Server::getPort() const @@ -55,7 +78,7 @@ namespace daggy { auto backendErrorResponse = desc_.response(Http::Code::Internal_Server_Error, - "An error occurred with the backend"); + R"({"error": "An error occurred with the backend"})"); desc_.schemes(Rest::Scheme::Http) .basePath("/v1") @@ -69,111 +92,131 @@ namespace daggy { auto versionPath = desc_.path("/v1"); - auto dagPath = versionPath.path("/dagrun"); + /* + DAG Run Summaries + */ + auto dagRunsPath = versionPath.path("/dagruns"); - // Run a DAG - dagPath.route(desc_.post("/")) + dagRunsPath.route(desc_.get("/")) + .bind(&Server::handleQueryDAGs, this) + .produces(MIME(Application, Json)) + .response(Http::Code::Ok, "List summaries DAGs"); + + /* + Individual DAG Run routes + */ + auto dagRunPath = versionPath.path("/dagrun"); + + dagRunPath.route(desc_.post("/")) .bind(&Server::handleRunDAG, this) - .produces(MIME(Application, Json), MIME(Application, Xml)) + .produces(MIME(Application, Json)) .response(Http::Code::Ok, "Run a DAG"); - // List detailed DAG run - dagPath.route(desc_.get("/:runID")) - .bind(&Server::handleGetDAGRun, this) - .produces(MIME(Application, Json), MIME(Application, Xml)) - .response(Http::Code::Ok, "Details of a specific DAG run"); - // List all DAG runs - dagPath.route(desc_.get("/")) - .bind(&Server::handleGetDAGRuns, this) - .produces(MIME(Application, Json), MIME(Application, Xml)) - .response(Http::Code::Ok, "The list of all known DAG Runs"); + dagRunPath.route(desc_.post("/validate")) + .bind(&Server::handleValidateDAG, this) + .produces(MIME(Application, Json)) + .response(Http::Code::Ok, "Validate a DAG Run Spec"); + + /* + Management of a specific DAG + */ + auto specificDAGRunPath = dagRunPath.path("/:runID"); + + specificDAGRunPath.route(desc_.get("/")) + .bind(&Server::handleGetDAGRun, this) + .produces(MIME(Application, Json)) + .response(Http::Code::Ok, "Full DAG Run"); + + specificDAGRunPath.route(desc_.get("/state")) + .bind(&Server::handleGetDAGRunState, this) + .produces(MIME(Application, Json)) + .response(Http::Code::Ok, + "Structure of a DAG and DAG and Task run states"); + + specificDAGRunPath.route(desc_.patch("/state/:state")) + .bind(&Server::handleSetDAGRunState, this) + .produces(MIME(Application, Json)) + .response(Http::Code::Ok, "Change the state of a DAG"); + + /* + Task paths + */ + auto taskPath = specificDAGRunPath.path("/task/:taskName"); + taskPath.route(desc_.get("/")) + .bind(&Server::handleGetTask, this) + .produces(MIME(Application, Json)) + .response(Http::Code::Ok, "Details of a specific task"); + + /* + Task State paths + */ + auto taskStatePath = taskPath.path("/state"); + + taskStatePath.route(desc_.get("/")) + .bind(&Server::handleGetTaskState, this) + .produces(MIME(Application, Json)) + .response(Http::Code::Ok, "Get a task state"); + + taskStatePath.route(desc_.patch("/:state")) + .bind(&Server::handleSetTaskState, this) + .produces(MIME(Application, Json)) + .response(Http::Code::Ok, "Set a task state"); } - /* - * { - * "name": "DAG Run Name" - * "job": {...} - * "tasks": {...} - */ void Server::handleRunDAG(const Pistache::Rest::Request &request, Pistache::Http::ResponseWriter response) { - if (!handleAuth(request, response)) + if (!handleAuth(request)) return; - rj::Document doc; - try { - doc.Parse(request.body().c_str()); - } - catch (std::exception &e) { - REQ_ERROR(Bad_Request, std::string{"Invalid JSON payload: "} + e.what()); - } - - if (!doc.IsObject()) { - REQ_ERROR(Bad_Request, "Payload is not a dictionary."); - } - if (!doc.HasMember("name")) { - REQ_ERROR(Bad_Request, "DAG Run is missing a name."); - } - if (!doc.HasMember("tasks")) { - REQ_ERROR(Bad_Request, "DAG Run has no tasks."); - } - - std::string runName = doc["name"].GetString(); - - // Get parameters if there are any - ConfigValues parameters; - if (doc.HasMember("parameters")) { - try { - auto parsedParams = configFromJSON(doc["parameters"].GetObject()); - parameters.swap(parsedParams); - } - catch (std::exception &e) { - REQ_ERROR(Bad_Request, e.what()); - } - } - - // Job Defaults - ConfigValues jobDefaults; - if (doc.HasMember("jobDefaults")) { - try { - auto parsedJobDefaults = configFromJSON(doc["jobDefaults"].GetObject()); - jobDefaults.swap(parsedJobDefaults); - } - catch (std::exception &e) { - REQ_ERROR(Bad_Request, e.what()); - } - } - - // Get the tasks - TaskSet tasks; - try { - auto taskTemplates = tasksFromJSON(doc["tasks"], jobDefaults); - auto expandedTasks = expandTaskSet(taskTemplates, executor_, parameters); - tasks.swap(expandedTasks); - } - catch (std::exception &e) { - REQ_ERROR(Bad_Request, e.what()); - } + auto dagSpec = dagFromJSON(request.body()); + dagSpec.tasks = + expandTaskSet(dagSpec.tasks, executor_, dagSpec.taskConfig.variables); // Get a run ID - auto runID = logger_.startDAGRun(runName, tasks); - auto dag = buildDAGFromTasks(tasks); - - runnerPool_.addTask([this, parameters, runID, dag]() { - runDAG(runID, this->executor_, this->logger_, dag, parameters); - }); + DAGRunID runID = logger_.startDAGRun(dagSpec); + auto dag = buildDAGFromTasks(dagSpec.tasks); + queueDAG_(runID, dag, dagSpec.taskConfig); response.send(Pistache::Http::Code::Ok, R"({"runID": )" + std::to_string(runID) + "}"); } - void Server::handleGetDAGRuns(const Pistache::Rest::Request &request, - Pistache::Http::ResponseWriter response) + void Server::handleValidateDAG(const Pistache::Rest::Request &request, + Pistache::Http::ResponseWriter response) { - if (!handleAuth(request, response)) + try { + dagFromJSON(request.body()); + response.send(Pistache::Http::Code::Ok, R"({"valid": true})"); + } + catch (std::exception &e) { + std::string error = e.what(); + response.send(Pistache::Http::Code::Ok, + std::string{R"({"valid": true, "error": })"} + error + "}"); + } + } + + void Server::handleQueryDAGs(const Pistache::Rest::Request &request, + Pistache::Http::ResponseWriter response) + { + if (!handleAuth(request)) return; - auto dagRuns = logger_.getDAGs(0); + + bool all = false; + std::string tag = ""; + + if (request.query().has("tag")) { + tag = request.query().get("tag").value(); + } + + if (request.hasParam(":all")) { + auto val = request.query().get("all").value(); + if (val == "true" or val == "1") { + all = true; + } + } + + auto dagRuns = logger_.queryDAGRuns(tag, all); std::stringstream ss; ss << '['; @@ -187,8 +230,8 @@ namespace daggy { } ss << " {" - << R"("runID": )" << run.runID << ',' << R"("name": )" - << std::quoted(run.name) << "," + << R"("runID": )" << run.runID << ',' << R"("tag": )" + << std::quoted(run.tag) << "," << R"("startTime": )" << std::quoted(timePointToString(run.startTime)) << ',' << R"("lastUpdate": )" << std::quoted(timePointToString(run.lastUpdate)) << ',' @@ -214,10 +257,10 @@ namespace daggy { void Server::handleGetDAGRun(const Pistache::Rest::Request &request, Pistache::Http::ResponseWriter response) { - if (!handleAuth(request, response)) + if (!handleAuth(request)) return; if (!request.hasParam(":runID")) { - REQ_ERROR(Not_Found, "No runID provided in URL"); + REQ_RESPONSE(Not_Found, "No runID provided in URL"); } auto runID = request.param(":runID").as(); auto run = logger_.getDAGRun(runID); @@ -225,9 +268,9 @@ namespace daggy { bool first = true; std::stringstream ss; ss << "{" - << R"("runID": )" << runID << ',' << R"("name": )" - << std::quoted(run.name) << ',' << R"("tasks": )" - << tasksToJSON(run.tasks) << ','; + << R"("runID": )" << runID << ',' << R"("tag": )" + << std::quoted(run.dagSpec.tag) << ',' << R"("tasks": )" + << tasksToJSON(run.dagSpec.tasks) << ','; // task run states ss << R"("taskStates": { )"; @@ -295,21 +338,179 @@ namespace daggy { response.send(Pistache::Http::Code::Ok, ss.str()); } + void Server::handleGetDAGRunState(const Pistache::Rest::Request &request, + Pistache::Http::ResponseWriter response) + { + if (!handleAuth(request)) + return; + + DAGRunID runID = request.param(":runID").as(); + RunState state = RunState::QUEUED; + try { + state = logger_.getDAGRunState(runID); + std::stringstream ss; + ss << R"({ "runID": )" << runID << R"(, "state": )" + << std::quoted(state._to_string()) << '}'; + response.send(Pistache::Http::Code::Ok, ss.str()); + } + catch (std::exception &e) { + REQ_RESPONSE(Not_Found, e.what()); + } + } + + void Server::queueDAG_(DAGRunID runID, const TaskDAG &dag, + const TaskParameters &taskParameters) + { + std::lock_guard lock(runnerGuard_); + /* + auto it = runners_.emplace( + std::piecewise_construct, std::forward_as_tuple(runID), + std::forward_as_tuple(runID, executor_, logger_, dag, + taskParameters)); + */ + auto it = runners_.emplace( + runID, std::make_shared(runID, executor_, logger_, dag, + taskParameters)); + + if (!it.second) + throw std::runtime_error("A DAGRun with the same ID is already running"); + auto runner = it.first->second; + runnerPool_.addTask([runner, runID, this]() { + runner->run(); + std::lock_guard lock(this->runnerGuard_); + this->runners_.extract(runID); + }); + } + + void Server::handleSetDAGRunState(const Pistache::Rest::Request &request, + Pistache::Http::ResponseWriter response) + { + if (!handleAuth(request)) + return; + + // TODO handle state transition + DAGRunID runID = request.param(":runID").as(); + RunState newState = RunState::_from_string( + request.param(":state").as().c_str()); + + std::shared_ptr runner{nullptr}; + { + std::lock_guard lock(runnerGuard_); + auto it = runners_.find(runID); + if (runners_.find(runID) != runners_.end()) { + runner = it->second; + } + } + + if (runner) { + switch (newState) { + case RunState::PAUSED: + case RunState::KILLED: { + runner->stop(true, true); + logger_.updateDAGRunState(runID, newState); + break; + } + default: { + REQ_RESPONSE(Method_Not_Allowed, + std::string{"Cannot transition to state "} + + newState._to_string()); + } + } + } + else { + switch (newState) { + case RunState::QUEUED: { + auto dagRun = logger_.getDAGRun(runID); + auto dag = + buildDAGFromTasks(dagRun.dagSpec.tasks, dagRun.taskStateChanges); + dag.resetRunning(); + queueDAG_(runID, dag, dagRun.dagSpec.taskConfig); + break; + } + default: + REQ_RESPONSE( + Method_Not_Allowed, + std::string{"DAG not running, cannot transition to state "} + + newState._to_string()); + } + } + REQ_RESPONSE(Ok, ""); + } + + void Server::handleGetTask(const Pistache::Rest::Request &request, + Pistache::Http::ResponseWriter response) + { + if (!handleAuth(request)) + return; + + auto runID = request.param(":runID").as(); + auto taskName = request.param(":taskName").as(); + + try { + auto task = logger_.getTask(runID, taskName); + response.send(Pistache::Http::Code::Ok, taskToJSON(task)); + } + catch (std::exception &e) { + REQ_RESPONSE(Not_Found, e.what()); + } + } + + void Server::handleGetTaskState(const Pistache::Rest::Request &request, + Pistache::Http::ResponseWriter response) + { + if (!handleAuth(request)) + return; + + auto runID = request.param(":runID").as(); + auto taskName = request.param(":taskName").as(); + + try { + auto state = logger_.getTaskState(runID, taskName); + std::stringstream ss; + ss << R"({ "runID": )" << runID << R"(, "taskName": )" + << std::quoted(taskName) << R"(, "state": )" + << std::quoted(state._to_string()) << '}'; + response.send(Pistache::Http::Code::Ok, ss.str()); + } + catch (std::exception &e) { + REQ_RESPONSE(Not_Found, e.what()); + } + } + + void Server::handleSetTaskState(const Pistache::Rest::Request &request, + Pistache::Http::ResponseWriter response) + { + if (!handleAuth(request)) + return; + + // TODO implement handling of task state + auto runID = request.param(":runID").as(); + auto taskName = request.param(":taskName").as(); + RunState state = RunState::_from_string( + request.param(":state").as().c_str()); + + try { + logger_.updateTaskState(runID, taskName, state); + response.send(Pistache::Http::Code::Ok, ""); + } + catch (std::exception &e) { + REQ_RESPONSE(Not_Found, e.what()); + } + } + void Server::handleReady(const Pistache::Rest::Request &request, Pistache::Http::ResponseWriter response) { - response.send(Pistache::Http::Code::Ok, "Ya like DAGs?"); + response.send(Pistache::Http::Code::Ok, R"({ "msg": "Ya like DAGs?"})"); } /* - * handleAuth will check any auth methods and handle any responses in the case - * of failed auth. If it returns false, callers should cease handling the - * response + * handleAuth will check any auth methods and handle any responses in the + * case of failed auth. If it returns false, callers should cease handling + * the response */ - bool Server::handleAuth(const Pistache::Rest::Request &request, - Pistache::Http::ResponseWriter &response) + bool Server::handleAuth(const Pistache::Rest::Request &request) { - (void)response; return true; } } // namespace daggy diff --git a/daggy/src/Utilities.cpp b/daggy/src/Utilities.cpp index 29ecfff..99258d0 100644 --- a/daggy/src/Utilities.cpp +++ b/daggy/src/Utilities.cpp @@ -92,8 +92,9 @@ namespace daggy { } // Add edges - for (const auto &[name, task] : tasks) { - dag.addEdgeIf(name, [&](const auto &v) { + for (const auto &[name, t] : tasks) { + const auto &task = t; + dag.addEdgeIf(name, [&task](const auto &v) { return task.children.count(v.data.definedName) > 0; }); } @@ -115,10 +116,10 @@ namespace daggy { switch (update.newState) { case RunState::RUNNING: case RunState::RETRY: + case RunState::PAUSED: case RunState::ERRORED: case RunState::KILLED: dag.setVertexState(update.taskName, RunState::RUNNING); - dag.setVertexState(update.taskName, RunState::COMPLETED); break; case RunState::COMPLETED: case RunState::QUEUED: @@ -129,120 +130,9 @@ namespace daggy { return dag; } - TaskDAG runDAG(DAGRunID runID, executors::task::TaskExecutor &executor, - loggers::dag_run::DAGRunLogger &logger, TaskDAG dag, - const ConfigValues parameters) - { - logger.updateDAGRunState(runID, RunState::RUNNING); - - std::unordered_map> runningTasks; - std::unordered_map taskAttemptCounts; - - size_t running = 0; - size_t errored = 0; - while (!dag.allVisited()) { - // Check for any completed tasks - for (auto &[taskName, fut] : runningTasks) { - if (fut.valid()) { - auto attempt = fut.get(); - logger.logTaskAttempt(runID, taskName, attempt); - - // Not a reference, since adding tasks will invalidate references - auto vert = dag.getVertex(taskName); - auto &task = vert.data; - if (attempt.rc == 0) { - logger.updateTaskState(runID, taskName, RunState::COMPLETED); - if (task.isGenerator) { - // Parse the output and update the DAGs - try { - auto parsedTasks = tasksFromJSON(attempt.outputLog); - auto newTasks = - expandTaskSet(parsedTasks, executor, parameters); - updateDAGFromTasks(dag, newTasks); - - // Add in dependencies from current task to new tasks - for (const auto &[ntName, ntTask] : newTasks) { - logger.addTask(runID, ntName, ntTask); - task.children.insert(ntName); - } - - // Efficiently add new edges from generator task - // to children - std::unordered_set baseNames; - for (const auto &[k, v] : parsedTasks) { - baseNames.insert(v.definedName); - } - dag.addEdgeIf(taskName, [&](const auto &v) { - return baseNames.count(v.data.definedName) > 0; - }); - - logger.updateTask(runID, taskName, task); - } - catch (std::exception &e) { - logger.logTaskAttempt( - runID, taskName, - AttemptRecord{ - .executorLog = - std::string{"Failed to parse JSON output: "} + - e.what()}); - logger.updateTaskState(runID, taskName, RunState::ERRORED); - ++errored; - } - } - dag.completeVisit(taskName); - --running; - } - else { - // RC isn't 0 - if (taskAttemptCounts[taskName] <= task.maxRetries) { - logger.updateTaskState(runID, taskName, RunState::RETRY); - runningTasks[taskName] = executor.execute(taskName, task); - ++taskAttemptCounts[taskName]; - } - else { - logger.updateTaskState(runID, taskName, RunState::ERRORED); - ++errored; - } - } - } - } - - // Add all remaining tasks in a task queue to avoid dominating the thread - // pool - auto t = dag.visitNext(); - while (t.has_value()) { - // Schedule the task to run - auto &taskName = t.value().first; - auto &task = t.value().second; - taskAttemptCounts[taskName] = 1; - - logger.updateTaskState(runID, taskName, RunState::RUNNING); - runningTasks.emplace(taskName, executor.execute(taskName, task)); - ++running; - - auto nextTask = dag.visitNext(); - if (not nextTask.has_value()) - break; - t.emplace(nextTask.value()); - } - if (running > 0 and errored == running) { - logger.updateDAGRunState(runID, RunState::ERRORED); - break; - } - std::this_thread::sleep_for(250ms); - } - - if (dag.allVisited()) { - logger.updateDAGRunState(runID, RunState::COMPLETED); - } - - return dag; - } - std::ostream &operator<<(std::ostream &os, const TimePoint &tp) { - auto t_c = Clock::to_time_t(tp); - os << std::put_time(std::localtime(&t_c), "%Y-%m-%d %H:%M:%S %Z"); + os << tp.time_since_epoch().count() << std::endl; return os; } } // namespace daggy diff --git a/daggy/src/executors/task/ForkingTaskExecutor.cpp b/daggy/src/executors/task/ForkingTaskExecutor.cpp index a3df342..8f1a5a0 100644 --- a/daggy/src/executors/task/ForkingTaskExecutor.cpp +++ b/daggy/src/executors/task/ForkingTaskExecutor.cpp @@ -36,13 +36,45 @@ std::string slurp(int fd) return result; } -std::future ForkingTaskExecutor::execute( - const std::string &taskName, const Task &task) +ForkingTaskExecutor::ForkingTaskExecutor(size_t nThreads) + : tp_(nThreads) { - return tp_.addTask([this, task]() { return this->runTask(task); }); } -daggy::AttemptRecord ForkingTaskExecutor::runTask(const Task &task) +ForkingTaskExecutor::~ForkingTaskExecutor() +{ + std::lock_guard lock(taskControlsGuard_); + taskControls_.clear(); +} + +bool ForkingTaskExecutor::stop(DAGRunID runID, const std::string &taskName) +{ + std::string key = std::to_string(runID) + "_" + taskName; + std::lock_guard lock(taskControlsGuard_); + auto it = taskControls_.find(key); + if (it == taskControls_.end()) + return true; + it->second = false; + return true; +} + +std::future ForkingTaskExecutor::execute( + DAGRunID runID, const std::string &taskName, const Task &task) +{ + std::string key = std::to_string(runID) + "_" + taskName; + std::lock_guard lock(taskControlsGuard_); + auto [it, ins] = taskControls_.emplace(key, true); + auto &running = it->second; + return tp_.addTask([this, task, &running, key]() { + auto ret = this->runTask(task, running); + std::lock_guard lock(this->taskControlsGuard_); + this->taskControls_.extract(key); + return ret; + }); +} + +daggy::AttemptRecord ForkingTaskExecutor::runTask(const Task &task, + std::atomic &running) { AttemptRecord rec; @@ -81,23 +113,41 @@ daggy::AttemptRecord ForkingTaskExecutor::runTask(const Task &task) exit(-1); } - std::atomic running = true; + std::atomic reading = true; std::thread stdoutReader([&]() { - while (running) + while (reading) rec.outputLog.append(slurp(stdoutPipe[0])); }); std::thread stderrReader([&]() { - while (running) + while (reading) rec.errorLog.append(slurp(stderrPipe[0])); }); - int rc = 0; - waitpid(child, &rc, 0); - running = false; + siginfo_t childInfo; + while (running) { + childInfo.si_pid = 0; + waitid(P_PID, child, &childInfo, WEXITED | WNOHANG); + if (childInfo.si_pid > 0) { + break; + } + std::this_thread::sleep_for(250ms); + } + + if (!running) { + rec.executorLog = "Killed"; + // Send the kills until pid is dead + while (kill(child, SIGKILL) != -1) { + // Need to collect the child to avoid a zombie process + waitid(P_PID, child, &childInfo, WEXITED | WNOHANG); + std::this_thread::sleep_for(50ms); + } + } + + reading = false; rec.stopTime = Clock::now(); - if (WIFEXITED(rc)) { - rec.rc = WEXITSTATUS(rc); + if (childInfo.si_pid > 0) { + rec.rc = childInfo.si_status; } else { rec.rc = -1; diff --git a/daggy/src/executors/task/NoopTaskExecutor.cpp b/daggy/src/executors/task/NoopTaskExecutor.cpp index 9239b8b..9e875f4 100644 --- a/daggy/src/executors/task/NoopTaskExecutor.cpp +++ b/daggy/src/executors/task/NoopTaskExecutor.cpp @@ -3,7 +3,7 @@ namespace daggy::executors::task { std::future NoopTaskExecutor::execute( - const std::string &taskName, const Task &task) + DAGRunID runID, const std::string &taskName, const Task &task) { std::promise promise; auto ts = Clock::now(); @@ -42,4 +42,10 @@ namespace daggy::executors::task { return newValues; } + + bool NoopTaskExecutor::stop(DAGRunID runID, const std::string &taskName) + { + return true; + } + } // namespace daggy::executors::task diff --git a/daggy/src/executors/task/SlurmTaskExecutor.cpp b/daggy/src/executors/task/SlurmTaskExecutor.cpp index e53e121..fb07635 100644 --- a/daggy/src/executors/task/SlurmTaskExecutor.cpp +++ b/daggy/src/executors/task/SlurmTaskExecutor.cpp @@ -1,4 +1,5 @@ #include +#include #include #ifdef DAGGY_ENABLE_SLURM #include @@ -6,6 +7,7 @@ #include #include +#include #include #include #include @@ -115,7 +117,7 @@ namespace daggy::executors::task { } std::future SlurmTaskExecutor::execute( - const std::string &taskName, const Task &task) + DAGRunID runID, const std::string &taskName, const Task &task) { std::stringstream executorLog; @@ -191,13 +193,39 @@ namespace daggy::executors::task { slurm_free_submit_response_response_msg(resp_msg); std::lock_guard lock(promiseGuard_); - Job newJob{.prom{}, .stdoutFile = stdoutFile, .stderrFile = stderrFile}; + Job newJob{.prom{}, + .stdoutFile = stdoutFile, + .stderrFile = stderrFile, + .runID = runID, + .taskName = taskName}; auto fut = newJob.prom.get_future(); runningJobs_.emplace(jobID, std::move(newJob)); return fut; } + bool SlurmTaskExecutor::stop(DAGRunID runID, const std::string &taskName) + { + // Hopefully this isn't a common thing, so just scrap the current jobs and + // kill them + size_t jobID = 0; + { + std::lock_guard lock(promiseGuard_); + for (const auto &[k, v] : runningJobs_) { + if (v.runID == runID and v.taskName == taskName) { + jobID = k; + break; + } + } + if (jobID == 0) + return true; + } + + // Send the kill message to slurm + slurm_kill_job(jobID, SIGKILL, KILL_HURRY); + return true; + } + void SlurmTaskExecutor::monitor() { std::unordered_set resolvedJobs; @@ -225,32 +253,40 @@ namespace daggy::executors::task { // Job has finished case JOB_COMPLETE: /* completed execution successfully */ case JOB_FAILED: /* completed execution unsuccessfully */ + record.rc = jobInfo.exit_code; record.executorLog = "Script errored.\n"; break; - case JOB_CANCELLED: /* cancelled by user */ + case JOB_CANCELLED: /* cancelled by user */ + record.rc = 9; // matches SIGKILL record.executorLog = "Job cancelled by user.\n"; break; case JOB_TIMEOUT: /* terminated on reaching time limit */ + record.rc = jobInfo.exit_code; record.executorLog = "Job exceeded time limit.\n"; break; case JOB_NODE_FAIL: /* terminated on node failure */ + record.rc = jobInfo.exit_code; record.executorLog = "Node failed during execution\n"; break; case JOB_PREEMPTED: /* terminated due to preemption */ + record.rc = jobInfo.exit_code; record.executorLog = "Job terminated due to pre-emption.\n"; break; case JOB_BOOT_FAIL: /* terminated due to node boot failure */ + record.rc = jobInfo.exit_code; record.executorLog = - "Job failed to run due to failure of compute node to boot.\n"; + "Job failed to run due to failure of compute node to " + "boot.\n"; break; case JOB_DEADLINE: /* terminated on deadline */ + record.rc = jobInfo.exit_code; record.executorLog = "Job terminated due to deadline.\n"; break; case JOB_OOM: /* experienced out of memory error */ + record.rc = jobInfo.exit_code; record.executorLog = "Job terminated due to out-of-memory.\n"; break; } - record.rc = jobInfo.exit_code; slurm_free_job_info_msg(jobStatus); readAndClean(job.stdoutFile, record.outputLog); @@ -265,7 +301,7 @@ namespace daggy::executors::task { } } - std::this_thread::sleep_for(std::chrono::milliseconds(250)); + std::this_thread::sleep_for(std::chrono::seconds(1)); } } } // namespace daggy::executors::task diff --git a/daggy/src/loggers/dag_run/OStreamLogger.cpp b/daggy/src/loggers/dag_run/OStreamLogger.cpp index f21a09a..1fc9bb3 100644 --- a/daggy/src/loggers/dag_run/OStreamLogger.cpp +++ b/daggy/src/loggers/dag_run/OStreamLogger.cpp @@ -11,20 +11,26 @@ namespace daggy { namespace loggers { namespace dag_run { { } + OStreamLogger::~OStreamLogger() + { + std::lock_guard lock(guard_); + dagRuns_.clear(); + } + // Execution - DAGRunID OStreamLogger::startDAGRun(std::string name, const TaskSet &tasks) + DAGRunID OStreamLogger::startDAGRun(const DAGSpec &dagSpec) { std::lock_guard lock(guard_); size_t runID = dagRuns_.size(); - dagRuns_.push_back({.name = name, .tasks = tasks}); - for (const auto &[name, _] : tasks) { + dagRuns_.emplace_back(DAGRunRecord{.dagSpec = dagSpec}); + for (const auto &[name, _] : dagSpec.tasks) { _updateTaskState(runID, name, RunState::QUEUED); } _updateDAGRunState(runID, RunState::QUEUED); - os_ << "Starting new DAGRun named " << name << " with ID " << runID - << " and " << tasks.size() << " tasks" << std::endl; - for (const auto &[name, task] : tasks) { + os_ << "Starting new DAGRun tagged " << dagSpec.tag << " with ID " << runID + << " and " << dagSpec.tasks.size() << " tasks" << std::endl; + for (const auto &[name, task] : dagSpec.tasks) { os_ << "TASK (" << name << "): " << configToJSON(task.job); os_ << std::endl; } @@ -35,8 +41,8 @@ namespace daggy { namespace loggers { namespace dag_run { const Task &task) { std::lock_guard lock(guard_); - auto &dagRun = dagRuns_[dagRunID]; - dagRun.tasks[taskName] = task; + auto &dagRun = dagRuns_[dagRunID]; + dagRun.dagSpec.tasks[taskName] = task; _updateTaskState(dagRunID, taskName, RunState::QUEUED); } @@ -44,8 +50,8 @@ namespace daggy { namespace loggers { namespace dag_run { const Task &task) { std::lock_guard lock(guard_); - auto &dagRun = dagRuns_[dagRunID]; - dagRun.tasks[taskName] = task; + auto &dagRun = dagRuns_[dagRunID]; + dagRun.dagSpec.tasks[taskName] = task; } void OStreamLogger::updateDAGRunState(DAGRunID dagRunID, RunState state) @@ -101,15 +107,29 @@ namespace daggy { namespace loggers { namespace dag_run { } // Querying - std::vector OStreamLogger::getDAGs(uint32_t stateMask) + DAGSpec OStreamLogger::getDAGSpec(DAGRunID dagRunID) + { + std::lock_guard lock(guard_); + return dagRuns_.at(dagRunID).dagSpec; + }; + + std::vector OStreamLogger::queryDAGRuns(const std::string &tag, + bool all) { std::vector summaries; std::lock_guard lock(guard_); size_t i = 0; for (const auto &run : dagRuns_) { + if ((!all) && + (run.dagStateChanges.back().newState == +RunState::COMPLETED)) { + continue; + } + if (!tag.empty() and tag != run.dagSpec.tag) + continue; + DAGRunSummary summary{ .runID = i, - .name = run.name, + .tag = run.dagSpec.tag, .runState = run.dagStateChanges.back().newState, .startTime = run.dagStateChanges.front().time, .lastUpdate = std::max(run.taskStateChanges.back().time, @@ -126,10 +146,26 @@ namespace daggy { namespace loggers { namespace dag_run { DAGRunRecord OStreamLogger::getDAGRun(DAGRunID dagRunID) { - if (dagRunID >= dagRuns_.size()) { - throw std::runtime_error("No such DAGRun ID"); - } std::lock_guard lock(guard_); - return dagRuns_[dagRunID]; + return dagRuns_.at(dagRunID); } + + RunState OStreamLogger::getDAGRunState(DAGRunID dagRunID) + { + std::lock_guard lock(guard_); + return dagRuns_.at(dagRunID).dagStateChanges.back().newState; + } + + Task &OStreamLogger::getTask(DAGRunID dagRunID, const std::string &taskName) + { + std::lock_guard lock(guard_); + return dagRuns_.at(dagRunID).dagSpec.tasks.at(taskName); + } + RunState &OStreamLogger::getTaskState(DAGRunID dagRunID, + const std::string &taskName) + { + std::lock_guard lock(guard_); + return dagRuns_.at(dagRunID).taskRunStates.at(taskName); + } + }}} // namespace daggy::loggers::dag_run diff --git a/endpoints.otl b/endpoints.otl new file mode 100644 index 0000000..8308751 --- /dev/null +++ b/endpoints.otl @@ -0,0 +1,30 @@ +ready [handleReady] +v1 + dagruns + :GET - Summary of running DAGs [handleListDAGs] + + {tag} + + dagrun + :POST - submit a dag run [handleRunDAG] + + validate + :POST - Ensure a submitted DAG run is valid [handleValidateDAG + + {runID} + :GET - Full DAG run information [handleGetDAG] + + summary + :GET - RunState of DAG, and task RunState counts + + state + :PATCH - Change the state of a DAG (paused, killed) + :GET - Summary of dag run (structure + runstates) + + tasks + {taskName} + :GET - Full task definition and output + + state + :GET -- Get the task state + :PATCH -- Set the task state diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 52d96fb..83618e0 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -2,6 +2,7 @@ project(tests) add_executable(tests main.cpp # unit tests unit_dag.cpp + unit_dagrunner.cpp unit_dagrun_loggers.cpp unit_executor_forkingexecutor.cpp unit_executor_slurmexecutor.cpp @@ -14,4 +15,4 @@ add_executable(tests main.cpp # Performance checks perf_dag.cpp ) -target_link_libraries(tests libdaggy stdc++fs Catch2::Catch2) +target_link_libraries(tests libdaggy stdc++fs Catch2::Catch2 curl) diff --git a/tests/unit_dagrun_loggers.cpp b/tests/unit_dagrun_loggers.cpp index c8bf95b..ed480e9 100644 --- a/tests/unit_dagrun_loggers.cpp +++ b/tests/unit_dagrun_loggers.cpp @@ -2,11 +2,10 @@ #include #include #include +#include #include "daggy/loggers/dag_run/OStreamLogger.hpp" -namespace fs = std::filesystem; - using namespace daggy; using namespace daggy::loggers::dag_run; @@ -20,28 +19,68 @@ const TaskSet SAMPLE_TASKS{ {"work_c", Task{.job{{"command", std::vector{"/bin/echo", "c"}}}}}}; -inline DAGRunID testDAGRunInit(DAGRunLogger &logger, const std::string &name, +inline DAGRunID testDAGRunInit(DAGRunLogger &logger, const std::string &tag, const TaskSet &tasks) { - auto runID = logger.startDAGRun(name, tasks); - auto dagRun = logger.getDAGRun(runID); + auto runID = logger.startDAGRun(DAGSpec{.tag = tag, .tasks = tasks}); - REQUIRE(dagRun.tasks == tasks); + // Verify run shows up in the list + { + auto runs = logger.queryDAGRuns(); + REQUIRE(!runs.empty()); + auto it = std::find_if(runs.begin(), runs.end(), + [runID](const auto &r) { return r.runID == runID; }); + REQUIRE(it != runs.end()); + REQUIRE(it->tag == tag); + REQUIRE(it->runState == +RunState::QUEUED); + } - REQUIRE(dagRun.taskRunStates.size() == tasks.size()); - auto nonQueuedTask = - std::find_if(dagRun.taskRunStates.begin(), dagRun.taskRunStates.end(), - [](const auto &a) { return a.second != +RunState::QUEUED; }); - REQUIRE(nonQueuedTask == dagRun.taskRunStates.end()); + // Verify states + { + REQUIRE(logger.getDAGRunState(runID) == +RunState::QUEUED); + for (const auto &[k, _] : tasks) { + REQUIRE(logger.getTaskState(runID, k) == +RunState::QUEUED); + } + } + + // Verify integrity of run + { + auto dagRun = logger.getDAGRun(runID); + + REQUIRE(dagRun.dagSpec.tag == tag); + REQUIRE(dagRun.dagSpec.tasks == tasks); + + REQUIRE(dagRun.taskRunStates.size() == tasks.size()); + auto nonQueuedTask = std::find_if( + dagRun.taskRunStates.begin(), dagRun.taskRunStates.end(), + [](const auto &a) { return a.second != +RunState::QUEUED; }); + REQUIRE(nonQueuedTask == dagRun.taskRunStates.end()); + REQUIRE(dagRun.dagStateChanges.size() == 1); + REQUIRE(dagRun.dagStateChanges.back().newState == +RunState::QUEUED); + } + + // Update DAG state and ensure that it's updated; + { + logger.updateDAGRunState(runID, RunState::RUNNING); + auto dagRun = logger.getDAGRun(runID); + REQUIRE(dagRun.dagStateChanges.back().newState == +RunState::RUNNING); + } + + // Update a task state + { + for (const auto &[k, v] : tasks) + logger.updateTaskState(runID, k, RunState::RUNNING); + auto dagRun = logger.getDAGRun(runID); + for (const auto &[k, v] : tasks) { + REQUIRE(dagRun.taskRunStates.at(k) == +RunState::RUNNING); + } + } - REQUIRE(dagRun.dagStateChanges.size() == 1); - REQUIRE(dagRun.dagStateChanges.back().newState == +RunState::QUEUED); return runID; } TEST_CASE("ostream_logger", "[ostream_logger]") { - // cleanup(); std::stringstream ss; daggy::loggers::dag_run::OStreamLogger logger(ss); @@ -49,6 +88,4 @@ TEST_CASE("ostream_logger", "[ostream_logger]") { testDAGRunInit(logger, "init_test", SAMPLE_TASKS); } - - // cleanup(); } diff --git a/tests/unit_dagrunner.cpp b/tests/unit_dagrunner.cpp new file mode 100644 index 0000000..b5d7b8a --- /dev/null +++ b/tests/unit_dagrunner.cpp @@ -0,0 +1,256 @@ +#include +#include +#include + +#include "daggy/DAGRunner.hpp" +#include "daggy/executors/task/ForkingTaskExecutor.hpp" +#include "daggy/executors/task/NoopTaskExecutor.hpp" +#include "daggy/loggers/dag_run/OStreamLogger.hpp" + +namespace fs = std::filesystem; + +TEST_CASE("dagrunner", "[dagrunner_order_preservation]") +{ + daggy::executors::task::NoopTaskExecutor ex; + std::stringstream ss; + daggy::loggers::dag_run::OStreamLogger logger(ss); + + daggy::TimePoint globalStartTime = daggy::Clock::now(); + + daggy::DAGSpec dagSpec; + + std::string testParams{ + R"({"DATE": ["2021-05-06", "2021-05-07", "2021-05-08", "2021-05-09" ]})"}; + dagSpec.taskConfig.variables = daggy::configFromJSON(testParams); + + std::string taskJSON = R"({ + "A": {"job": {"command": ["/usr/bin/touch", "{{DATE}}"]}, "children": [ "B","D" ]}, + "B": {"job": {"command": ["/usr/bin/touch", "{{DATE}}"]}, "children": [ "C","D","E" ]}, + "C": {"job": {"command": ["/usr/bin/touch", "{{DATE}}"]}, "children": [ "D"]}, + "D": {"job": {"command": ["/usr/bin/touch", "{{DATE}}"]}, "children": [ "E"]}, + "E": {"job": {"command": ["/usr/bin/touch", "{{DATE}}"]}} + })"; + + dagSpec.tasks = expandTaskSet(daggy::tasksFromJSON(taskJSON), ex, + dagSpec.taskConfig.variables); + + REQUIRE(dagSpec.tasks.size() == 20); + + auto dag = daggy::buildDAGFromTasks(dagSpec.tasks); + auto runID = logger.startDAGRun(dagSpec); + + daggy::DAGRunner runner(runID, ex, logger, dag, dagSpec.taskConfig); + + auto endDAG = runner.run(); + + REQUIRE(endDAG.allVisited()); + + // Ensure the run order + auto rec = logger.getDAGRun(runID); + + daggy::TimePoint globalStopTime = daggy::Clock::now(); + std::array minTimes; + minTimes.fill(globalStartTime); + std::array maxTimes; + maxTimes.fill(globalStopTime); + + for (const auto &[k, v] : rec.taskAttempts) { + size_t idx = k[0] - 65; + auto &startTime = minTimes[idx]; + auto &stopTime = maxTimes[idx]; + startTime = std::max(startTime, v.front().startTime); + stopTime = std::min(stopTime, v.back().stopTime); + } + + for (size_t i = 0; i < 5; ++i) { + for (size_t j = i + 1; j < 4; ++j) { + REQUIRE(maxTimes[i] < minTimes[j]); + } + } +} + +TEST_CASE("DAGRunner simple execution", "[dagrunner_simple]") +{ + daggy::executors::task::ForkingTaskExecutor ex(10); + std::stringstream ss; + daggy::loggers::dag_run::OStreamLogger logger(ss); + + daggy::DAGSpec dagSpec; + + SECTION("Simple execution") + { + std::string prefix = (fs::current_path() / "asdlk").string(); + std::unordered_map files{ + {"A", prefix + "_A"}, {"B", prefix + "_B"}, {"C", prefix + "_C"}}; + std::string taskJSON = + R"({"A": {"job": {"command": ["/usr/bin/touch", ")" + files.at("A") + + R"("]}, "children": ["C"]}, "B": {"job": {"command": ["/usr/bin/touch", ")" + + files.at("B") + + R"("]}, "children": ["C"]}, "C": {"job": {"command": ["/usr/bin/touch", ")" + + files.at("C") + R"("]}}})"; + dagSpec.tasks = expandTaskSet(daggy::tasksFromJSON(taskJSON), ex); + auto dag = daggy::buildDAGFromTasks(dagSpec.tasks); + auto runID = logger.startDAGRun(dagSpec); + daggy::DAGRunner runner(runID, ex, logger, dag, dagSpec.taskConfig); + auto endDAG = runner.run(); + REQUIRE(endDAG.allVisited()); + + for (const auto &[_, file] : files) { + REQUIRE(fs::exists(file)); + fs::remove(file); + } + + // Get the DAG Run Attempts + auto record = logger.getDAGRun(runID); + for (const auto &[_, attempts] : record.taskAttempts) { + REQUIRE(attempts.size() == 1); + REQUIRE(attempts.front().rc == 0); + } + } +} + +TEST_CASE("DAG Runner Restart old DAG", "[dagrunner_restart]") +{ + daggy::executors::task::ForkingTaskExecutor ex(10); + std::stringstream ss; + daggy::loggers::dag_run::OStreamLogger logger(ss); + daggy::DAGSpec dagSpec; + + SECTION("Recovery from Error") + { + auto cleanup = []() { + // Cleanup + std::vector paths{"rec_error_A", "noexist"}; + for (const auto &pth : paths) { + if (fs::exists(pth)) + fs::remove_all(pth); + } + }; + + cleanup(); + + std::string goodPrefix = "rec_error_"; + std::string badPrefix = "noexist/rec_error_"; + std::string taskJSON = + R"({"A": {"job": {"command": ["/usr/bin/touch", ")" + goodPrefix + + R"(A"]}, "children": ["C"]}, "B": {"job": {"command": ["/usr/bin/touch", ")" + + badPrefix + + R"(B"]}, "children": ["C"]}, "C": {"job": {"command": ["/usr/bin/touch", ")" + + badPrefix + R"(C"]}}})"; + dagSpec.tasks = expandTaskSet(daggy::tasksFromJSON(taskJSON), ex); + auto dag = daggy::buildDAGFromTasks(dagSpec.tasks); + + auto runID = logger.startDAGRun(dagSpec); + + daggy::DAGRunner runner(runID, ex, logger, dag, dagSpec.taskConfig); + auto tryDAG = runner.run(); + + REQUIRE(!tryDAG.allVisited()); + + // Create the missing dir, then continue to run the DAG + fs::create_directory("noexist"); + runner.resetRunning(); + auto endDAG = runner.run(); + + REQUIRE(endDAG.allVisited()); + + // Get the DAG Run Attempts + auto record = logger.getDAGRun(runID); + REQUIRE(record.taskAttempts["A_0"].size() == 1); // A ran fine + REQUIRE(record.taskAttempts["B_0"].size() == + 2); // B errored and had to be retried + REQUIRE(record.taskAttempts["C_0"].size() == + 1); // C wasn't run because B errored + + cleanup(); + } +} + +TEST_CASE("DAG Runner Generator Tasks", "[dagrunner_generator]") +{ + daggy::executors::task::ForkingTaskExecutor ex(10); + std::stringstream ss; + daggy::loggers::dag_run::OStreamLogger logger(ss); + daggy::DAGSpec dagSpec; + + SECTION("Generator tasks") + { + std::string testParams{R"({"DATE": ["2021-05-06", "2021-05-07" ]})"}; + dagSpec.taskConfig.variables = daggy::configFromJSON(testParams); + + std::string generatorOutput = + R"({"B": {"job": {"command": ["/usr/bin/echo", "-e", "{{DATE}}"]}, "children": ["C"]}})"; + fs::path ofn = fs::current_path() / "generator_test_output.json"; + std::ofstream ofh{ofn}; + ofh << generatorOutput << std::endl; + ofh.close(); + + daggy::TimePoint globalStartTime = daggy::Clock::now(); + std::stringstream jsonTasks; + jsonTasks + << R"({ "A": { "job": {"command": [ "/usr/bin/cat", )" + << std::quoted(ofn.string()) + << R"(]}, "children": ["C"], "isGenerator": true},)" + << R"("C": { "job": {"command": [ "/usr/bin/echo", "hello!"]} } })"; + + dagSpec.tasks = daggy::tasksFromJSON(jsonTasks.str()); + REQUIRE(dagSpec.tasks.size() == 2); + REQUIRE(dagSpec.tasks["A"].children == + std::unordered_set{"C"}); + dagSpec.tasks = + daggy::expandTaskSet(dagSpec.tasks, ex, dagSpec.taskConfig.variables); + REQUIRE(dagSpec.tasks.size() == 2); + REQUIRE(dagSpec.tasks["A_0"].children == + std::unordered_set{"C"}); + auto dag = daggy::buildDAGFromTasks(dagSpec.tasks); + REQUIRE(dag.size() == 2); + + auto runID = logger.startDAGRun(dagSpec); + daggy::DAGRunner runner(runID, ex, logger, dag, dagSpec.taskConfig); + auto finalDAG = runner.run(); + + REQUIRE(finalDAG.allVisited()); + REQUIRE(finalDAG.size() == 4); + + // Check the logger + auto record = logger.getDAGRun(runID); + + REQUIRE(record.dagSpec.tasks.size() == 4); + REQUIRE(record.taskRunStates.size() == 4); + for (const auto &[taskName, attempts] : record.taskAttempts) { + REQUIRE(attempts.size() == 1); + REQUIRE(attempts.back().rc == 0); + } + + // Ensure that children were updated properly + REQUIRE(record.dagSpec.tasks["A_0"].children == + std::unordered_set{"B_0", "B_1", "C"}); + REQUIRE(record.dagSpec.tasks["B_0"].children == + std::unordered_set{"C"}); + REQUIRE(record.dagSpec.tasks["B_1"].children == + std::unordered_set{"C"}); + REQUIRE(record.dagSpec.tasks["C_0"].children.empty()); + + // Ensure they were run in the right order + // All A's get run before B's, which run before C's + daggy::TimePoint globalStopTime = daggy::Clock::now(); + std::array minTimes; + minTimes.fill(globalStartTime); + std::array maxTimes; + maxTimes.fill(globalStopTime); + + for (const auto &[k, v] : record.taskAttempts) { + size_t idx = k[0] - 65; + auto &startTime = minTimes[idx]; + auto &stopTime = maxTimes[idx]; + startTime = std::max(startTime, v.front().startTime); + stopTime = std::min(stopTime, v.back().stopTime); + } + + for (size_t i = 0; i < 3; ++i) { + for (size_t j = i + 1; j < 2; ++j) { + REQUIRE(maxTimes[i] < minTimes[j]); + } + } + } +} diff --git a/tests/unit_executor_forkingexecutor.cpp b/tests/unit_executor_forkingexecutor.cpp index 42e393d..276ca47 100644 --- a/tests/unit_executor_forkingexecutor.cpp +++ b/tests/unit_executor_forkingexecutor.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include "daggy/Serialization.hpp" #include "daggy/Utilities.hpp" @@ -18,7 +19,7 @@ TEST_CASE("forking_executor", "[forking_executor]") REQUIRE(ex.validateTaskParameters(task.job)); - auto recFuture = ex.execute("command", task); + auto recFuture = ex.execute(0, "command", task); auto rec = recFuture.get(); REQUIRE(rec.rc == 0); @@ -32,7 +33,7 @@ TEST_CASE("forking_executor", "[forking_executor]") .job{{"command", daggy::executors::task::ForkingTaskExecutor::Command{ "/usr/bin/expr", "1", "+", "+"}}}}; - auto recFuture = ex.execute("command", task); + auto recFuture = ex.execute(0, "command", task); auto rec = recFuture.get(); REQUIRE(rec.rc == 2); @@ -40,6 +41,28 @@ TEST_CASE("forking_executor", "[forking_executor]") REQUIRE(rec.outputLog.empty()); } + SECTION("Killing a long task") + { + daggy::Task task{ + .job{{"command", daggy::executors::task::ForkingTaskExecutor::Command{ + "/usr/bin/sleep", "30"}}}}; + + auto start = daggy::Clock::now(); + auto recFuture = ex.execute(0, "command", task); + std::this_thread::sleep_for(1s); + ex.stop(0, "command"); + auto rec = recFuture.get(); + auto stop = daggy::Clock::now(); + + REQUIRE(rec.rc == 9); + REQUIRE(rec.errorLog.empty()); + REQUIRE(rec.outputLog.empty()); + REQUIRE(rec.executorLog == "Killed"); + REQUIRE( + std::chrono::duration_cast(stop - start).count() < + 20); + } + SECTION("Large Output") { const std::vector BIG_FILES{"/usr/share/dict/linux.words", @@ -54,7 +77,7 @@ TEST_CASE("forking_executor", "[forking_executor]") .job{{"command", daggy::executors::task::ForkingTaskExecutor::Command{ "/usr/bin/cat", bigFile}}}}; - auto recFuture = ex.execute("command", task); + auto recFuture = ex.execute(0, "command", task); auto rec = recFuture.get(); REQUIRE(rec.rc == 0); diff --git a/tests/unit_executor_slurmexecutor.cpp b/tests/unit_executor_slurmexecutor.cpp index 2b63a72..3dc8778 100644 --- a/tests/unit_executor_slurmexecutor.cpp +++ b/tests/unit_executor_slurmexecutor.cpp @@ -34,7 +34,7 @@ TEST_CASE("slurm_execution", "[slurm_executor]") REQUIRE(ex.validateTaskParameters(task.job)); - auto recFuture = ex.execute("command", task); + auto recFuture = ex.execute(0, "command", task); auto rec = recFuture.get(); REQUIRE(rec.rc == 0); @@ -49,7 +49,7 @@ TEST_CASE("slurm_execution", "[slurm_executor]") "/usr/bin/expr", "1", "+", "+"}}}}; task.job.merge(defaultJobValues); - auto recFuture = ex.execute("command", task); + auto recFuture = ex.execute(0, "command", task); auto rec = recFuture.get(); REQUIRE(rec.rc != 0); @@ -57,6 +57,23 @@ TEST_CASE("slurm_execution", "[slurm_executor]") REQUIRE(rec.outputLog.empty()); } + SECTION("Killing a long task") + { + daggy::Task task{ + .job{{"command", daggy::executors::task::SlurmTaskExecutor::Command{ + "/usr/bin/sleep", "30"}}}}; + task.job.merge(defaultJobValues); + + auto recFuture = ex.execute(0, "command", task); + ex.stop(0, "command"); + auto rec = recFuture.get(); + + REQUIRE(rec.rc == 9); + REQUIRE(rec.errorLog.empty()); + REQUIRE(rec.outputLog.empty()); + REQUIRE(rec.executorLog == "Job cancelled by user.\n"); + } + SECTION("Large Output") { const std::vector BIG_FILES{"/usr/share/dict/linux.words", @@ -72,7 +89,7 @@ TEST_CASE("slurm_execution", "[slurm_executor]") "/usr/bin/cat", bigFile}}}}; task.job.merge(defaultJobValues); - auto recFuture = ex.execute("command", task); + auto recFuture = ex.execute(0, "command", task); auto rec = recFuture.get(); REQUIRE(rec.rc == 0); diff --git a/tests/unit_server.cpp b/tests/unit_server.cpp index bb7f444..2d1bf01 100644 --- a/tests/unit_server.cpp +++ b/tests/unit_server.cpp @@ -1,51 +1,131 @@ +#include #include #include +#include #include #include #include #include +#include #include #include #include +#include namespace rj = rapidjson; -Pistache::Http::Response REQUEST(const std::string &url, - const std::string &payload = "") -{ - Pistache::Http::Experimental::Client client; - client.init(); - Pistache::Http::Response response; - auto reqSpec = (payload.empty() ? client.get(url) : client.post(url)); - reqSpec.timeout(std::chrono::seconds(2)); - if (!payload.empty()) { - reqSpec.body(payload); - } - auto request = reqSpec.send(); - bool ok = false, error = false; - std::string msg; - request.then( - [&](Pistache::Http::Response rsp) { - ok = true; - response = std::move(rsp); - }, - [&](std::exception_ptr ptr) { - error = true; - try { - std::rethrow_exception(std::move(ptr)); - } - catch (std::exception &e) { - msg = e.what(); - } - }); +using namespace daggy; - Pistache::Async::Barrier barrier(request); - barrier.wait_for(std::chrono::seconds(2)); - client.shutdown(); - if (error) { - throw std::runtime_error(msg); +#ifdef DEBUG_HTTP +static int my_trace(CURL *handle, curl_infotype type, char *data, size_t size, + void *userp) +{ + const char *text; + (void)handle; /* prevent compiler warning */ + (void)userp; + + switch (type) { + case CURLINFO_TEXT: + fprintf(stderr, "== Info: %s", data); + default: /* in case a new one is introduced to shock us */ + return 0; + + case CURLINFO_HEADER_OUT: + text = "=> Send header"; + break; + case CURLINFO_DATA_OUT: + text = "=> Send data"; + break; + case CURLINFO_SSL_DATA_OUT: + text = "=> Send SSL data"; + break; + case CURLINFO_HEADER_IN: + text = "<= Recv header"; + break; + case CURLINFO_DATA_IN: + text = "<= Recv data"; + break; + case CURLINFO_SSL_DATA_IN: + text = "<= Recv SSL data"; + break; } + + std::cerr << "\n================== " << text + << " ==================" << std::endl + << data << std::endl; + return 0; +} +#endif + +enum HTTPCode : long +{ + Ok = 200, + Not_Found = 404 +}; + +struct HTTPResponse +{ + HTTPCode code; + std::string body; +}; + +uint curlWriter(char *in, uint size, uint nmemb, std::stringstream *out) +{ + uint r; + r = size * nmemb; + out->write(in, r); + return r; +} + +HTTPResponse REQUEST(const std::string &url, const std::string &payload = "", + const std::string &method = "GET") +{ + HTTPResponse response; + + CURL *curl; + CURLcode res; + struct curl_slist *headers = NULL; + + curl_global_init(CURL_GLOBAL_ALL); + + curl = curl_easy_init(); + if (curl) { + std::stringstream buffer; + +#ifdef DEBUG_HTTP + curl_easy_setopt(curl, CURLOPT_DEBUGFUNCTION, my_trace); + curl_easy_setopt(curl, CURLOPT_VERBOSE, 1L); +#endif + + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, curlWriter); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &buffer); + + if (!payload.empty()) { + curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, payload.size()); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, payload.c_str()); + headers = curl_slist_append(headers, "Content-Type: Application/Json"); + } + curl_easy_setopt(curl, CURLOPT_CUSTOMREQUEST, method.c_str()); + headers = curl_slist_append(headers, "Expect:"); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + + res = curl_easy_perform(curl); + + if (res != CURLE_OK) { + curl_easy_cleanup(curl); + throw std::runtime_error(std::string{"CURL Failed: "} + + curl_easy_strerror(res)); + } + curl_easy_cleanup(curl); + + curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &response.code); + response.body = buffer.str(); + } + + curl_global_cleanup(); + return response; } @@ -68,19 +148,19 @@ TEST_CASE("rest_endpoint", "[server_basic]") SECTION("Ready Endpoint") { auto response = REQUEST(baseURL + "/ready"); - REQUIRE(response.code() == Pistache::Http::Code::Ok); + REQUIRE(response.code == HTTPCode::Ok); } SECTION("Querying a non-existent dagrunid should fail ") { auto response = REQUEST(baseURL + "/v1/dagrun/100"); - REQUIRE(response.code() != Pistache::Http::Code::Ok); + REQUIRE(response.code != HTTPCode::Ok); } SECTION("Simple DAGRun Submission") { std::string dagRun = R"({ - "name": "unit_server", + "tag": "unit_server", "parameters": { "FILE": [ "A", "B" ] }, "tasks": { "touch": { "job": { "command": [ "/usr/bin/touch", "dagrun_{{FILE}}" ]} }, @@ -90,14 +170,16 @@ TEST_CASE("rest_endpoint", "[server_basic]") } })"; + auto dagSpec = daggy::dagFromJSON(dagRun); + // Submit, and get the runID daggy::DAGRunID runID = 0; { - auto response = REQUEST(baseURL + "/v1/dagrun/", dagRun); - REQUIRE(response.code() == Pistache::Http::Code::Ok); + auto response = REQUEST(baseURL + "/v1/dagrun/", dagRun, "POST"); + REQUIRE(response.code == HTTPCode::Ok); rj::Document doc; - daggy::checkRJParse(doc.Parse(response.body().c_str())); + daggy::checkRJParse(doc.Parse(response.body.c_str())); REQUIRE(doc.IsObject()); REQUIRE(doc.HasMember("runID")); @@ -106,11 +188,11 @@ TEST_CASE("rest_endpoint", "[server_basic]") // Ensure our runID shows up in the list of running DAGs { - auto response = REQUEST(baseURL + "/v1/dagrun/"); - REQUIRE(response.code() == Pistache::Http::Code::Ok); + auto response = REQUEST(baseURL + "/v1/dagruns?all=1"); + REQUIRE(response.code == HTTPCode::Ok); rj::Document doc; - daggy::checkRJParse(doc.Parse(response.body().c_str())); + daggy::checkRJParse(doc.Parse(response.body.c_str())); REQUIRE(doc.IsArray()); REQUIRE(doc.Size() >= 1); @@ -120,10 +202,10 @@ TEST_CASE("rest_endpoint", "[server_basic]") for (size_t i = 0; i < runs.Size(); ++i) { const auto &run = runs[i]; REQUIRE(run.IsObject()); - REQUIRE(run.HasMember("name")); + REQUIRE(run.HasMember("tag")); REQUIRE(run.HasMember("runID")); - std::string runName = run["name"].GetString(); + std::string runName = run["tag"].GetString(); if (runName == "unit_server") { REQUIRE(run["runID"].GetUint64() == runID); found = true; @@ -133,13 +215,28 @@ TEST_CASE("rest_endpoint", "[server_basic]") REQUIRE(found); } + // Ensure we can get one of our tasks + { + auto response = REQUEST(baseURL + "/v1/dagrun/" + std::to_string(runID) + + "/task/cat_0"); + REQUIRE(response.code == HTTPCode::Ok); + + rj::Document doc; + daggy::checkRJParse(doc.Parse(response.body.c_str())); + + REQUIRE_NOTHROW(daggy::taskFromJSON("cat", doc)); + auto task = daggy::taskFromJSON("cat", doc); + + REQUIRE(task == dagSpec.tasks.at("cat")); + } + // Wait until our DAG is complete bool complete = true; for (auto i = 0; i < 10; ++i) { auto response = REQUEST(baseURL + "/v1/dagrun/" + std::to_string(runID)); - REQUIRE(response.code() == Pistache::Http::Code::Ok); + REQUIRE(response.code == HTTPCode::Ok); rj::Document doc; - daggy::checkRJParse(doc.Parse(response.body().c_str())); + daggy::checkRJParse(doc.Parse(response.body.c_str())); REQUIRE(doc.IsObject()); REQUIRE(doc.HasMember("taskStates")); @@ -173,6 +270,113 @@ TEST_CASE("rest_endpoint", "[server_basic]") fs::remove(pth); } } +} + +TEST_CASE("Server cancels and resumes execution", "[server_resume]") +{ + std::stringstream ss; + daggy::executors::task::ForkingTaskExecutor executor(10); + daggy::loggers::dag_run::OStreamLogger logger(ss); + Pistache::Address listenSpec("localhost", Pistache::Port(0)); + + const size_t nDAGRunners = 10, nWebThreads = 10; + + daggy::Server server(listenSpec, logger, executor, nDAGRunners); + server.init(nWebThreads); + server.start(); + + const std::string host = "localhost:"; + const std::string baseURL = host + std::to_string(server.getPort()); + + SECTION("Cancel / Resume DAGRun") + { + std::string dagRunJSON = R"({ + "tag": "unit_server", + "tasks": { + "touch_A": { "job": { "command": [ "/usr/bin/touch", "resume_touch_a" ]}, "children": ["touch_C"] }, + "sleep_B": { "job": { "command": [ "/usr/bin/sleep", "3" ]}, "children": ["touch_C"] }, + "touch_C": { "job": { "command": [ "/usr/bin/touch", "resume_touch_c" ]} } + } + })"; + + auto dagSpec = daggy::dagFromJSON(dagRunJSON); + + // Submit, and get the runID + daggy::DAGRunID runID; + { + auto response = REQUEST(baseURL + "/v1/dagrun/", dagRunJSON, "POST"); + REQUIRE(response.code == HTTPCode::Ok); + + rj::Document doc; + daggy::checkRJParse(doc.Parse(response.body.c_str())); + REQUIRE(doc.IsObject()); + REQUIRE(doc.HasMember("runID")); + + runID = doc["runID"].GetUint64(); + } + + std::this_thread::sleep_for(1s); + + // Stop the current run + { + auto response = REQUEST( + baseURL + "/v1/dagrun/" + std::to_string(runID) + "/state/KILLED", "", + "PATCH"); + REQUIRE(response.code == HTTPCode::Ok); + REQUIRE(logger.getDAGRunState(runID) == +daggy::RunState::KILLED); + } + + // Verify that the run still exists + { + auto dagRun = logger.getDAGRun(runID); + REQUIRE(dagRun.taskRunStates.at("touch_A_0") == + +daggy::RunState::COMPLETED); + REQUIRE(fs::exists("resume_touch_a")); + + REQUIRE(dagRun.taskRunStates.at("sleep_B_0") == + +daggy::RunState::ERRORED); + REQUIRE(dagRun.taskRunStates.at("touch_C_0") == +daggy::RunState::QUEUED); + } + + // Set the errored task state + { + auto url = baseURL + "/v1/dagrun/" + std::to_string(runID) + + "/task/sleep_B_0/state/QUEUED"; + auto response = REQUEST(url, "", "PATCH"); + REQUIRE(response.code == HTTPCode::Ok); + REQUIRE(logger.getTaskState(runID, "sleep_B_0") == + +daggy::RunState::QUEUED); + } + + // Resume + { + struct stat s; + + lstat("resume_touch_A", &s); + auto preMTime = s.st_mtim.tv_sec; + + auto response = REQUEST( + baseURL + "/v1/dagrun/" + std::to_string(runID) + "/state/QUEUED", "", + "PATCH"); + + // Wait for run to complete + std::this_thread::sleep_for(5s); + REQUIRE(logger.getDAGRunState(runID) == +daggy::RunState::COMPLETED); + + REQUIRE(fs::exists("resume_touch_c")); + REQUIRE(fs::exists("resume_touch_a")); + + for (const auto &[taskName, task] : dagSpec.tasks) { + REQUIRE(logger.getTaskState(runID, taskName + "_0") == + +daggy::RunState::COMPLETED); + } + + // Ensure "touch_A" wasn't run again + lstat("resume_touch_A", &s); + auto postMTime = s.st_mtim.tv_sec; + REQUIRE(preMTime == postMTime); + } + } server.shutdown(); } diff --git a/tests/unit_utilities.cpp b/tests/unit_utilities.cpp index 1c6a6f0..ddb6ae5 100644 --- a/tests/unit_utilities.cpp +++ b/tests/unit_utilities.cpp @@ -8,11 +8,6 @@ #include "daggy/Serialization.hpp" #include "daggy/Utilities.hpp" -#include "daggy/executors/task/ForkingTaskExecutor.hpp" -#include "daggy/executors/task/NoopTaskExecutor.hpp" -#include "daggy/loggers/dag_run/OStreamLogger.hpp" - -namespace fs = std::filesystem; TEST_CASE("string_utilities", "[utilities_string]") { @@ -59,234 +54,3 @@ TEST_CASE("string_expansion", "[utilities_parameter_expansion]") REQUIRE(result.size() == 4); } } - -TEST_CASE("dag_runner_order", "[dagrun_order]") -{ - daggy::executors::task::NoopTaskExecutor ex; - std::stringstream ss; - daggy::loggers::dag_run::OStreamLogger logger(ss); - - daggy::TimePoint globalStartTime = daggy::Clock::now(); - - std::string testParams{ - R"({"DATE": ["2021-05-06", "2021-05-07", "2021-05-08", "2021-05-09" ]})"}; - auto params = daggy::configFromJSON(testParams); - - std::string taskJSON = R"({ - "A": {"job": {"command": ["/usr/bin/touch", "{{DATE}}"]}, "children": [ "B","D" ]}, - "B": {"job": {"command": ["/usr/bin/touch", "{{DATE}}"]}, "children": [ "C","D","E" ]}, - "C": {"job": {"command": ["/usr/bin/touch", "{{DATE}}"]}, "children": [ "D"]}, - "D": {"job": {"command": ["/usr/bin/touch", "{{DATE}}"]}, "children": [ "E"]}, - "E": {"job": {"command": ["/usr/bin/touch", "{{DATE}}"]}} - })"; - - auto tasks = expandTaskSet(daggy::tasksFromJSON(taskJSON), ex, params); - - REQUIRE(tasks.size() == 20); - - auto dag = daggy::buildDAGFromTasks(tasks); - auto runID = logger.startDAGRun("test_run", tasks); - auto endDAG = daggy::runDAG(runID, ex, logger, dag); - - REQUIRE(endDAG.allVisited()); - - // Ensure the run order - auto rec = logger.getDAGRun(runID); - - daggy::TimePoint globalStopTime = daggy::Clock::now(); - std::array minTimes; - minTimes.fill(globalStartTime); - std::array maxTimes; - maxTimes.fill(globalStopTime); - - for (const auto &[k, v] : rec.taskAttempts) { - size_t idx = k[0] - 65; - auto &startTime = minTimes[idx]; - auto &stopTime = maxTimes[idx]; - startTime = std::max(startTime, v.front().startTime); - stopTime = std::min(stopTime, v.back().stopTime); - } - - for (size_t i = 0; i < 5; ++i) { - for (size_t j = i + 1; j < 4; ++j) { - REQUIRE(maxTimes[i] < minTimes[j]); - } - } -} - -TEST_CASE("dag_runner", "[utilities_dag_runner]") -{ - daggy::executors::task::ForkingTaskExecutor ex(10); - std::stringstream ss; - daggy::loggers::dag_run::OStreamLogger logger(ss); - - SECTION("Simple execution") - { - std::string prefix = (fs::current_path() / "asdlk").string(); - std::unordered_map files{ - {"A", prefix + "_A"}, {"B", prefix + "_B"}, {"C", prefix + "_C"}}; - std::string taskJSON = - R"({"A": {"job": {"command": ["/usr/bin/touch", ")" + files.at("A") + - R"("]}, "children": ["C"]}, "B": {"job": {"command": ["/usr/bin/touch", ")" + - files.at("B") + - R"("]}, "children": ["C"]}, "C": {"job": {"command": ["/usr/bin/touch", ")" + - files.at("C") + R"("]}}})"; - auto tasks = expandTaskSet(daggy::tasksFromJSON(taskJSON), ex); - auto dag = daggy::buildDAGFromTasks(tasks); - auto runID = logger.startDAGRun("test_run", tasks); - auto endDAG = daggy::runDAG(runID, ex, logger, dag); - - REQUIRE(endDAG.allVisited()); - - for (const auto &[_, file] : files) { - REQUIRE(fs::exists(file)); - fs::remove(file); - } - - // Get the DAG Run Attempts - auto record = logger.getDAGRun(runID); - for (const auto &[_, attempts] : record.taskAttempts) { - REQUIRE(attempts.size() == 1); - REQUIRE(attempts.front().rc == 0); - } - } -} - -TEST_CASE("runDAG_recovery", "[runDAG]") -{ - daggy::executors::task::ForkingTaskExecutor ex(10); - std::stringstream ss; - daggy::loggers::dag_run::OStreamLogger logger(ss); - - SECTION("Recovery from Error") - { - auto cleanup = []() { - // Cleanup - std::vector paths{"rec_error_A", "noexist"}; - for (const auto &pth : paths) { - if (fs::exists(pth)) - fs::remove_all(pth); - } - }; - - cleanup(); - - std::string goodPrefix = "rec_error_"; - std::string badPrefix = "noexist/rec_error_"; - std::string taskJSON = - R"({"A": {"job": {"command": ["/usr/bin/touch", ")" + goodPrefix + - R"(A"]}, "children": ["C"]}, "B": {"job": {"command": ["/usr/bin/touch", ")" + - badPrefix + - R"(B"]}, "children": ["C"]}, "C": {"job": {"command": ["/usr/bin/touch", ")" + - badPrefix + R"(C"]}}})"; - auto tasks = expandTaskSet(daggy::tasksFromJSON(taskJSON), ex); - auto dag = daggy::buildDAGFromTasks(tasks); - - auto runID = logger.startDAGRun("test_run", tasks); - - auto tryDAG = daggy::runDAG(runID, ex, logger, dag); - - REQUIRE(!tryDAG.allVisited()); - - // Create the missing dir, then continue to run the DAG - fs::create_directory("noexist"); - tryDAG.resetRunning(); - auto endDAG = daggy::runDAG(runID, ex, logger, tryDAG); - - REQUIRE(endDAG.allVisited()); - - // Get the DAG Run Attempts - auto record = logger.getDAGRun(runID); - REQUIRE(record.taskAttempts["A_0"].size() == 1); // A ran fine - REQUIRE(record.taskAttempts["B_0"].size() == - 2); // B errored and had to be retried - REQUIRE(record.taskAttempts["C_0"].size() == - 1); // C wasn't run because B errored - - cleanup(); - } -} - -TEST_CASE("runDAG_generator", "[runDAG_generator]") -{ - daggy::executors::task::ForkingTaskExecutor ex(10); - std::stringstream ss; - daggy::loggers::dag_run::OStreamLogger logger(ss); - - SECTION("Generator tasks") - { - std::string testParams{R"({"DATE": ["2021-05-06", "2021-05-07" ]})"}; - auto params = daggy::configFromJSON(testParams); - - std::string generatorOutput = - R"({"B": {"job": {"command": ["/usr/bin/echo", "-e", "{{DATE}}"]}, "children": ["C"]}})"; - fs::path ofn = fs::current_path() / "generator_test_output.json"; - std::ofstream ofh{ofn}; - ofh << generatorOutput << std::endl; - ofh.close(); - - daggy::TimePoint globalStartTime = daggy::Clock::now(); - std::stringstream jsonTasks; - jsonTasks - << R"({ "A": { "job": {"command": [ "/usr/bin/cat", )" - << std::quoted(ofn.string()) - << R"(]}, "children": ["C"], "isGenerator": true},)" - << R"("C": { "job": {"command": [ "/usr/bin/echo", "hello!"]} } })"; - - auto baseTasks = daggy::tasksFromJSON(jsonTasks.str()); - REQUIRE(baseTasks.size() == 2); - REQUIRE(baseTasks["A"].children == std::unordered_set{"C"}); - auto tasks = daggy::expandTaskSet(baseTasks, ex, params); - REQUIRE(tasks.size() == 2); - REQUIRE(tasks["A_0"].children == std::unordered_set{"C"}); - auto dag = daggy::buildDAGFromTasks(tasks); - REQUIRE(dag.size() == 2); - - auto runID = logger.startDAGRun("generator_run", tasks); - auto finalDAG = daggy::runDAG(runID, ex, logger, dag, params); - - REQUIRE(finalDAG.allVisited()); - REQUIRE(finalDAG.size() == 4); - - // Check the logger - auto record = logger.getDAGRun(runID); - - REQUIRE(record.tasks.size() == 4); - REQUIRE(record.taskRunStates.size() == 4); - for (const auto &[taskName, attempts] : record.taskAttempts) { - REQUIRE(attempts.size() == 1); - REQUIRE(attempts.back().rc == 0); - } - - // Ensure that children were updated properly - REQUIRE(record.tasks["A_0"].children == - std::unordered_set{"B_0", "B_1", "C"}); - REQUIRE(record.tasks["B_0"].children == - std::unordered_set{"C"}); - REQUIRE(record.tasks["B_1"].children == - std::unordered_set{"C"}); - REQUIRE(record.tasks["C_0"].children.empty()); - - // Ensure they were run in the right order - // All A's get run before B's, which run before C's - daggy::TimePoint globalStopTime = daggy::Clock::now(); - std::array minTimes; - minTimes.fill(globalStartTime); - std::array maxTimes; - maxTimes.fill(globalStopTime); - - for (const auto &[k, v] : record.taskAttempts) { - size_t idx = k[0] - 65; - auto &startTime = minTimes[idx]; - auto &stopTime = maxTimes[idx]; - startTime = std::max(startTime, v.front().startTime); - stopTime = std::min(stopTime, v.back().stopTime); - } - - for (size_t i = 0; i < 3; ++i) { - for (size_t j = i + 1; j < 2; ++j) { - REQUIRE(maxTimes[i] < minTimes[j]); - } - } - } -}