From 65ab439848a7d467403831f08e5d5193aa5ae245 Mon Sep 17 00:00:00 2001 From: Ian Roddis Date: Tue, 5 Oct 2021 11:57:55 -0300 Subject: [PATCH] Squashed commit of the following: commit b06b11cbb5d09c6d091551e39767cd3316f88376 Author: Ian Roddis Date: Tue Oct 5 11:57:37 2021 -0300 Fixing failing unit test commit fe2a43a19b2a16a9aedd9e95e71e672935ecaeb1 Author: Ian Roddis Date: Tue Oct 5 11:54:01 2021 -0300 Adding endpoints and updating documentation commit 46e0deeefb8b06291ae5e2d6b8ec4749c5b0ea6f Author: Ian Roddis Date: Tue Oct 5 11:49:43 2021 -0300 Completing unit tests and relevant fixes commit e0569f370624844feee6aae4708bfe683f4156cf Author: Ian Roddis Date: Mon Oct 4 17:30:59 2021 -0300 Adding in gcc tsan for debug builds to help with race conditions, fixing many of those, and fixing really crummy assumption about how futures worked that will speed up task execution by a ton. commit c748a4f592e1ada5546908be5281d04f4749539d Author: Ian Roddis Date: Mon Oct 4 10:14:43 2021 -0300 Checkpointing work that seems to have resolved the race condition commit 7a79f2943e0d50545d976a28b4b379340a90dded Author: Ian Roddis Date: Wed Sep 29 09:27:07 2021 -0300 Completing the rough-in for DAG killing / pausing / resuming commit 4cf8d81d5f6fcf4a7dd83d8fca3e23f153aa8acb Author: Ian Roddis Date: Tue Sep 28 14:53:50 2021 -0300 Adding dagrunner unit tests, adding a resetRunning method to resume commit 54e2c1f9f5e7d5b339d71be024e0e390c4d2bf61 Author: Ian Roddis Date: Tue Sep 28 14:45:57 2021 -0300 Refactoring runDAG into DAGRunner commit 682be7a11e2fae850e1bc3e207628d2335768c2b Author: Ian Roddis Date: Tue Sep 28 14:34:43 2021 -0300 Adding DAGRunner class to replace Utilities::runDAG, making Slurm cancellation rc agree with SIGKILL commit 4171b3a6998791abfc71b04f8de1ae93c4f90a78 Author: Ian Roddis Date: Tue Sep 28 14:14:17 2021 -0300 Adding unit tests for stopping jobs to slurm commit dc0b1ff26a5d98471164132d35bb8a552cc75ff8 Author: Ian Roddis Date: Tue Sep 28 14:04:15 2021 -0300 Adding in stop method for task executors commit e752b44f55113be54392bcbb5c3d6f251d673cfa Author: Ian Roddis Date: Tue Sep 28 12:32:06 2021 -0300 Adding additional tests for loggers commit f0773d5a84a422738fc17c9277a2b735a21a3d04 Author: Ian Roddis Date: Tue Sep 28 12:29:21 2021 -0300 Unit tests pass commit 993ff2810de2d53dc6a59ab53d620fecf152d4a0 Author: Ian Roddis Date: Tue Sep 28 12:24:34 2021 -0300 Adding handling for new routes, still need to add tests for new routes commit 676623b14e45759872a2dbcbc98f6a744e022a71 Author: Ian Roddis Date: Tue Sep 28 12:12:43 2021 -0300 Adding handling for new routes, still need to add tests for new routes commit b9edb6ba291eb064f4c459a308ea6912fba9fa02 Author: Ian Roddis Date: Mon Sep 27 11:59:14 2021 -0300 Defining new endpoints, fixing dag resumption code, adding PAUSED state, refactoring DAGSpec and adding deserializer --- CMakeLists.txt | 13 +- README.md | 13 +- cmake/daggy_features.cmake | 2 +- daggy/include/daggy/DAGRunner.hpp | 55 +++ daggy/include/daggy/Defines.hpp | 19 +- daggy/include/daggy/Serialization.hpp | 5 + daggy/include/daggy/Server.hpp | 44 +- daggy/include/daggy/Utilities.hpp | 4 - .../executors/task/ForkingTaskExecutor.hpp | 15 +- .../daggy/executors/task/NoopTaskExecutor.hpp | 5 +- .../executors/task/SlurmTaskExecutor.hpp | 7 +- .../daggy/executors/task/TaskExecutor.hpp | 6 +- .../daggy/loggers/dag_run/DAGRunLogger.hpp | 16 +- .../include/daggy/loggers/dag_run/Defines.hpp | 5 +- .../daggy/loggers/dag_run/OStreamLogger.hpp | 13 +- daggy/src/CMakeLists.txt | 1 + daggy/src/DAGRunner.cpp | 213 +++++++++ daggy/src/Serialization.cpp | 53 ++- daggy/src/Server.cpp | 407 +++++++++++++----- daggy/src/Utilities.cpp | 120 +----- .../executors/task/ForkingTaskExecutor.cpp | 74 +++- daggy/src/executors/task/NoopTaskExecutor.cpp | 8 +- .../src/executors/task/SlurmTaskExecutor.cpp | 48 ++- daggy/src/loggers/dag_run/OStreamLogger.cpp | 68 ++- endpoints.otl | 30 ++ tests/CMakeLists.txt | 3 +- tests/unit_dagrun_loggers.cpp | 69 ++- tests/unit_dagrunner.cpp | 256 +++++++++++ tests/unit_executor_forkingexecutor.cpp | 29 +- tests/unit_executor_slurmexecutor.cpp | 23 +- tests/unit_server.cpp | 296 +++++++++++-- tests/unit_utilities.cpp | 236 ---------- 32 files changed, 1538 insertions(+), 618 deletions(-) create mode 100644 daggy/include/daggy/DAGRunner.hpp create mode 100644 daggy/src/DAGRunner.cpp create mode 100644 endpoints.otl create mode 100644 tests/unit_dagrunner.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 3418100..f2107c7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,9 +1,18 @@ cmake_minimum_required(VERSION 3.14) project(overall) + +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "Debug") +endif() + set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED True) set(CMAKE_EXPORT_COMPILE_COMMANDS True) -SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Werror") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Werror") + +if(CMAKE_BUILD_TYPE MATCHES "Debug") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=thread -fno-omit-frame-pointer") +endif() set(THIRD_PARTY_DIR ${CMAKE_BINARY_DIR}/third_party) @@ -19,6 +28,8 @@ include(cmake/argparse.cmake) include(cmake/Catch2.cmake) include(cmake/daggy_features.cmake) +message("-- CMAKE Build Type is ${CMAKE_BUILD_TYPE}") + # use, i.e. don't skip the full RPATH for the build tree set(CMAKE_SKIP_BUILD_RPATH FALSE) diff --git a/README.md b/README.md index 4601c13..fa0377b 100644 --- a/README.md +++ b/README.md @@ -28,11 +28,10 @@ graph LR Individual tasks (vertices) are run via a task executor. Daggy supports multiple executors, from local executor (via fork), to distributed work managers like [slurm](https://slurm.schedmd.com/overview.html) -or [kubernetes](https://kubernetes.io/) (both planned). +or [kubernetes](https://kubernetes.io/) (planned). -State is maintained via state loggers. Currently daggy supports an in-memory state manager (OStreamLogger), and a -filesystem logger (FileSystemLogger). Future plans include supporting [redis](https://redis.io) -and [postgres](https://postgresql.org). +State is maintained via state loggers. Currently daggy supports an in-memory state manager (OStreamLogger). +Future plans include supporting [redis](https://redis.io) and [postgres](https://postgresql.org). Building == @@ -43,13 +42,17 @@ Building - cmake >= 3.14 - gcc >= 8 +- libslurm (if needed) + ``` git clone https://gitlab.com/iroddis/daggy cd daggy mkdir build cd build -cmake .. +cmake [-DDAGGY_ENABLE_SLURM=ON] .. make + +tests/tests # for unit tests ``` DAG Run Definition diff --git a/cmake/daggy_features.cmake b/cmake/daggy_features.cmake index 67ca04c..d929413 100644 --- a/cmake/daggy_features.cmake +++ b/cmake/daggy_features.cmake @@ -1,5 +1,5 @@ # SLURM -message("DAGGY_ENABLED_SLURM is set to ${DAGGY_ENABLE_SLURM}") +message("-- DAGGY_ENABLED_SLURM is set to ${DAGGY_ENABLE_SLURM}") if (DAGGY_ENABLE_SLURM) find_library(SLURM_LIB libslurm.so libslurm.a slurm REQUIRED) find_path(SLURM_INCLUDE_DIR "slurm/slurm.h" REQUIRED) diff --git a/daggy/include/daggy/DAGRunner.hpp b/daggy/include/daggy/DAGRunner.hpp new file mode 100644 index 0000000..f821de2 --- /dev/null +++ b/daggy/include/daggy/DAGRunner.hpp @@ -0,0 +1,55 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +#include "DAG.hpp" +#include "Defines.hpp" +#include "Serialization.hpp" +#include "Utilities.hpp" +#include "daggy/executors/task/TaskExecutor.hpp" +#include "daggy/loggers/dag_run/DAGRunLogger.hpp" + +using namespace std::chrono_literals; + +namespace daggy { + class DAGRunner + { + public: + DAGRunner(DAGRunID runID, executors::task::TaskExecutor &executor, + loggers::dag_run::DAGRunLogger &logger, TaskDAG dag, + const TaskParameters &taskParams); + + ~DAGRunner(); + + TaskDAG run(); + void resetRunning(); + void stop(bool kill = false, bool blocking = false); + + private: + void collectFinished(); + void queuePending(); + void killRunning(); + + DAGRunID runID_; + executors::task::TaskExecutor &executor_; + loggers::dag_run::DAGRunLogger &logger_; + TaskDAG dag_; + const TaskParameters &taskParams_; + std::atomic running_; + std::atomic kill_; + + ssize_t nRunningTasks_; + ssize_t nErroredTasks_; + std::unordered_map> runningTasks_; + std::unordered_map taskAttemptCounts_; + + std::mutex runGuard_; + }; +} // namespace daggy diff --git a/daggy/include/daggy/Defines.hpp b/daggy/include/daggy/Defines.hpp index cbf4422..4e7e1ee 100644 --- a/daggy/include/daggy/Defines.hpp +++ b/daggy/include/daggy/Defines.hpp @@ -22,9 +22,8 @@ namespace daggy { // DAG Runs using DAGRunID = size_t; - BETTER_ENUM(RunState, uint32_t, QUEUED = 1 << 0, RUNNING = 1 << 1, - RETRY = 1 << 2, ERRORED = 1 << 3, KILLED = 1 << 4, - COMPLETED = 1 << 5); + BETTER_ENUM(RunState, uint32_t, QUEUED = 1, RUNNING, RETRY, ERRORED, KILLED, + PAUSED, COMPLETED); struct Task { @@ -50,6 +49,20 @@ namespace daggy { using TaskSet = std::unordered_map; + // All the components required to define and run a DAG + struct TaskParameters + { + ConfigValues variables; + ConfigValues jobDefaults; + }; + + struct DAGSpec + { + std::string tag; + TaskSet tasks; + TaskParameters taskConfig; + }; + struct AttemptRecord { TimePoint startTime; diff --git a/daggy/include/daggy/Serialization.hpp b/daggy/include/daggy/Serialization.hpp index 0d4fb36..5259d3a 100644 --- a/daggy/include/daggy/Serialization.hpp +++ b/daggy/include/daggy/Serialization.hpp @@ -8,6 +8,7 @@ #include #include "Defines.hpp" +#include "Utilities.hpp" namespace rj = rapidjson; @@ -36,6 +37,10 @@ namespace daggy { std::string tasksToJSON(const TaskSet &tasks); + // Full specs + DAGSpec dagFromJSON(const rj::Value &spec); + DAGSpec dagFromJSON(const std::string &jsonSpec); + // Attempt Records std::string attemptRecordToJSON(const AttemptRecord &attemptRecord); diff --git a/daggy/include/daggy/Server.hpp b/daggy/include/daggy/Server.hpp index f4f4231..2f6dbc8 100644 --- a/daggy/include/daggy/Server.hpp +++ b/daggy/include/daggy/Server.hpp @@ -6,10 +6,15 @@ #include +#include "DAGRunner.hpp" #include "ThreadPool.hpp" #include "executors/task/TaskExecutor.hpp" #include "loggers/dag_run/DAGRunLogger.hpp" +#define DAGGY_REST_HANDLER(func) \ + void func(const Pistache::Rest::Request &request, \ + Pistache::Http::ResponseWriter response); + namespace fs = std::filesystem; namespace daggy { @@ -18,14 +23,8 @@ namespace daggy { public: Server(const Pistache::Address &listenSpec, loggers::dag_run::DAGRunLogger &logger, - executors::task::TaskExecutor &executor, size_t nDAGRunners) - : endpoint_(listenSpec) - , desc_("Daggy API", "0.1") - , logger_(logger) - , executor_(executor) - , runnerPool_(nDAGRunners) - { - } + executors::task::TaskExecutor &executor, size_t nDAGRunners); + ~Server(); Server &setSSLCertificates(const fs::path &cert, const fs::path &key); @@ -39,21 +38,21 @@ namespace daggy { private: void createDescription(); + void queueDAG_(DAGRunID runID, const TaskDAG &dag, + const TaskParameters &taskParameters); - void handleRunDAG(const Pistache::Rest::Request &request, - Pistache::Http::ResponseWriter response); + DAGGY_REST_HANDLER(handleReady); // X + DAGGY_REST_HANDLER(handleQueryDAGs); // X + DAGGY_REST_HANDLER(handleRunDAG); // X + DAGGY_REST_HANDLER(handleValidateDAG); // X + DAGGY_REST_HANDLER(handleGetDAGRun); // X + DAGGY_REST_HANDLER(handleGetDAGRunState); // X + DAGGY_REST_HANDLER(handleSetDAGRunState); // X + DAGGY_REST_HANDLER(handleGetTask); // X + DAGGY_REST_HANDLER(handleGetTaskState); // X + DAGGY_REST_HANDLER(handleSetTaskState); // X - void handleGetDAGRuns(const Pistache::Rest::Request &request, - Pistache::Http::ResponseWriter response); - - void handleGetDAGRun(const Pistache::Rest::Request &request, - Pistache::Http::ResponseWriter response); - - void handleReady(const Pistache::Rest::Request &request, - Pistache::Http::ResponseWriter response); - - bool handleAuth(const Pistache::Rest::Request &request, - Pistache::Http::ResponseWriter &response); + bool handleAuth(const Pistache::Rest::Request &request); Pistache::Http::Endpoint endpoint_; Pistache::Rest::Description desc_; @@ -62,5 +61,8 @@ namespace daggy { loggers::dag_run::DAGRunLogger &logger_; executors::task::TaskExecutor &executor_; ThreadPool runnerPool_; + + std::mutex runnerGuard_; + std::unordered_map> runners_; }; } // namespace daggy diff --git a/daggy/include/daggy/Utilities.hpp b/daggy/include/daggy/Utilities.hpp index f5e9602..8c2f4dc 100644 --- a/daggy/include/daggy/Utilities.hpp +++ b/daggy/include/daggy/Utilities.hpp @@ -31,9 +31,5 @@ namespace daggy { void updateDAGFromTasks(TaskDAG &dag, const TaskSet &tasks); - TaskDAG runDAG(DAGRunID runID, executors::task::TaskExecutor &executor, - loggers::dag_run::DAGRunLogger &logger, TaskDAG dag, - const ConfigValues job = {}); - std::ostream &operator<<(std::ostream &os, const TimePoint &tp); } // namespace daggy diff --git a/daggy/include/daggy/executors/task/ForkingTaskExecutor.hpp b/daggy/include/daggy/executors/task/ForkingTaskExecutor.hpp index 860fe04..573c4fe 100644 --- a/daggy/include/daggy/executors/task/ForkingTaskExecutor.hpp +++ b/daggy/include/daggy/executors/task/ForkingTaskExecutor.hpp @@ -10,10 +10,8 @@ namespace daggy::executors::task { public: using Command = std::vector; - explicit ForkingTaskExecutor(size_t nThreads) - : tp_(nThreads) - { - } + explicit ForkingTaskExecutor(size_t nThreads); + ~ForkingTaskExecutor() override; // Validates the job to ensure that all required values are set and are of // the right type, @@ -23,11 +21,16 @@ namespace daggy::executors::task { const ConfigValues &job, const ConfigValues &expansionValues) override; // Runs the task - std::future execute(const std::string &taskName, + std::future execute(DAGRunID runID, + const std::string &taskName, const Task &task) override; + bool stop(DAGRunID runID, const std::string &taskName) override; + private: ThreadPool tp_; - AttemptRecord runTask(const Task &task); + std::mutex taskControlsGuard_; + AttemptRecord runTask(const Task &task, std::atomic &running); + std::unordered_map> taskControls_; }; } // namespace daggy::executors::task diff --git a/daggy/include/daggy/executors/task/NoopTaskExecutor.hpp b/daggy/include/daggy/executors/task/NoopTaskExecutor.hpp index 751b255..92730cc 100644 --- a/daggy/include/daggy/executors/task/NoopTaskExecutor.hpp +++ b/daggy/include/daggy/executors/task/NoopTaskExecutor.hpp @@ -16,7 +16,10 @@ namespace daggy::executors::task { const ConfigValues &job, const ConfigValues &expansionValues) override; // Runs the task - std::future execute(const std::string &taskName, + std::future execute(DAGRunID runID, + const std::string &taskName, const Task &task) override; + + bool stop(DAGRunID runID, const std::string &taskName) override; }; } // namespace daggy::executors::task diff --git a/daggy/include/daggy/executors/task/SlurmTaskExecutor.hpp b/daggy/include/daggy/executors/task/SlurmTaskExecutor.hpp index f8f159e..90e3e2e 100644 --- a/daggy/include/daggy/executors/task/SlurmTaskExecutor.hpp +++ b/daggy/include/daggy/executors/task/SlurmTaskExecutor.hpp @@ -19,15 +19,20 @@ namespace daggy::executors::task { const ConfigValues &job, const ConfigValues &expansionValues) override; // Runs the task - std::future execute(const std::string &taskName, + std::future execute(DAGRunID runID, + const std::string &taskName, const Task &task) override; + bool stop(DAGRunID runID, const std::string &taskName) override; + private: struct Job { std::promise prom; std::string stdoutFile; std::string stderrFile; + DAGRunID runID; + std::string taskName; }; std::mutex promiseGuard_; diff --git a/daggy/include/daggy/executors/task/TaskExecutor.hpp b/daggy/include/daggy/executors/task/TaskExecutor.hpp index 682bcea..f2c02ec 100644 --- a/daggy/include/daggy/executors/task/TaskExecutor.hpp +++ b/daggy/include/daggy/executors/task/TaskExecutor.hpp @@ -27,7 +27,11 @@ namespace daggy::executors::task { const ConfigValues &job, const ConfigValues &expansionValues) = 0; // Blocking execution of a task - virtual std::future execute(const std::string &taskName, + virtual std::future execute(DAGRunID runID, + const std::string &taskName, const Task &task) = 0; + + // Kill a currently executing task. This will resolve the future. + virtual bool stop(DAGRunID runID, const std::string &taskName) = 0; }; } // namespace daggy::executors::task diff --git a/daggy/include/daggy/loggers/dag_run/DAGRunLogger.hpp b/daggy/include/daggy/loggers/dag_run/DAGRunLogger.hpp index 0e97036..3232dc5 100644 --- a/daggy/include/daggy/loggers/dag_run/DAGRunLogger.hpp +++ b/daggy/include/daggy/loggers/dag_run/DAGRunLogger.hpp @@ -17,8 +17,8 @@ namespace daggy::loggers::dag_run { public: virtual ~DAGRunLogger() = default; - // Execution - virtual DAGRunID startDAGRun(std::string name, const TaskSet &tasks) = 0; + // Insertion / Updates + virtual DAGRunID startDAGRun(const DAGSpec &dagSpec) = 0; virtual void addTask(DAGRunID dagRunID, const std::string &taskName, const Task &task) = 0; @@ -35,8 +35,16 @@ namespace daggy::loggers::dag_run { RunState state) = 0; // Querying - virtual std::vector getDAGs(uint32_t stateMask) = 0; + virtual DAGSpec getDAGSpec(DAGRunID dagRunID) = 0; - virtual DAGRunRecord getDAGRun(DAGRunID dagRunID) = 0; + virtual std::vector queryDAGRuns(const std::string &tag = "", + bool all = false) = 0; + + virtual RunState getDAGRunState(DAGRunID dagRunID) = 0; + virtual DAGRunRecord getDAGRun(DAGRunID dagRunID) = 0; + + virtual Task &getTask(DAGRunID dagRunID, const std::string &taskName) = 0; + virtual RunState &getTaskState(DAGRunID dagRunID, + const std::string &taskName) = 0; }; } // namespace daggy::loggers::dag_run diff --git a/daggy/include/daggy/loggers/dag_run/Defines.hpp b/daggy/include/daggy/loggers/dag_run/Defines.hpp index eff8010..db2d810 100644 --- a/daggy/include/daggy/loggers/dag_run/Defines.hpp +++ b/daggy/include/daggy/loggers/dag_run/Defines.hpp @@ -25,8 +25,7 @@ namespace daggy::loggers::dag_run { // Pretty heavy weight, but struct DAGRunRecord { - std::string name; - TaskSet tasks; + DAGSpec dagSpec; std::unordered_map taskRunStates; std::unordered_map> taskAttempts; std::vector taskStateChanges; @@ -36,7 +35,7 @@ namespace daggy::loggers::dag_run { struct DAGRunSummary { DAGRunID runID; - std::string name; + std::string tag; RunState runState; TimePoint startTime; TimePoint lastUpdate; diff --git a/daggy/include/daggy/loggers/dag_run/OStreamLogger.hpp b/daggy/include/daggy/loggers/dag_run/OStreamLogger.hpp index ad0979c..53c7bb8 100644 --- a/daggy/include/daggy/loggers/dag_run/OStreamLogger.hpp +++ b/daggy/include/daggy/loggers/dag_run/OStreamLogger.hpp @@ -15,9 +15,10 @@ namespace daggy::loggers::dag_run { { public: explicit OStreamLogger(std::ostream &os); + ~OStreamLogger() override; // Execution - DAGRunID startDAGRun(std::string name, const TaskSet &tasks) override; + DAGRunID startDAGRun(const DAGSpec &dagSpec) override; void addTask(DAGRunID dagRunID, const std::string &taskName, const Task &task) override; @@ -34,10 +35,18 @@ namespace daggy::loggers::dag_run { RunState state) override; // Querying - std::vector getDAGs(uint32_t stateMask) override; + DAGSpec getDAGSpec(DAGRunID dagRunID) override; + std::vector queryDAGRuns(const std::string &tag = "", + bool all = false) override; + + RunState getDAGRunState(DAGRunID dagRunID) override; DAGRunRecord getDAGRun(DAGRunID dagRunID) override; + Task &getTask(DAGRunID dagRunID, const std::string &taskName) override; + RunState &getTaskState(DAGRunID dagRunID, + const std::string &taskName) override; + private: std::mutex guard_; std::ostream &os_; diff --git a/daggy/src/CMakeLists.txt b/daggy/src/CMakeLists.txt index 1af1510..7dbe518 100644 --- a/daggy/src/CMakeLists.txt +++ b/daggy/src/CMakeLists.txt @@ -2,6 +2,7 @@ target_sources(${PROJECT_NAME} PRIVATE Serialization.cpp Server.cpp Utilities.cpp + DAGRunner.cpp ) add_subdirectory(executors) diff --git a/daggy/src/DAGRunner.cpp b/daggy/src/DAGRunner.cpp new file mode 100644 index 0000000..164fbc8 --- /dev/null +++ b/daggy/src/DAGRunner.cpp @@ -0,0 +1,213 @@ +#include +#include +#include +#include + +namespace daggy { + DAGRunner::DAGRunner(DAGRunID runID, executors::task::TaskExecutor &executor, + loggers::dag_run::DAGRunLogger &logger, TaskDAG dag, + const TaskParameters &taskParams) + : runID_(runID) + , executor_(executor) + , logger_(logger) + , dag_(dag) + , taskParams_(taskParams) + , running_(true) + , kill_(true) + , nRunningTasks_(0) + , nErroredTasks_(0) + { + } + + DAGRunner::~DAGRunner() + { + std::lock_guard lock(runGuard_); + } + + TaskDAG DAGRunner::run() + { + kill_ = false; + running_ = true; + logger_.updateDAGRunState(runID_, RunState::RUNNING); + + bool allVisited; + { + std::lock_guard lock(runGuard_); + allVisited = dag_.allVisited(); + } + while (!allVisited) { + { + std::lock_guard runLock(runGuard_); + if (!running_ and kill_) { + killRunning(); + } + collectFinished(); + queuePending(); + + if (!running_ and (nRunningTasks_ - nErroredTasks_ <= 0)) { + logger_.updateDAGRunState(runID_, RunState::KILLED); + break; + } + + if (nRunningTasks_ > 0 and nErroredTasks_ == nRunningTasks_) { + logger_.updateDAGRunState(runID_, RunState::ERRORED); + break; + } + } + + std::this_thread::sleep_for(250ms); + { + std::lock_guard lock(runGuard_); + allVisited = dag_.allVisited(); + } + } + + if (dag_.allVisited()) { + logger_.updateDAGRunState(runID_, RunState::COMPLETED); + } + + running_ = false; + return dag_; + } + + void DAGRunner::resetRunning() + { + if (running_) + throw std::runtime_error("Unable to reset while DAG is running."); + + std::lock_guard lock(runGuard_); + nRunningTasks_ = 0; + nErroredTasks_ = 0; + runningTasks_.clear(); + taskAttemptCounts_.clear(); + dag_.resetRunning(); + } + + void DAGRunner::killRunning() + { + for (const auto &[taskName, _] : runningTasks_) { + executor_.stop(runID_, taskName); + } + } + + void DAGRunner::queuePending() + { + if (!running_) + return; + + // Check for any completed tasks + // Add all remaining tasks in a task queue to avoid dominating the thread + // pool + auto t = dag_.visitNext(); + while (t.has_value()) { + // Schedule the task to run + auto &taskName = t.value().first; + auto &task = t.value().second; + taskAttemptCounts_[taskName] = 1; + + logger_.updateTaskState(runID_, taskName, RunState::RUNNING); + runningTasks_.emplace(taskName, + executor_.execute(runID_, taskName, task)); + ++nRunningTasks_; + + auto nextTask = dag_.visitNext(); + if (not nextTask.has_value()) + break; + t.emplace(nextTask.value()); + } + } + + void DAGRunner::collectFinished() + { + for (auto &[taskName, fut] : runningTasks_) { + if (fut.valid() and fut.wait_for(1ms) == std::future_status::ready) { + auto attempt = fut.get(); + logger_.logTaskAttempt(runID_, taskName, attempt); + + // Not a reference, since adding tasks will invalidate references + auto vert = dag_.getVertex(taskName); + auto &task = vert.data; + if (attempt.rc == 0) { + logger_.updateTaskState(runID_, taskName, RunState::COMPLETED); + if (task.isGenerator) { + // Parse the output and update the DAGs + try { + auto parsedTasks = + tasksFromJSON(attempt.outputLog, taskParams_.jobDefaults); + auto newTasks = + expandTaskSet(parsedTasks, executor_, taskParams_.variables); + updateDAGFromTasks(dag_, newTasks); + + // Add in dependencies from current task to new tasks + for (const auto &[ntName, ntTask] : newTasks) { + logger_.addTask(runID_, ntName, ntTask); + task.children.insert(ntName); + } + + // Efficiently add new edges from generator task + // to children + std::unordered_set baseNames; + for (const auto &[k, v] : parsedTasks) { + baseNames.insert(v.definedName); + } + dag_.addEdgeIf(taskName, [&](const auto &v) { + return baseNames.count(v.data.definedName) > 0; + }); + + logger_.updateTask(runID_, taskName, task); + } + catch (std::exception &e) { + logger_.logTaskAttempt( + runID_, taskName, + AttemptRecord{ + .executorLog = + std::string{"Failed to parse JSON output: "} + + e.what()}); + logger_.updateTaskState(runID_, taskName, RunState::ERRORED); + ++nErroredTasks_; + } + } + dag_.completeVisit(taskName); + --nRunningTasks_; + } + else { + // RC isn't 0 + if (taskAttemptCounts_[taskName] <= task.maxRetries) { + logger_.updateTaskState(runID_, taskName, RunState::RETRY); + runningTasks_[taskName] = executor_.execute(runID_, taskName, task); + ++taskAttemptCounts_[taskName]; + } + else { + if (logger_.getTaskState(runID_, taskName) == +RunState::RUNNING or + logger_.getTaskState(runID_, taskName) == +RunState::RETRY) { + logger_.updateTaskState(runID_, taskName, RunState::ERRORED); + ++nErroredTasks_; + } + else { + // Task was killed + --nRunningTasks_; + } + } + } + } + } + } + + void DAGRunner::stop(bool kill, bool blocking) + { + kill_ = kill; + running_ = false; + + if (blocking) { + while (true) { + { + std::lock_guard lock(runGuard_); + if (nRunningTasks_ - nErroredTasks_ == 0) + break; + } + std::this_thread::sleep_for(250ms); + } + } + } + +} // namespace daggy diff --git a/daggy/src/Serialization.cpp b/daggy/src/Serialization.cpp index 7540eaa..422670b 100644 --- a/daggy/src/Serialization.cpp +++ b/daggy/src/Serialization.cpp @@ -276,17 +276,56 @@ namespace daggy { std::string timePointToString(const TimePoint &tp) { - std::stringstream ss; - ss << tp; - return ss.str(); + return std::to_string(tp.time_since_epoch().count()); } TimePoint stringToTimePoint(const std::string &timeString) { - std::tm dt{}; - std::stringstream ss{timeString}; - ss >> std::get_time(&dt, "%Y-%m-%d %H:%M:%S %Z"); - return Clock::from_time_t(mktime(&dt)); + using namespace std::chrono; + + size_t nanos = std::stoull(timeString); + nanoseconds dur(nanos); + + return TimePoint(dur); + } + + DAGSpec dagFromJSON(const rj::Value &spec) + { + DAGSpec info; + + if (!spec.IsObject()) { + throw std::runtime_error("Payload is not a dictionary."); + } + if (!spec.HasMember("tag")) { + throw std::runtime_error("DAG Run is missing a name."); + } + if (!spec.HasMember("tasks")) { + throw std::runtime_error("DAG Run has no tasks."); + } + + info.tag = spec["tag"].GetString(); + + // Get parameters if there are any + if (spec.HasMember("parameters")) { + info.taskConfig.variables = configFromJSON(spec["parameters"]); + } + + // Job Defaults + if (spec.HasMember("jobDefaults")) { + info.taskConfig.jobDefaults = configFromJSON(spec["jobDefaults"]); + } + + // Get the tasks + info.tasks = tasksFromJSON(spec["tasks"], info.taskConfig.jobDefaults); + + return info; + } + + DAGSpec dagFromJSON(const std::string &jsonSpec) + { + rj::Document doc; + checkRJParse(doc.Parse(jsonSpec.c_str()), "Parsing config"); + return dagFromJSON(doc); } } // namespace daggy diff --git a/daggy/src/Server.cpp b/daggy/src/Server.cpp index 81cacf8..9062706 100644 --- a/daggy/src/Server.cpp +++ b/daggy/src/Server.cpp @@ -4,12 +4,18 @@ #include #include #include +#include +#include +#include +#include +#include -#define REQ_ERROR(code, msg) \ - response.send(Pistache::Http::Code::code, msg); \ +#define REQ_RESPONSE(code, msg) \ + std::stringstream ss; \ + ss << R"({"message": )" << std::quoted(msg) << "}"; \ + response.send(Pistache::Http::Code::code, ss.str()); \ return; -namespace rj = rapidjson; using namespace Pistache; namespace daggy { @@ -25,6 +31,22 @@ namespace daggy { createDescription(); } + Server::Server(const Pistache::Address &listenSpec, + loggers::dag_run::DAGRunLogger &logger, + executors::task::TaskExecutor &executor, size_t nDAGRunners) + : endpoint_(listenSpec) + , desc_("Daggy API", "0.1") + , logger_(logger) + , executor_(executor) + , runnerPool_(nDAGRunners) + { + } + + Server::~Server() + { + shutdown(); + } + void Server::start() { router_.initFromDescription(desc_); @@ -42,6 +64,7 @@ namespace daggy { void Server::shutdown() { endpoint_.shutdown(); + runnerPool_.shutdown(); } uint16_t Server::getPort() const @@ -55,7 +78,7 @@ namespace daggy { auto backendErrorResponse = desc_.response(Http::Code::Internal_Server_Error, - "An error occurred with the backend"); + R"({"error": "An error occurred with the backend"})"); desc_.schemes(Rest::Scheme::Http) .basePath("/v1") @@ -69,111 +92,131 @@ namespace daggy { auto versionPath = desc_.path("/v1"); - auto dagPath = versionPath.path("/dagrun"); + /* + DAG Run Summaries + */ + auto dagRunsPath = versionPath.path("/dagruns"); - // Run a DAG - dagPath.route(desc_.post("/")) + dagRunsPath.route(desc_.get("/")) + .bind(&Server::handleQueryDAGs, this) + .produces(MIME(Application, Json)) + .response(Http::Code::Ok, "List summaries DAGs"); + + /* + Individual DAG Run routes + */ + auto dagRunPath = versionPath.path("/dagrun"); + + dagRunPath.route(desc_.post("/")) .bind(&Server::handleRunDAG, this) - .produces(MIME(Application, Json), MIME(Application, Xml)) + .produces(MIME(Application, Json)) .response(Http::Code::Ok, "Run a DAG"); - // List detailed DAG run - dagPath.route(desc_.get("/:runID")) - .bind(&Server::handleGetDAGRun, this) - .produces(MIME(Application, Json), MIME(Application, Xml)) - .response(Http::Code::Ok, "Details of a specific DAG run"); - // List all DAG runs - dagPath.route(desc_.get("/")) - .bind(&Server::handleGetDAGRuns, this) - .produces(MIME(Application, Json), MIME(Application, Xml)) - .response(Http::Code::Ok, "The list of all known DAG Runs"); + dagRunPath.route(desc_.post("/validate")) + .bind(&Server::handleValidateDAG, this) + .produces(MIME(Application, Json)) + .response(Http::Code::Ok, "Validate a DAG Run Spec"); + + /* + Management of a specific DAG + */ + auto specificDAGRunPath = dagRunPath.path("/:runID"); + + specificDAGRunPath.route(desc_.get("/")) + .bind(&Server::handleGetDAGRun, this) + .produces(MIME(Application, Json)) + .response(Http::Code::Ok, "Full DAG Run"); + + specificDAGRunPath.route(desc_.get("/state")) + .bind(&Server::handleGetDAGRunState, this) + .produces(MIME(Application, Json)) + .response(Http::Code::Ok, + "Structure of a DAG and DAG and Task run states"); + + specificDAGRunPath.route(desc_.patch("/state/:state")) + .bind(&Server::handleSetDAGRunState, this) + .produces(MIME(Application, Json)) + .response(Http::Code::Ok, "Change the state of a DAG"); + + /* + Task paths + */ + auto taskPath = specificDAGRunPath.path("/task/:taskName"); + taskPath.route(desc_.get("/")) + .bind(&Server::handleGetTask, this) + .produces(MIME(Application, Json)) + .response(Http::Code::Ok, "Details of a specific task"); + + /* + Task State paths + */ + auto taskStatePath = taskPath.path("/state"); + + taskStatePath.route(desc_.get("/")) + .bind(&Server::handleGetTaskState, this) + .produces(MIME(Application, Json)) + .response(Http::Code::Ok, "Get a task state"); + + taskStatePath.route(desc_.patch("/:state")) + .bind(&Server::handleSetTaskState, this) + .produces(MIME(Application, Json)) + .response(Http::Code::Ok, "Set a task state"); } - /* - * { - * "name": "DAG Run Name" - * "job": {...} - * "tasks": {...} - */ void Server::handleRunDAG(const Pistache::Rest::Request &request, Pistache::Http::ResponseWriter response) { - if (!handleAuth(request, response)) + if (!handleAuth(request)) return; - rj::Document doc; - try { - doc.Parse(request.body().c_str()); - } - catch (std::exception &e) { - REQ_ERROR(Bad_Request, std::string{"Invalid JSON payload: "} + e.what()); - } - - if (!doc.IsObject()) { - REQ_ERROR(Bad_Request, "Payload is not a dictionary."); - } - if (!doc.HasMember("name")) { - REQ_ERROR(Bad_Request, "DAG Run is missing a name."); - } - if (!doc.HasMember("tasks")) { - REQ_ERROR(Bad_Request, "DAG Run has no tasks."); - } - - std::string runName = doc["name"].GetString(); - - // Get parameters if there are any - ConfigValues parameters; - if (doc.HasMember("parameters")) { - try { - auto parsedParams = configFromJSON(doc["parameters"].GetObject()); - parameters.swap(parsedParams); - } - catch (std::exception &e) { - REQ_ERROR(Bad_Request, e.what()); - } - } - - // Job Defaults - ConfigValues jobDefaults; - if (doc.HasMember("jobDefaults")) { - try { - auto parsedJobDefaults = configFromJSON(doc["jobDefaults"].GetObject()); - jobDefaults.swap(parsedJobDefaults); - } - catch (std::exception &e) { - REQ_ERROR(Bad_Request, e.what()); - } - } - - // Get the tasks - TaskSet tasks; - try { - auto taskTemplates = tasksFromJSON(doc["tasks"], jobDefaults); - auto expandedTasks = expandTaskSet(taskTemplates, executor_, parameters); - tasks.swap(expandedTasks); - } - catch (std::exception &e) { - REQ_ERROR(Bad_Request, e.what()); - } + auto dagSpec = dagFromJSON(request.body()); + dagSpec.tasks = + expandTaskSet(dagSpec.tasks, executor_, dagSpec.taskConfig.variables); // Get a run ID - auto runID = logger_.startDAGRun(runName, tasks); - auto dag = buildDAGFromTasks(tasks); - - runnerPool_.addTask([this, parameters, runID, dag]() { - runDAG(runID, this->executor_, this->logger_, dag, parameters); - }); + DAGRunID runID = logger_.startDAGRun(dagSpec); + auto dag = buildDAGFromTasks(dagSpec.tasks); + queueDAG_(runID, dag, dagSpec.taskConfig); response.send(Pistache::Http::Code::Ok, R"({"runID": )" + std::to_string(runID) + "}"); } - void Server::handleGetDAGRuns(const Pistache::Rest::Request &request, - Pistache::Http::ResponseWriter response) + void Server::handleValidateDAG(const Pistache::Rest::Request &request, + Pistache::Http::ResponseWriter response) { - if (!handleAuth(request, response)) + try { + dagFromJSON(request.body()); + response.send(Pistache::Http::Code::Ok, R"({"valid": true})"); + } + catch (std::exception &e) { + std::string error = e.what(); + response.send(Pistache::Http::Code::Ok, + std::string{R"({"valid": true, "error": })"} + error + "}"); + } + } + + void Server::handleQueryDAGs(const Pistache::Rest::Request &request, + Pistache::Http::ResponseWriter response) + { + if (!handleAuth(request)) return; - auto dagRuns = logger_.getDAGs(0); + + bool all = false; + std::string tag = ""; + + if (request.query().has("tag")) { + tag = request.query().get("tag").value(); + } + + if (request.hasParam(":all")) { + auto val = request.query().get("all").value(); + if (val == "true" or val == "1") { + all = true; + } + } + + auto dagRuns = logger_.queryDAGRuns(tag, all); std::stringstream ss; ss << '['; @@ -187,8 +230,8 @@ namespace daggy { } ss << " {" - << R"("runID": )" << run.runID << ',' << R"("name": )" - << std::quoted(run.name) << "," + << R"("runID": )" << run.runID << ',' << R"("tag": )" + << std::quoted(run.tag) << "," << R"("startTime": )" << std::quoted(timePointToString(run.startTime)) << ',' << R"("lastUpdate": )" << std::quoted(timePointToString(run.lastUpdate)) << ',' @@ -214,10 +257,10 @@ namespace daggy { void Server::handleGetDAGRun(const Pistache::Rest::Request &request, Pistache::Http::ResponseWriter response) { - if (!handleAuth(request, response)) + if (!handleAuth(request)) return; if (!request.hasParam(":runID")) { - REQ_ERROR(Not_Found, "No runID provided in URL"); + REQ_RESPONSE(Not_Found, "No runID provided in URL"); } auto runID = request.param(":runID").as(); auto run = logger_.getDAGRun(runID); @@ -225,9 +268,9 @@ namespace daggy { bool first = true; std::stringstream ss; ss << "{" - << R"("runID": )" << runID << ',' << R"("name": )" - << std::quoted(run.name) << ',' << R"("tasks": )" - << tasksToJSON(run.tasks) << ','; + << R"("runID": )" << runID << ',' << R"("tag": )" + << std::quoted(run.dagSpec.tag) << ',' << R"("tasks": )" + << tasksToJSON(run.dagSpec.tasks) << ','; // task run states ss << R"("taskStates": { )"; @@ -295,21 +338,179 @@ namespace daggy { response.send(Pistache::Http::Code::Ok, ss.str()); } + void Server::handleGetDAGRunState(const Pistache::Rest::Request &request, + Pistache::Http::ResponseWriter response) + { + if (!handleAuth(request)) + return; + + DAGRunID runID = request.param(":runID").as(); + RunState state = RunState::QUEUED; + try { + state = logger_.getDAGRunState(runID); + std::stringstream ss; + ss << R"({ "runID": )" << runID << R"(, "state": )" + << std::quoted(state._to_string()) << '}'; + response.send(Pistache::Http::Code::Ok, ss.str()); + } + catch (std::exception &e) { + REQ_RESPONSE(Not_Found, e.what()); + } + } + + void Server::queueDAG_(DAGRunID runID, const TaskDAG &dag, + const TaskParameters &taskParameters) + { + std::lock_guard lock(runnerGuard_); + /* + auto it = runners_.emplace( + std::piecewise_construct, std::forward_as_tuple(runID), + std::forward_as_tuple(runID, executor_, logger_, dag, + taskParameters)); + */ + auto it = runners_.emplace( + runID, std::make_shared(runID, executor_, logger_, dag, + taskParameters)); + + if (!it.second) + throw std::runtime_error("A DAGRun with the same ID is already running"); + auto runner = it.first->second; + runnerPool_.addTask([runner, runID, this]() { + runner->run(); + std::lock_guard lock(this->runnerGuard_); + this->runners_.extract(runID); + }); + } + + void Server::handleSetDAGRunState(const Pistache::Rest::Request &request, + Pistache::Http::ResponseWriter response) + { + if (!handleAuth(request)) + return; + + // TODO handle state transition + DAGRunID runID = request.param(":runID").as(); + RunState newState = RunState::_from_string( + request.param(":state").as().c_str()); + + std::shared_ptr runner{nullptr}; + { + std::lock_guard lock(runnerGuard_); + auto it = runners_.find(runID); + if (runners_.find(runID) != runners_.end()) { + runner = it->second; + } + } + + if (runner) { + switch (newState) { + case RunState::PAUSED: + case RunState::KILLED: { + runner->stop(true, true); + logger_.updateDAGRunState(runID, newState); + break; + } + default: { + REQ_RESPONSE(Method_Not_Allowed, + std::string{"Cannot transition to state "} + + newState._to_string()); + } + } + } + else { + switch (newState) { + case RunState::QUEUED: { + auto dagRun = logger_.getDAGRun(runID); + auto dag = + buildDAGFromTasks(dagRun.dagSpec.tasks, dagRun.taskStateChanges); + dag.resetRunning(); + queueDAG_(runID, dag, dagRun.dagSpec.taskConfig); + break; + } + default: + REQ_RESPONSE( + Method_Not_Allowed, + std::string{"DAG not running, cannot transition to state "} + + newState._to_string()); + } + } + REQ_RESPONSE(Ok, ""); + } + + void Server::handleGetTask(const Pistache::Rest::Request &request, + Pistache::Http::ResponseWriter response) + { + if (!handleAuth(request)) + return; + + auto runID = request.param(":runID").as(); + auto taskName = request.param(":taskName").as(); + + try { + auto task = logger_.getTask(runID, taskName); + response.send(Pistache::Http::Code::Ok, taskToJSON(task)); + } + catch (std::exception &e) { + REQ_RESPONSE(Not_Found, e.what()); + } + } + + void Server::handleGetTaskState(const Pistache::Rest::Request &request, + Pistache::Http::ResponseWriter response) + { + if (!handleAuth(request)) + return; + + auto runID = request.param(":runID").as(); + auto taskName = request.param(":taskName").as(); + + try { + auto state = logger_.getTaskState(runID, taskName); + std::stringstream ss; + ss << R"({ "runID": )" << runID << R"(, "taskName": )" + << std::quoted(taskName) << R"(, "state": )" + << std::quoted(state._to_string()) << '}'; + response.send(Pistache::Http::Code::Ok, ss.str()); + } + catch (std::exception &e) { + REQ_RESPONSE(Not_Found, e.what()); + } + } + + void Server::handleSetTaskState(const Pistache::Rest::Request &request, + Pistache::Http::ResponseWriter response) + { + if (!handleAuth(request)) + return; + + // TODO implement handling of task state + auto runID = request.param(":runID").as(); + auto taskName = request.param(":taskName").as(); + RunState state = RunState::_from_string( + request.param(":state").as().c_str()); + + try { + logger_.updateTaskState(runID, taskName, state); + response.send(Pistache::Http::Code::Ok, ""); + } + catch (std::exception &e) { + REQ_RESPONSE(Not_Found, e.what()); + } + } + void Server::handleReady(const Pistache::Rest::Request &request, Pistache::Http::ResponseWriter response) { - response.send(Pistache::Http::Code::Ok, "Ya like DAGs?"); + response.send(Pistache::Http::Code::Ok, R"({ "msg": "Ya like DAGs?"})"); } /* - * handleAuth will check any auth methods and handle any responses in the case - * of failed auth. If it returns false, callers should cease handling the - * response + * handleAuth will check any auth methods and handle any responses in the + * case of failed auth. If it returns false, callers should cease handling + * the response */ - bool Server::handleAuth(const Pistache::Rest::Request &request, - Pistache::Http::ResponseWriter &response) + bool Server::handleAuth(const Pistache::Rest::Request &request) { - (void)response; return true; } } // namespace daggy diff --git a/daggy/src/Utilities.cpp b/daggy/src/Utilities.cpp index 29ecfff..99258d0 100644 --- a/daggy/src/Utilities.cpp +++ b/daggy/src/Utilities.cpp @@ -92,8 +92,9 @@ namespace daggy { } // Add edges - for (const auto &[name, task] : tasks) { - dag.addEdgeIf(name, [&](const auto &v) { + for (const auto &[name, t] : tasks) { + const auto &task = t; + dag.addEdgeIf(name, [&task](const auto &v) { return task.children.count(v.data.definedName) > 0; }); } @@ -115,10 +116,10 @@ namespace daggy { switch (update.newState) { case RunState::RUNNING: case RunState::RETRY: + case RunState::PAUSED: case RunState::ERRORED: case RunState::KILLED: dag.setVertexState(update.taskName, RunState::RUNNING); - dag.setVertexState(update.taskName, RunState::COMPLETED); break; case RunState::COMPLETED: case RunState::QUEUED: @@ -129,120 +130,9 @@ namespace daggy { return dag; } - TaskDAG runDAG(DAGRunID runID, executors::task::TaskExecutor &executor, - loggers::dag_run::DAGRunLogger &logger, TaskDAG dag, - const ConfigValues parameters) - { - logger.updateDAGRunState(runID, RunState::RUNNING); - - std::unordered_map> runningTasks; - std::unordered_map taskAttemptCounts; - - size_t running = 0; - size_t errored = 0; - while (!dag.allVisited()) { - // Check for any completed tasks - for (auto &[taskName, fut] : runningTasks) { - if (fut.valid()) { - auto attempt = fut.get(); - logger.logTaskAttempt(runID, taskName, attempt); - - // Not a reference, since adding tasks will invalidate references - auto vert = dag.getVertex(taskName); - auto &task = vert.data; - if (attempt.rc == 0) { - logger.updateTaskState(runID, taskName, RunState::COMPLETED); - if (task.isGenerator) { - // Parse the output and update the DAGs - try { - auto parsedTasks = tasksFromJSON(attempt.outputLog); - auto newTasks = - expandTaskSet(parsedTasks, executor, parameters); - updateDAGFromTasks(dag, newTasks); - - // Add in dependencies from current task to new tasks - for (const auto &[ntName, ntTask] : newTasks) { - logger.addTask(runID, ntName, ntTask); - task.children.insert(ntName); - } - - // Efficiently add new edges from generator task - // to children - std::unordered_set baseNames; - for (const auto &[k, v] : parsedTasks) { - baseNames.insert(v.definedName); - } - dag.addEdgeIf(taskName, [&](const auto &v) { - return baseNames.count(v.data.definedName) > 0; - }); - - logger.updateTask(runID, taskName, task); - } - catch (std::exception &e) { - logger.logTaskAttempt( - runID, taskName, - AttemptRecord{ - .executorLog = - std::string{"Failed to parse JSON output: "} + - e.what()}); - logger.updateTaskState(runID, taskName, RunState::ERRORED); - ++errored; - } - } - dag.completeVisit(taskName); - --running; - } - else { - // RC isn't 0 - if (taskAttemptCounts[taskName] <= task.maxRetries) { - logger.updateTaskState(runID, taskName, RunState::RETRY); - runningTasks[taskName] = executor.execute(taskName, task); - ++taskAttemptCounts[taskName]; - } - else { - logger.updateTaskState(runID, taskName, RunState::ERRORED); - ++errored; - } - } - } - } - - // Add all remaining tasks in a task queue to avoid dominating the thread - // pool - auto t = dag.visitNext(); - while (t.has_value()) { - // Schedule the task to run - auto &taskName = t.value().first; - auto &task = t.value().second; - taskAttemptCounts[taskName] = 1; - - logger.updateTaskState(runID, taskName, RunState::RUNNING); - runningTasks.emplace(taskName, executor.execute(taskName, task)); - ++running; - - auto nextTask = dag.visitNext(); - if (not nextTask.has_value()) - break; - t.emplace(nextTask.value()); - } - if (running > 0 and errored == running) { - logger.updateDAGRunState(runID, RunState::ERRORED); - break; - } - std::this_thread::sleep_for(250ms); - } - - if (dag.allVisited()) { - logger.updateDAGRunState(runID, RunState::COMPLETED); - } - - return dag; - } - std::ostream &operator<<(std::ostream &os, const TimePoint &tp) { - auto t_c = Clock::to_time_t(tp); - os << std::put_time(std::localtime(&t_c), "%Y-%m-%d %H:%M:%S %Z"); + os << tp.time_since_epoch().count() << std::endl; return os; } } // namespace daggy diff --git a/daggy/src/executors/task/ForkingTaskExecutor.cpp b/daggy/src/executors/task/ForkingTaskExecutor.cpp index a3df342..8f1a5a0 100644 --- a/daggy/src/executors/task/ForkingTaskExecutor.cpp +++ b/daggy/src/executors/task/ForkingTaskExecutor.cpp @@ -36,13 +36,45 @@ std::string slurp(int fd) return result; } -std::future ForkingTaskExecutor::execute( - const std::string &taskName, const Task &task) +ForkingTaskExecutor::ForkingTaskExecutor(size_t nThreads) + : tp_(nThreads) { - return tp_.addTask([this, task]() { return this->runTask(task); }); } -daggy::AttemptRecord ForkingTaskExecutor::runTask(const Task &task) +ForkingTaskExecutor::~ForkingTaskExecutor() +{ + std::lock_guard lock(taskControlsGuard_); + taskControls_.clear(); +} + +bool ForkingTaskExecutor::stop(DAGRunID runID, const std::string &taskName) +{ + std::string key = std::to_string(runID) + "_" + taskName; + std::lock_guard lock(taskControlsGuard_); + auto it = taskControls_.find(key); + if (it == taskControls_.end()) + return true; + it->second = false; + return true; +} + +std::future ForkingTaskExecutor::execute( + DAGRunID runID, const std::string &taskName, const Task &task) +{ + std::string key = std::to_string(runID) + "_" + taskName; + std::lock_guard lock(taskControlsGuard_); + auto [it, ins] = taskControls_.emplace(key, true); + auto &running = it->second; + return tp_.addTask([this, task, &running, key]() { + auto ret = this->runTask(task, running); + std::lock_guard lock(this->taskControlsGuard_); + this->taskControls_.extract(key); + return ret; + }); +} + +daggy::AttemptRecord ForkingTaskExecutor::runTask(const Task &task, + std::atomic &running) { AttemptRecord rec; @@ -81,23 +113,41 @@ daggy::AttemptRecord ForkingTaskExecutor::runTask(const Task &task) exit(-1); } - std::atomic running = true; + std::atomic reading = true; std::thread stdoutReader([&]() { - while (running) + while (reading) rec.outputLog.append(slurp(stdoutPipe[0])); }); std::thread stderrReader([&]() { - while (running) + while (reading) rec.errorLog.append(slurp(stderrPipe[0])); }); - int rc = 0; - waitpid(child, &rc, 0); - running = false; + siginfo_t childInfo; + while (running) { + childInfo.si_pid = 0; + waitid(P_PID, child, &childInfo, WEXITED | WNOHANG); + if (childInfo.si_pid > 0) { + break; + } + std::this_thread::sleep_for(250ms); + } + + if (!running) { + rec.executorLog = "Killed"; + // Send the kills until pid is dead + while (kill(child, SIGKILL) != -1) { + // Need to collect the child to avoid a zombie process + waitid(P_PID, child, &childInfo, WEXITED | WNOHANG); + std::this_thread::sleep_for(50ms); + } + } + + reading = false; rec.stopTime = Clock::now(); - if (WIFEXITED(rc)) { - rec.rc = WEXITSTATUS(rc); + if (childInfo.si_pid > 0) { + rec.rc = childInfo.si_status; } else { rec.rc = -1; diff --git a/daggy/src/executors/task/NoopTaskExecutor.cpp b/daggy/src/executors/task/NoopTaskExecutor.cpp index 9239b8b..9e875f4 100644 --- a/daggy/src/executors/task/NoopTaskExecutor.cpp +++ b/daggy/src/executors/task/NoopTaskExecutor.cpp @@ -3,7 +3,7 @@ namespace daggy::executors::task { std::future NoopTaskExecutor::execute( - const std::string &taskName, const Task &task) + DAGRunID runID, const std::string &taskName, const Task &task) { std::promise promise; auto ts = Clock::now(); @@ -42,4 +42,10 @@ namespace daggy::executors::task { return newValues; } + + bool NoopTaskExecutor::stop(DAGRunID runID, const std::string &taskName) + { + return true; + } + } // namespace daggy::executors::task diff --git a/daggy/src/executors/task/SlurmTaskExecutor.cpp b/daggy/src/executors/task/SlurmTaskExecutor.cpp index e53e121..fb07635 100644 --- a/daggy/src/executors/task/SlurmTaskExecutor.cpp +++ b/daggy/src/executors/task/SlurmTaskExecutor.cpp @@ -1,4 +1,5 @@ #include +#include #include #ifdef DAGGY_ENABLE_SLURM #include @@ -6,6 +7,7 @@ #include #include +#include #include #include #include @@ -115,7 +117,7 @@ namespace daggy::executors::task { } std::future SlurmTaskExecutor::execute( - const std::string &taskName, const Task &task) + DAGRunID runID, const std::string &taskName, const Task &task) { std::stringstream executorLog; @@ -191,13 +193,39 @@ namespace daggy::executors::task { slurm_free_submit_response_response_msg(resp_msg); std::lock_guard lock(promiseGuard_); - Job newJob{.prom{}, .stdoutFile = stdoutFile, .stderrFile = stderrFile}; + Job newJob{.prom{}, + .stdoutFile = stdoutFile, + .stderrFile = stderrFile, + .runID = runID, + .taskName = taskName}; auto fut = newJob.prom.get_future(); runningJobs_.emplace(jobID, std::move(newJob)); return fut; } + bool SlurmTaskExecutor::stop(DAGRunID runID, const std::string &taskName) + { + // Hopefully this isn't a common thing, so just scrap the current jobs and + // kill them + size_t jobID = 0; + { + std::lock_guard lock(promiseGuard_); + for (const auto &[k, v] : runningJobs_) { + if (v.runID == runID and v.taskName == taskName) { + jobID = k; + break; + } + } + if (jobID == 0) + return true; + } + + // Send the kill message to slurm + slurm_kill_job(jobID, SIGKILL, KILL_HURRY); + return true; + } + void SlurmTaskExecutor::monitor() { std::unordered_set resolvedJobs; @@ -225,32 +253,40 @@ namespace daggy::executors::task { // Job has finished case JOB_COMPLETE: /* completed execution successfully */ case JOB_FAILED: /* completed execution unsuccessfully */ + record.rc = jobInfo.exit_code; record.executorLog = "Script errored.\n"; break; - case JOB_CANCELLED: /* cancelled by user */ + case JOB_CANCELLED: /* cancelled by user */ + record.rc = 9; // matches SIGKILL record.executorLog = "Job cancelled by user.\n"; break; case JOB_TIMEOUT: /* terminated on reaching time limit */ + record.rc = jobInfo.exit_code; record.executorLog = "Job exceeded time limit.\n"; break; case JOB_NODE_FAIL: /* terminated on node failure */ + record.rc = jobInfo.exit_code; record.executorLog = "Node failed during execution\n"; break; case JOB_PREEMPTED: /* terminated due to preemption */ + record.rc = jobInfo.exit_code; record.executorLog = "Job terminated due to pre-emption.\n"; break; case JOB_BOOT_FAIL: /* terminated due to node boot failure */ + record.rc = jobInfo.exit_code; record.executorLog = - "Job failed to run due to failure of compute node to boot.\n"; + "Job failed to run due to failure of compute node to " + "boot.\n"; break; case JOB_DEADLINE: /* terminated on deadline */ + record.rc = jobInfo.exit_code; record.executorLog = "Job terminated due to deadline.\n"; break; case JOB_OOM: /* experienced out of memory error */ + record.rc = jobInfo.exit_code; record.executorLog = "Job terminated due to out-of-memory.\n"; break; } - record.rc = jobInfo.exit_code; slurm_free_job_info_msg(jobStatus); readAndClean(job.stdoutFile, record.outputLog); @@ -265,7 +301,7 @@ namespace daggy::executors::task { } } - std::this_thread::sleep_for(std::chrono::milliseconds(250)); + std::this_thread::sleep_for(std::chrono::seconds(1)); } } } // namespace daggy::executors::task diff --git a/daggy/src/loggers/dag_run/OStreamLogger.cpp b/daggy/src/loggers/dag_run/OStreamLogger.cpp index f21a09a..1fc9bb3 100644 --- a/daggy/src/loggers/dag_run/OStreamLogger.cpp +++ b/daggy/src/loggers/dag_run/OStreamLogger.cpp @@ -11,20 +11,26 @@ namespace daggy { namespace loggers { namespace dag_run { { } + OStreamLogger::~OStreamLogger() + { + std::lock_guard lock(guard_); + dagRuns_.clear(); + } + // Execution - DAGRunID OStreamLogger::startDAGRun(std::string name, const TaskSet &tasks) + DAGRunID OStreamLogger::startDAGRun(const DAGSpec &dagSpec) { std::lock_guard lock(guard_); size_t runID = dagRuns_.size(); - dagRuns_.push_back({.name = name, .tasks = tasks}); - for (const auto &[name, _] : tasks) { + dagRuns_.emplace_back(DAGRunRecord{.dagSpec = dagSpec}); + for (const auto &[name, _] : dagSpec.tasks) { _updateTaskState(runID, name, RunState::QUEUED); } _updateDAGRunState(runID, RunState::QUEUED); - os_ << "Starting new DAGRun named " << name << " with ID " << runID - << " and " << tasks.size() << " tasks" << std::endl; - for (const auto &[name, task] : tasks) { + os_ << "Starting new DAGRun tagged " << dagSpec.tag << " with ID " << runID + << " and " << dagSpec.tasks.size() << " tasks" << std::endl; + for (const auto &[name, task] : dagSpec.tasks) { os_ << "TASK (" << name << "): " << configToJSON(task.job); os_ << std::endl; } @@ -35,8 +41,8 @@ namespace daggy { namespace loggers { namespace dag_run { const Task &task) { std::lock_guard lock(guard_); - auto &dagRun = dagRuns_[dagRunID]; - dagRun.tasks[taskName] = task; + auto &dagRun = dagRuns_[dagRunID]; + dagRun.dagSpec.tasks[taskName] = task; _updateTaskState(dagRunID, taskName, RunState::QUEUED); } @@ -44,8 +50,8 @@ namespace daggy { namespace loggers { namespace dag_run { const Task &task) { std::lock_guard lock(guard_); - auto &dagRun = dagRuns_[dagRunID]; - dagRun.tasks[taskName] = task; + auto &dagRun = dagRuns_[dagRunID]; + dagRun.dagSpec.tasks[taskName] = task; } void OStreamLogger::updateDAGRunState(DAGRunID dagRunID, RunState state) @@ -101,15 +107,29 @@ namespace daggy { namespace loggers { namespace dag_run { } // Querying - std::vector OStreamLogger::getDAGs(uint32_t stateMask) + DAGSpec OStreamLogger::getDAGSpec(DAGRunID dagRunID) + { + std::lock_guard lock(guard_); + return dagRuns_.at(dagRunID).dagSpec; + }; + + std::vector OStreamLogger::queryDAGRuns(const std::string &tag, + bool all) { std::vector summaries; std::lock_guard lock(guard_); size_t i = 0; for (const auto &run : dagRuns_) { + if ((!all) && + (run.dagStateChanges.back().newState == +RunState::COMPLETED)) { + continue; + } + if (!tag.empty() and tag != run.dagSpec.tag) + continue; + DAGRunSummary summary{ .runID = i, - .name = run.name, + .tag = run.dagSpec.tag, .runState = run.dagStateChanges.back().newState, .startTime = run.dagStateChanges.front().time, .lastUpdate = std::max(run.taskStateChanges.back().time, @@ -126,10 +146,26 @@ namespace daggy { namespace loggers { namespace dag_run { DAGRunRecord OStreamLogger::getDAGRun(DAGRunID dagRunID) { - if (dagRunID >= dagRuns_.size()) { - throw std::runtime_error("No such DAGRun ID"); - } std::lock_guard lock(guard_); - return dagRuns_[dagRunID]; + return dagRuns_.at(dagRunID); } + + RunState OStreamLogger::getDAGRunState(DAGRunID dagRunID) + { + std::lock_guard lock(guard_); + return dagRuns_.at(dagRunID).dagStateChanges.back().newState; + } + + Task &OStreamLogger::getTask(DAGRunID dagRunID, const std::string &taskName) + { + std::lock_guard lock(guard_); + return dagRuns_.at(dagRunID).dagSpec.tasks.at(taskName); + } + RunState &OStreamLogger::getTaskState(DAGRunID dagRunID, + const std::string &taskName) + { + std::lock_guard lock(guard_); + return dagRuns_.at(dagRunID).taskRunStates.at(taskName); + } + }}} // namespace daggy::loggers::dag_run diff --git a/endpoints.otl b/endpoints.otl new file mode 100644 index 0000000..8308751 --- /dev/null +++ b/endpoints.otl @@ -0,0 +1,30 @@ +ready [handleReady] +v1 + dagruns + :GET - Summary of running DAGs [handleListDAGs] + + {tag} + + dagrun + :POST - submit a dag run [handleRunDAG] + + validate + :POST - Ensure a submitted DAG run is valid [handleValidateDAG + + {runID} + :GET - Full DAG run information [handleGetDAG] + + summary + :GET - RunState of DAG, and task RunState counts + + state + :PATCH - Change the state of a DAG (paused, killed) + :GET - Summary of dag run (structure + runstates) + + tasks + {taskName} + :GET - Full task definition and output + + state + :GET -- Get the task state + :PATCH -- Set the task state diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 52d96fb..83618e0 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -2,6 +2,7 @@ project(tests) add_executable(tests main.cpp # unit tests unit_dag.cpp + unit_dagrunner.cpp unit_dagrun_loggers.cpp unit_executor_forkingexecutor.cpp unit_executor_slurmexecutor.cpp @@ -14,4 +15,4 @@ add_executable(tests main.cpp # Performance checks perf_dag.cpp ) -target_link_libraries(tests libdaggy stdc++fs Catch2::Catch2) +target_link_libraries(tests libdaggy stdc++fs Catch2::Catch2 curl) diff --git a/tests/unit_dagrun_loggers.cpp b/tests/unit_dagrun_loggers.cpp index c8bf95b..ed480e9 100644 --- a/tests/unit_dagrun_loggers.cpp +++ b/tests/unit_dagrun_loggers.cpp @@ -2,11 +2,10 @@ #include #include #include +#include #include "daggy/loggers/dag_run/OStreamLogger.hpp" -namespace fs = std::filesystem; - using namespace daggy; using namespace daggy::loggers::dag_run; @@ -20,28 +19,68 @@ const TaskSet SAMPLE_TASKS{ {"work_c", Task{.job{{"command", std::vector{"/bin/echo", "c"}}}}}}; -inline DAGRunID testDAGRunInit(DAGRunLogger &logger, const std::string &name, +inline DAGRunID testDAGRunInit(DAGRunLogger &logger, const std::string &tag, const TaskSet &tasks) { - auto runID = logger.startDAGRun(name, tasks); - auto dagRun = logger.getDAGRun(runID); + auto runID = logger.startDAGRun(DAGSpec{.tag = tag, .tasks = tasks}); - REQUIRE(dagRun.tasks == tasks); + // Verify run shows up in the list + { + auto runs = logger.queryDAGRuns(); + REQUIRE(!runs.empty()); + auto it = std::find_if(runs.begin(), runs.end(), + [runID](const auto &r) { return r.runID == runID; }); + REQUIRE(it != runs.end()); + REQUIRE(it->tag == tag); + REQUIRE(it->runState == +RunState::QUEUED); + } - REQUIRE(dagRun.taskRunStates.size() == tasks.size()); - auto nonQueuedTask = - std::find_if(dagRun.taskRunStates.begin(), dagRun.taskRunStates.end(), - [](const auto &a) { return a.second != +RunState::QUEUED; }); - REQUIRE(nonQueuedTask == dagRun.taskRunStates.end()); + // Verify states + { + REQUIRE(logger.getDAGRunState(runID) == +RunState::QUEUED); + for (const auto &[k, _] : tasks) { + REQUIRE(logger.getTaskState(runID, k) == +RunState::QUEUED); + } + } + + // Verify integrity of run + { + auto dagRun = logger.getDAGRun(runID); + + REQUIRE(dagRun.dagSpec.tag == tag); + REQUIRE(dagRun.dagSpec.tasks == tasks); + + REQUIRE(dagRun.taskRunStates.size() == tasks.size()); + auto nonQueuedTask = std::find_if( + dagRun.taskRunStates.begin(), dagRun.taskRunStates.end(), + [](const auto &a) { return a.second != +RunState::QUEUED; }); + REQUIRE(nonQueuedTask == dagRun.taskRunStates.end()); + REQUIRE(dagRun.dagStateChanges.size() == 1); + REQUIRE(dagRun.dagStateChanges.back().newState == +RunState::QUEUED); + } + + // Update DAG state and ensure that it's updated; + { + logger.updateDAGRunState(runID, RunState::RUNNING); + auto dagRun = logger.getDAGRun(runID); + REQUIRE(dagRun.dagStateChanges.back().newState == +RunState::RUNNING); + } + + // Update a task state + { + for (const auto &[k, v] : tasks) + logger.updateTaskState(runID, k, RunState::RUNNING); + auto dagRun = logger.getDAGRun(runID); + for (const auto &[k, v] : tasks) { + REQUIRE(dagRun.taskRunStates.at(k) == +RunState::RUNNING); + } + } - REQUIRE(dagRun.dagStateChanges.size() == 1); - REQUIRE(dagRun.dagStateChanges.back().newState == +RunState::QUEUED); return runID; } TEST_CASE("ostream_logger", "[ostream_logger]") { - // cleanup(); std::stringstream ss; daggy::loggers::dag_run::OStreamLogger logger(ss); @@ -49,6 +88,4 @@ TEST_CASE("ostream_logger", "[ostream_logger]") { testDAGRunInit(logger, "init_test", SAMPLE_TASKS); } - - // cleanup(); } diff --git a/tests/unit_dagrunner.cpp b/tests/unit_dagrunner.cpp new file mode 100644 index 0000000..b5d7b8a --- /dev/null +++ b/tests/unit_dagrunner.cpp @@ -0,0 +1,256 @@ +#include +#include +#include + +#include "daggy/DAGRunner.hpp" +#include "daggy/executors/task/ForkingTaskExecutor.hpp" +#include "daggy/executors/task/NoopTaskExecutor.hpp" +#include "daggy/loggers/dag_run/OStreamLogger.hpp" + +namespace fs = std::filesystem; + +TEST_CASE("dagrunner", "[dagrunner_order_preservation]") +{ + daggy::executors::task::NoopTaskExecutor ex; + std::stringstream ss; + daggy::loggers::dag_run::OStreamLogger logger(ss); + + daggy::TimePoint globalStartTime = daggy::Clock::now(); + + daggy::DAGSpec dagSpec; + + std::string testParams{ + R"({"DATE": ["2021-05-06", "2021-05-07", "2021-05-08", "2021-05-09" ]})"}; + dagSpec.taskConfig.variables = daggy::configFromJSON(testParams); + + std::string taskJSON = R"({ + "A": {"job": {"command": ["/usr/bin/touch", "{{DATE}}"]}, "children": [ "B","D" ]}, + "B": {"job": {"command": ["/usr/bin/touch", "{{DATE}}"]}, "children": [ "C","D","E" ]}, + "C": {"job": {"command": ["/usr/bin/touch", "{{DATE}}"]}, "children": [ "D"]}, + "D": {"job": {"command": ["/usr/bin/touch", "{{DATE}}"]}, "children": [ "E"]}, + "E": {"job": {"command": ["/usr/bin/touch", "{{DATE}}"]}} + })"; + + dagSpec.tasks = expandTaskSet(daggy::tasksFromJSON(taskJSON), ex, + dagSpec.taskConfig.variables); + + REQUIRE(dagSpec.tasks.size() == 20); + + auto dag = daggy::buildDAGFromTasks(dagSpec.tasks); + auto runID = logger.startDAGRun(dagSpec); + + daggy::DAGRunner runner(runID, ex, logger, dag, dagSpec.taskConfig); + + auto endDAG = runner.run(); + + REQUIRE(endDAG.allVisited()); + + // Ensure the run order + auto rec = logger.getDAGRun(runID); + + daggy::TimePoint globalStopTime = daggy::Clock::now(); + std::array minTimes; + minTimes.fill(globalStartTime); + std::array maxTimes; + maxTimes.fill(globalStopTime); + + for (const auto &[k, v] : rec.taskAttempts) { + size_t idx = k[0] - 65; + auto &startTime = minTimes[idx]; + auto &stopTime = maxTimes[idx]; + startTime = std::max(startTime, v.front().startTime); + stopTime = std::min(stopTime, v.back().stopTime); + } + + for (size_t i = 0; i < 5; ++i) { + for (size_t j = i + 1; j < 4; ++j) { + REQUIRE(maxTimes[i] < minTimes[j]); + } + } +} + +TEST_CASE("DAGRunner simple execution", "[dagrunner_simple]") +{ + daggy::executors::task::ForkingTaskExecutor ex(10); + std::stringstream ss; + daggy::loggers::dag_run::OStreamLogger logger(ss); + + daggy::DAGSpec dagSpec; + + SECTION("Simple execution") + { + std::string prefix = (fs::current_path() / "asdlk").string(); + std::unordered_map files{ + {"A", prefix + "_A"}, {"B", prefix + "_B"}, {"C", prefix + "_C"}}; + std::string taskJSON = + R"({"A": {"job": {"command": ["/usr/bin/touch", ")" + files.at("A") + + R"("]}, "children": ["C"]}, "B": {"job": {"command": ["/usr/bin/touch", ")" + + files.at("B") + + R"("]}, "children": ["C"]}, "C": {"job": {"command": ["/usr/bin/touch", ")" + + files.at("C") + R"("]}}})"; + dagSpec.tasks = expandTaskSet(daggy::tasksFromJSON(taskJSON), ex); + auto dag = daggy::buildDAGFromTasks(dagSpec.tasks); + auto runID = logger.startDAGRun(dagSpec); + daggy::DAGRunner runner(runID, ex, logger, dag, dagSpec.taskConfig); + auto endDAG = runner.run(); + REQUIRE(endDAG.allVisited()); + + for (const auto &[_, file] : files) { + REQUIRE(fs::exists(file)); + fs::remove(file); + } + + // Get the DAG Run Attempts + auto record = logger.getDAGRun(runID); + for (const auto &[_, attempts] : record.taskAttempts) { + REQUIRE(attempts.size() == 1); + REQUIRE(attempts.front().rc == 0); + } + } +} + +TEST_CASE("DAG Runner Restart old DAG", "[dagrunner_restart]") +{ + daggy::executors::task::ForkingTaskExecutor ex(10); + std::stringstream ss; + daggy::loggers::dag_run::OStreamLogger logger(ss); + daggy::DAGSpec dagSpec; + + SECTION("Recovery from Error") + { + auto cleanup = []() { + // Cleanup + std::vector paths{"rec_error_A", "noexist"}; + for (const auto &pth : paths) { + if (fs::exists(pth)) + fs::remove_all(pth); + } + }; + + cleanup(); + + std::string goodPrefix = "rec_error_"; + std::string badPrefix = "noexist/rec_error_"; + std::string taskJSON = + R"({"A": {"job": {"command": ["/usr/bin/touch", ")" + goodPrefix + + R"(A"]}, "children": ["C"]}, "B": {"job": {"command": ["/usr/bin/touch", ")" + + badPrefix + + R"(B"]}, "children": ["C"]}, "C": {"job": {"command": ["/usr/bin/touch", ")" + + badPrefix + R"(C"]}}})"; + dagSpec.tasks = expandTaskSet(daggy::tasksFromJSON(taskJSON), ex); + auto dag = daggy::buildDAGFromTasks(dagSpec.tasks); + + auto runID = logger.startDAGRun(dagSpec); + + daggy::DAGRunner runner(runID, ex, logger, dag, dagSpec.taskConfig); + auto tryDAG = runner.run(); + + REQUIRE(!tryDAG.allVisited()); + + // Create the missing dir, then continue to run the DAG + fs::create_directory("noexist"); + runner.resetRunning(); + auto endDAG = runner.run(); + + REQUIRE(endDAG.allVisited()); + + // Get the DAG Run Attempts + auto record = logger.getDAGRun(runID); + REQUIRE(record.taskAttempts["A_0"].size() == 1); // A ran fine + REQUIRE(record.taskAttempts["B_0"].size() == + 2); // B errored and had to be retried + REQUIRE(record.taskAttempts["C_0"].size() == + 1); // C wasn't run because B errored + + cleanup(); + } +} + +TEST_CASE("DAG Runner Generator Tasks", "[dagrunner_generator]") +{ + daggy::executors::task::ForkingTaskExecutor ex(10); + std::stringstream ss; + daggy::loggers::dag_run::OStreamLogger logger(ss); + daggy::DAGSpec dagSpec; + + SECTION("Generator tasks") + { + std::string testParams{R"({"DATE": ["2021-05-06", "2021-05-07" ]})"}; + dagSpec.taskConfig.variables = daggy::configFromJSON(testParams); + + std::string generatorOutput = + R"({"B": {"job": {"command": ["/usr/bin/echo", "-e", "{{DATE}}"]}, "children": ["C"]}})"; + fs::path ofn = fs::current_path() / "generator_test_output.json"; + std::ofstream ofh{ofn}; + ofh << generatorOutput << std::endl; + ofh.close(); + + daggy::TimePoint globalStartTime = daggy::Clock::now(); + std::stringstream jsonTasks; + jsonTasks + << R"({ "A": { "job": {"command": [ "/usr/bin/cat", )" + << std::quoted(ofn.string()) + << R"(]}, "children": ["C"], "isGenerator": true},)" + << R"("C": { "job": {"command": [ "/usr/bin/echo", "hello!"]} } })"; + + dagSpec.tasks = daggy::tasksFromJSON(jsonTasks.str()); + REQUIRE(dagSpec.tasks.size() == 2); + REQUIRE(dagSpec.tasks["A"].children == + std::unordered_set{"C"}); + dagSpec.tasks = + daggy::expandTaskSet(dagSpec.tasks, ex, dagSpec.taskConfig.variables); + REQUIRE(dagSpec.tasks.size() == 2); + REQUIRE(dagSpec.tasks["A_0"].children == + std::unordered_set{"C"}); + auto dag = daggy::buildDAGFromTasks(dagSpec.tasks); + REQUIRE(dag.size() == 2); + + auto runID = logger.startDAGRun(dagSpec); + daggy::DAGRunner runner(runID, ex, logger, dag, dagSpec.taskConfig); + auto finalDAG = runner.run(); + + REQUIRE(finalDAG.allVisited()); + REQUIRE(finalDAG.size() == 4); + + // Check the logger + auto record = logger.getDAGRun(runID); + + REQUIRE(record.dagSpec.tasks.size() == 4); + REQUIRE(record.taskRunStates.size() == 4); + for (const auto &[taskName, attempts] : record.taskAttempts) { + REQUIRE(attempts.size() == 1); + REQUIRE(attempts.back().rc == 0); + } + + // Ensure that children were updated properly + REQUIRE(record.dagSpec.tasks["A_0"].children == + std::unordered_set{"B_0", "B_1", "C"}); + REQUIRE(record.dagSpec.tasks["B_0"].children == + std::unordered_set{"C"}); + REQUIRE(record.dagSpec.tasks["B_1"].children == + std::unordered_set{"C"}); + REQUIRE(record.dagSpec.tasks["C_0"].children.empty()); + + // Ensure they were run in the right order + // All A's get run before B's, which run before C's + daggy::TimePoint globalStopTime = daggy::Clock::now(); + std::array minTimes; + minTimes.fill(globalStartTime); + std::array maxTimes; + maxTimes.fill(globalStopTime); + + for (const auto &[k, v] : record.taskAttempts) { + size_t idx = k[0] - 65; + auto &startTime = minTimes[idx]; + auto &stopTime = maxTimes[idx]; + startTime = std::max(startTime, v.front().startTime); + stopTime = std::min(stopTime, v.back().stopTime); + } + + for (size_t i = 0; i < 3; ++i) { + for (size_t j = i + 1; j < 2; ++j) { + REQUIRE(maxTimes[i] < minTimes[j]); + } + } + } +} diff --git a/tests/unit_executor_forkingexecutor.cpp b/tests/unit_executor_forkingexecutor.cpp index 42e393d..276ca47 100644 --- a/tests/unit_executor_forkingexecutor.cpp +++ b/tests/unit_executor_forkingexecutor.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include "daggy/Serialization.hpp" #include "daggy/Utilities.hpp" @@ -18,7 +19,7 @@ TEST_CASE("forking_executor", "[forking_executor]") REQUIRE(ex.validateTaskParameters(task.job)); - auto recFuture = ex.execute("command", task); + auto recFuture = ex.execute(0, "command", task); auto rec = recFuture.get(); REQUIRE(rec.rc == 0); @@ -32,7 +33,7 @@ TEST_CASE("forking_executor", "[forking_executor]") .job{{"command", daggy::executors::task::ForkingTaskExecutor::Command{ "/usr/bin/expr", "1", "+", "+"}}}}; - auto recFuture = ex.execute("command", task); + auto recFuture = ex.execute(0, "command", task); auto rec = recFuture.get(); REQUIRE(rec.rc == 2); @@ -40,6 +41,28 @@ TEST_CASE("forking_executor", "[forking_executor]") REQUIRE(rec.outputLog.empty()); } + SECTION("Killing a long task") + { + daggy::Task task{ + .job{{"command", daggy::executors::task::ForkingTaskExecutor::Command{ + "/usr/bin/sleep", "30"}}}}; + + auto start = daggy::Clock::now(); + auto recFuture = ex.execute(0, "command", task); + std::this_thread::sleep_for(1s); + ex.stop(0, "command"); + auto rec = recFuture.get(); + auto stop = daggy::Clock::now(); + + REQUIRE(rec.rc == 9); + REQUIRE(rec.errorLog.empty()); + REQUIRE(rec.outputLog.empty()); + REQUIRE(rec.executorLog == "Killed"); + REQUIRE( + std::chrono::duration_cast(stop - start).count() < + 20); + } + SECTION("Large Output") { const std::vector BIG_FILES{"/usr/share/dict/linux.words", @@ -54,7 +77,7 @@ TEST_CASE("forking_executor", "[forking_executor]") .job{{"command", daggy::executors::task::ForkingTaskExecutor::Command{ "/usr/bin/cat", bigFile}}}}; - auto recFuture = ex.execute("command", task); + auto recFuture = ex.execute(0, "command", task); auto rec = recFuture.get(); REQUIRE(rec.rc == 0); diff --git a/tests/unit_executor_slurmexecutor.cpp b/tests/unit_executor_slurmexecutor.cpp index 2b63a72..3dc8778 100644 --- a/tests/unit_executor_slurmexecutor.cpp +++ b/tests/unit_executor_slurmexecutor.cpp @@ -34,7 +34,7 @@ TEST_CASE("slurm_execution", "[slurm_executor]") REQUIRE(ex.validateTaskParameters(task.job)); - auto recFuture = ex.execute("command", task); + auto recFuture = ex.execute(0, "command", task); auto rec = recFuture.get(); REQUIRE(rec.rc == 0); @@ -49,7 +49,7 @@ TEST_CASE("slurm_execution", "[slurm_executor]") "/usr/bin/expr", "1", "+", "+"}}}}; task.job.merge(defaultJobValues); - auto recFuture = ex.execute("command", task); + auto recFuture = ex.execute(0, "command", task); auto rec = recFuture.get(); REQUIRE(rec.rc != 0); @@ -57,6 +57,23 @@ TEST_CASE("slurm_execution", "[slurm_executor]") REQUIRE(rec.outputLog.empty()); } + SECTION("Killing a long task") + { + daggy::Task task{ + .job{{"command", daggy::executors::task::SlurmTaskExecutor::Command{ + "/usr/bin/sleep", "30"}}}}; + task.job.merge(defaultJobValues); + + auto recFuture = ex.execute(0, "command", task); + ex.stop(0, "command"); + auto rec = recFuture.get(); + + REQUIRE(rec.rc == 9); + REQUIRE(rec.errorLog.empty()); + REQUIRE(rec.outputLog.empty()); + REQUIRE(rec.executorLog == "Job cancelled by user.\n"); + } + SECTION("Large Output") { const std::vector BIG_FILES{"/usr/share/dict/linux.words", @@ -72,7 +89,7 @@ TEST_CASE("slurm_execution", "[slurm_executor]") "/usr/bin/cat", bigFile}}}}; task.job.merge(defaultJobValues); - auto recFuture = ex.execute("command", task); + auto recFuture = ex.execute(0, "command", task); auto rec = recFuture.get(); REQUIRE(rec.rc == 0); diff --git a/tests/unit_server.cpp b/tests/unit_server.cpp index bb7f444..2d1bf01 100644 --- a/tests/unit_server.cpp +++ b/tests/unit_server.cpp @@ -1,51 +1,131 @@ +#include #include #include +#include #include #include #include #include +#include #include #include #include +#include namespace rj = rapidjson; -Pistache::Http::Response REQUEST(const std::string &url, - const std::string &payload = "") -{ - Pistache::Http::Experimental::Client client; - client.init(); - Pistache::Http::Response response; - auto reqSpec = (payload.empty() ? client.get(url) : client.post(url)); - reqSpec.timeout(std::chrono::seconds(2)); - if (!payload.empty()) { - reqSpec.body(payload); - } - auto request = reqSpec.send(); - bool ok = false, error = false; - std::string msg; - request.then( - [&](Pistache::Http::Response rsp) { - ok = true; - response = std::move(rsp); - }, - [&](std::exception_ptr ptr) { - error = true; - try { - std::rethrow_exception(std::move(ptr)); - } - catch (std::exception &e) { - msg = e.what(); - } - }); +using namespace daggy; - Pistache::Async::Barrier barrier(request); - barrier.wait_for(std::chrono::seconds(2)); - client.shutdown(); - if (error) { - throw std::runtime_error(msg); +#ifdef DEBUG_HTTP +static int my_trace(CURL *handle, curl_infotype type, char *data, size_t size, + void *userp) +{ + const char *text; + (void)handle; /* prevent compiler warning */ + (void)userp; + + switch (type) { + case CURLINFO_TEXT: + fprintf(stderr, "== Info: %s", data); + default: /* in case a new one is introduced to shock us */ + return 0; + + case CURLINFO_HEADER_OUT: + text = "=> Send header"; + break; + case CURLINFO_DATA_OUT: + text = "=> Send data"; + break; + case CURLINFO_SSL_DATA_OUT: + text = "=> Send SSL data"; + break; + case CURLINFO_HEADER_IN: + text = "<= Recv header"; + break; + case CURLINFO_DATA_IN: + text = "<= Recv data"; + break; + case CURLINFO_SSL_DATA_IN: + text = "<= Recv SSL data"; + break; } + + std::cerr << "\n================== " << text + << " ==================" << std::endl + << data << std::endl; + return 0; +} +#endif + +enum HTTPCode : long +{ + Ok = 200, + Not_Found = 404 +}; + +struct HTTPResponse +{ + HTTPCode code; + std::string body; +}; + +uint curlWriter(char *in, uint size, uint nmemb, std::stringstream *out) +{ + uint r; + r = size * nmemb; + out->write(in, r); + return r; +} + +HTTPResponse REQUEST(const std::string &url, const std::string &payload = "", + const std::string &method = "GET") +{ + HTTPResponse response; + + CURL *curl; + CURLcode res; + struct curl_slist *headers = NULL; + + curl_global_init(CURL_GLOBAL_ALL); + + curl = curl_easy_init(); + if (curl) { + std::stringstream buffer; + +#ifdef DEBUG_HTTP + curl_easy_setopt(curl, CURLOPT_DEBUGFUNCTION, my_trace); + curl_easy_setopt(curl, CURLOPT_VERBOSE, 1L); +#endif + + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, curlWriter); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &buffer); + + if (!payload.empty()) { + curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, payload.size()); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, payload.c_str()); + headers = curl_slist_append(headers, "Content-Type: Application/Json"); + } + curl_easy_setopt(curl, CURLOPT_CUSTOMREQUEST, method.c_str()); + headers = curl_slist_append(headers, "Expect:"); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + + res = curl_easy_perform(curl); + + if (res != CURLE_OK) { + curl_easy_cleanup(curl); + throw std::runtime_error(std::string{"CURL Failed: "} + + curl_easy_strerror(res)); + } + curl_easy_cleanup(curl); + + curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &response.code); + response.body = buffer.str(); + } + + curl_global_cleanup(); + return response; } @@ -68,19 +148,19 @@ TEST_CASE("rest_endpoint", "[server_basic]") SECTION("Ready Endpoint") { auto response = REQUEST(baseURL + "/ready"); - REQUIRE(response.code() == Pistache::Http::Code::Ok); + REQUIRE(response.code == HTTPCode::Ok); } SECTION("Querying a non-existent dagrunid should fail ") { auto response = REQUEST(baseURL + "/v1/dagrun/100"); - REQUIRE(response.code() != Pistache::Http::Code::Ok); + REQUIRE(response.code != HTTPCode::Ok); } SECTION("Simple DAGRun Submission") { std::string dagRun = R"({ - "name": "unit_server", + "tag": "unit_server", "parameters": { "FILE": [ "A", "B" ] }, "tasks": { "touch": { "job": { "command": [ "/usr/bin/touch", "dagrun_{{FILE}}" ]} }, @@ -90,14 +170,16 @@ TEST_CASE("rest_endpoint", "[server_basic]") } })"; + auto dagSpec = daggy::dagFromJSON(dagRun); + // Submit, and get the runID daggy::DAGRunID runID = 0; { - auto response = REQUEST(baseURL + "/v1/dagrun/", dagRun); - REQUIRE(response.code() == Pistache::Http::Code::Ok); + auto response = REQUEST(baseURL + "/v1/dagrun/", dagRun, "POST"); + REQUIRE(response.code == HTTPCode::Ok); rj::Document doc; - daggy::checkRJParse(doc.Parse(response.body().c_str())); + daggy::checkRJParse(doc.Parse(response.body.c_str())); REQUIRE(doc.IsObject()); REQUIRE(doc.HasMember("runID")); @@ -106,11 +188,11 @@ TEST_CASE("rest_endpoint", "[server_basic]") // Ensure our runID shows up in the list of running DAGs { - auto response = REQUEST(baseURL + "/v1/dagrun/"); - REQUIRE(response.code() == Pistache::Http::Code::Ok); + auto response = REQUEST(baseURL + "/v1/dagruns?all=1"); + REQUIRE(response.code == HTTPCode::Ok); rj::Document doc; - daggy::checkRJParse(doc.Parse(response.body().c_str())); + daggy::checkRJParse(doc.Parse(response.body.c_str())); REQUIRE(doc.IsArray()); REQUIRE(doc.Size() >= 1); @@ -120,10 +202,10 @@ TEST_CASE("rest_endpoint", "[server_basic]") for (size_t i = 0; i < runs.Size(); ++i) { const auto &run = runs[i]; REQUIRE(run.IsObject()); - REQUIRE(run.HasMember("name")); + REQUIRE(run.HasMember("tag")); REQUIRE(run.HasMember("runID")); - std::string runName = run["name"].GetString(); + std::string runName = run["tag"].GetString(); if (runName == "unit_server") { REQUIRE(run["runID"].GetUint64() == runID); found = true; @@ -133,13 +215,28 @@ TEST_CASE("rest_endpoint", "[server_basic]") REQUIRE(found); } + // Ensure we can get one of our tasks + { + auto response = REQUEST(baseURL + "/v1/dagrun/" + std::to_string(runID) + + "/task/cat_0"); + REQUIRE(response.code == HTTPCode::Ok); + + rj::Document doc; + daggy::checkRJParse(doc.Parse(response.body.c_str())); + + REQUIRE_NOTHROW(daggy::taskFromJSON("cat", doc)); + auto task = daggy::taskFromJSON("cat", doc); + + REQUIRE(task == dagSpec.tasks.at("cat")); + } + // Wait until our DAG is complete bool complete = true; for (auto i = 0; i < 10; ++i) { auto response = REQUEST(baseURL + "/v1/dagrun/" + std::to_string(runID)); - REQUIRE(response.code() == Pistache::Http::Code::Ok); + REQUIRE(response.code == HTTPCode::Ok); rj::Document doc; - daggy::checkRJParse(doc.Parse(response.body().c_str())); + daggy::checkRJParse(doc.Parse(response.body.c_str())); REQUIRE(doc.IsObject()); REQUIRE(doc.HasMember("taskStates")); @@ -173,6 +270,113 @@ TEST_CASE("rest_endpoint", "[server_basic]") fs::remove(pth); } } +} + +TEST_CASE("Server cancels and resumes execution", "[server_resume]") +{ + std::stringstream ss; + daggy::executors::task::ForkingTaskExecutor executor(10); + daggy::loggers::dag_run::OStreamLogger logger(ss); + Pistache::Address listenSpec("localhost", Pistache::Port(0)); + + const size_t nDAGRunners = 10, nWebThreads = 10; + + daggy::Server server(listenSpec, logger, executor, nDAGRunners); + server.init(nWebThreads); + server.start(); + + const std::string host = "localhost:"; + const std::string baseURL = host + std::to_string(server.getPort()); + + SECTION("Cancel / Resume DAGRun") + { + std::string dagRunJSON = R"({ + "tag": "unit_server", + "tasks": { + "touch_A": { "job": { "command": [ "/usr/bin/touch", "resume_touch_a" ]}, "children": ["touch_C"] }, + "sleep_B": { "job": { "command": [ "/usr/bin/sleep", "3" ]}, "children": ["touch_C"] }, + "touch_C": { "job": { "command": [ "/usr/bin/touch", "resume_touch_c" ]} } + } + })"; + + auto dagSpec = daggy::dagFromJSON(dagRunJSON); + + // Submit, and get the runID + daggy::DAGRunID runID; + { + auto response = REQUEST(baseURL + "/v1/dagrun/", dagRunJSON, "POST"); + REQUIRE(response.code == HTTPCode::Ok); + + rj::Document doc; + daggy::checkRJParse(doc.Parse(response.body.c_str())); + REQUIRE(doc.IsObject()); + REQUIRE(doc.HasMember("runID")); + + runID = doc["runID"].GetUint64(); + } + + std::this_thread::sleep_for(1s); + + // Stop the current run + { + auto response = REQUEST( + baseURL + "/v1/dagrun/" + std::to_string(runID) + "/state/KILLED", "", + "PATCH"); + REQUIRE(response.code == HTTPCode::Ok); + REQUIRE(logger.getDAGRunState(runID) == +daggy::RunState::KILLED); + } + + // Verify that the run still exists + { + auto dagRun = logger.getDAGRun(runID); + REQUIRE(dagRun.taskRunStates.at("touch_A_0") == + +daggy::RunState::COMPLETED); + REQUIRE(fs::exists("resume_touch_a")); + + REQUIRE(dagRun.taskRunStates.at("sleep_B_0") == + +daggy::RunState::ERRORED); + REQUIRE(dagRun.taskRunStates.at("touch_C_0") == +daggy::RunState::QUEUED); + } + + // Set the errored task state + { + auto url = baseURL + "/v1/dagrun/" + std::to_string(runID) + + "/task/sleep_B_0/state/QUEUED"; + auto response = REQUEST(url, "", "PATCH"); + REQUIRE(response.code == HTTPCode::Ok); + REQUIRE(logger.getTaskState(runID, "sleep_B_0") == + +daggy::RunState::QUEUED); + } + + // Resume + { + struct stat s; + + lstat("resume_touch_A", &s); + auto preMTime = s.st_mtim.tv_sec; + + auto response = REQUEST( + baseURL + "/v1/dagrun/" + std::to_string(runID) + "/state/QUEUED", "", + "PATCH"); + + // Wait for run to complete + std::this_thread::sleep_for(5s); + REQUIRE(logger.getDAGRunState(runID) == +daggy::RunState::COMPLETED); + + REQUIRE(fs::exists("resume_touch_c")); + REQUIRE(fs::exists("resume_touch_a")); + + for (const auto &[taskName, task] : dagSpec.tasks) { + REQUIRE(logger.getTaskState(runID, taskName + "_0") == + +daggy::RunState::COMPLETED); + } + + // Ensure "touch_A" wasn't run again + lstat("resume_touch_A", &s); + auto postMTime = s.st_mtim.tv_sec; + REQUIRE(preMTime == postMTime); + } + } server.shutdown(); } diff --git a/tests/unit_utilities.cpp b/tests/unit_utilities.cpp index 1c6a6f0..ddb6ae5 100644 --- a/tests/unit_utilities.cpp +++ b/tests/unit_utilities.cpp @@ -8,11 +8,6 @@ #include "daggy/Serialization.hpp" #include "daggy/Utilities.hpp" -#include "daggy/executors/task/ForkingTaskExecutor.hpp" -#include "daggy/executors/task/NoopTaskExecutor.hpp" -#include "daggy/loggers/dag_run/OStreamLogger.hpp" - -namespace fs = std::filesystem; TEST_CASE("string_utilities", "[utilities_string]") { @@ -59,234 +54,3 @@ TEST_CASE("string_expansion", "[utilities_parameter_expansion]") REQUIRE(result.size() == 4); } } - -TEST_CASE("dag_runner_order", "[dagrun_order]") -{ - daggy::executors::task::NoopTaskExecutor ex; - std::stringstream ss; - daggy::loggers::dag_run::OStreamLogger logger(ss); - - daggy::TimePoint globalStartTime = daggy::Clock::now(); - - std::string testParams{ - R"({"DATE": ["2021-05-06", "2021-05-07", "2021-05-08", "2021-05-09" ]})"}; - auto params = daggy::configFromJSON(testParams); - - std::string taskJSON = R"({ - "A": {"job": {"command": ["/usr/bin/touch", "{{DATE}}"]}, "children": [ "B","D" ]}, - "B": {"job": {"command": ["/usr/bin/touch", "{{DATE}}"]}, "children": [ "C","D","E" ]}, - "C": {"job": {"command": ["/usr/bin/touch", "{{DATE}}"]}, "children": [ "D"]}, - "D": {"job": {"command": ["/usr/bin/touch", "{{DATE}}"]}, "children": [ "E"]}, - "E": {"job": {"command": ["/usr/bin/touch", "{{DATE}}"]}} - })"; - - auto tasks = expandTaskSet(daggy::tasksFromJSON(taskJSON), ex, params); - - REQUIRE(tasks.size() == 20); - - auto dag = daggy::buildDAGFromTasks(tasks); - auto runID = logger.startDAGRun("test_run", tasks); - auto endDAG = daggy::runDAG(runID, ex, logger, dag); - - REQUIRE(endDAG.allVisited()); - - // Ensure the run order - auto rec = logger.getDAGRun(runID); - - daggy::TimePoint globalStopTime = daggy::Clock::now(); - std::array minTimes; - minTimes.fill(globalStartTime); - std::array maxTimes; - maxTimes.fill(globalStopTime); - - for (const auto &[k, v] : rec.taskAttempts) { - size_t idx = k[0] - 65; - auto &startTime = minTimes[idx]; - auto &stopTime = maxTimes[idx]; - startTime = std::max(startTime, v.front().startTime); - stopTime = std::min(stopTime, v.back().stopTime); - } - - for (size_t i = 0; i < 5; ++i) { - for (size_t j = i + 1; j < 4; ++j) { - REQUIRE(maxTimes[i] < minTimes[j]); - } - } -} - -TEST_CASE("dag_runner", "[utilities_dag_runner]") -{ - daggy::executors::task::ForkingTaskExecutor ex(10); - std::stringstream ss; - daggy::loggers::dag_run::OStreamLogger logger(ss); - - SECTION("Simple execution") - { - std::string prefix = (fs::current_path() / "asdlk").string(); - std::unordered_map files{ - {"A", prefix + "_A"}, {"B", prefix + "_B"}, {"C", prefix + "_C"}}; - std::string taskJSON = - R"({"A": {"job": {"command": ["/usr/bin/touch", ")" + files.at("A") + - R"("]}, "children": ["C"]}, "B": {"job": {"command": ["/usr/bin/touch", ")" + - files.at("B") + - R"("]}, "children": ["C"]}, "C": {"job": {"command": ["/usr/bin/touch", ")" + - files.at("C") + R"("]}}})"; - auto tasks = expandTaskSet(daggy::tasksFromJSON(taskJSON), ex); - auto dag = daggy::buildDAGFromTasks(tasks); - auto runID = logger.startDAGRun("test_run", tasks); - auto endDAG = daggy::runDAG(runID, ex, logger, dag); - - REQUIRE(endDAG.allVisited()); - - for (const auto &[_, file] : files) { - REQUIRE(fs::exists(file)); - fs::remove(file); - } - - // Get the DAG Run Attempts - auto record = logger.getDAGRun(runID); - for (const auto &[_, attempts] : record.taskAttempts) { - REQUIRE(attempts.size() == 1); - REQUIRE(attempts.front().rc == 0); - } - } -} - -TEST_CASE("runDAG_recovery", "[runDAG]") -{ - daggy::executors::task::ForkingTaskExecutor ex(10); - std::stringstream ss; - daggy::loggers::dag_run::OStreamLogger logger(ss); - - SECTION("Recovery from Error") - { - auto cleanup = []() { - // Cleanup - std::vector paths{"rec_error_A", "noexist"}; - for (const auto &pth : paths) { - if (fs::exists(pth)) - fs::remove_all(pth); - } - }; - - cleanup(); - - std::string goodPrefix = "rec_error_"; - std::string badPrefix = "noexist/rec_error_"; - std::string taskJSON = - R"({"A": {"job": {"command": ["/usr/bin/touch", ")" + goodPrefix + - R"(A"]}, "children": ["C"]}, "B": {"job": {"command": ["/usr/bin/touch", ")" + - badPrefix + - R"(B"]}, "children": ["C"]}, "C": {"job": {"command": ["/usr/bin/touch", ")" + - badPrefix + R"(C"]}}})"; - auto tasks = expandTaskSet(daggy::tasksFromJSON(taskJSON), ex); - auto dag = daggy::buildDAGFromTasks(tasks); - - auto runID = logger.startDAGRun("test_run", tasks); - - auto tryDAG = daggy::runDAG(runID, ex, logger, dag); - - REQUIRE(!tryDAG.allVisited()); - - // Create the missing dir, then continue to run the DAG - fs::create_directory("noexist"); - tryDAG.resetRunning(); - auto endDAG = daggy::runDAG(runID, ex, logger, tryDAG); - - REQUIRE(endDAG.allVisited()); - - // Get the DAG Run Attempts - auto record = logger.getDAGRun(runID); - REQUIRE(record.taskAttempts["A_0"].size() == 1); // A ran fine - REQUIRE(record.taskAttempts["B_0"].size() == - 2); // B errored and had to be retried - REQUIRE(record.taskAttempts["C_0"].size() == - 1); // C wasn't run because B errored - - cleanup(); - } -} - -TEST_CASE("runDAG_generator", "[runDAG_generator]") -{ - daggy::executors::task::ForkingTaskExecutor ex(10); - std::stringstream ss; - daggy::loggers::dag_run::OStreamLogger logger(ss); - - SECTION("Generator tasks") - { - std::string testParams{R"({"DATE": ["2021-05-06", "2021-05-07" ]})"}; - auto params = daggy::configFromJSON(testParams); - - std::string generatorOutput = - R"({"B": {"job": {"command": ["/usr/bin/echo", "-e", "{{DATE}}"]}, "children": ["C"]}})"; - fs::path ofn = fs::current_path() / "generator_test_output.json"; - std::ofstream ofh{ofn}; - ofh << generatorOutput << std::endl; - ofh.close(); - - daggy::TimePoint globalStartTime = daggy::Clock::now(); - std::stringstream jsonTasks; - jsonTasks - << R"({ "A": { "job": {"command": [ "/usr/bin/cat", )" - << std::quoted(ofn.string()) - << R"(]}, "children": ["C"], "isGenerator": true},)" - << R"("C": { "job": {"command": [ "/usr/bin/echo", "hello!"]} } })"; - - auto baseTasks = daggy::tasksFromJSON(jsonTasks.str()); - REQUIRE(baseTasks.size() == 2); - REQUIRE(baseTasks["A"].children == std::unordered_set{"C"}); - auto tasks = daggy::expandTaskSet(baseTasks, ex, params); - REQUIRE(tasks.size() == 2); - REQUIRE(tasks["A_0"].children == std::unordered_set{"C"}); - auto dag = daggy::buildDAGFromTasks(tasks); - REQUIRE(dag.size() == 2); - - auto runID = logger.startDAGRun("generator_run", tasks); - auto finalDAG = daggy::runDAG(runID, ex, logger, dag, params); - - REQUIRE(finalDAG.allVisited()); - REQUIRE(finalDAG.size() == 4); - - // Check the logger - auto record = logger.getDAGRun(runID); - - REQUIRE(record.tasks.size() == 4); - REQUIRE(record.taskRunStates.size() == 4); - for (const auto &[taskName, attempts] : record.taskAttempts) { - REQUIRE(attempts.size() == 1); - REQUIRE(attempts.back().rc == 0); - } - - // Ensure that children were updated properly - REQUIRE(record.tasks["A_0"].children == - std::unordered_set{"B_0", "B_1", "C"}); - REQUIRE(record.tasks["B_0"].children == - std::unordered_set{"C"}); - REQUIRE(record.tasks["B_1"].children == - std::unordered_set{"C"}); - REQUIRE(record.tasks["C_0"].children.empty()); - - // Ensure they were run in the right order - // All A's get run before B's, which run before C's - daggy::TimePoint globalStopTime = daggy::Clock::now(); - std::array minTimes; - minTimes.fill(globalStartTime); - std::array maxTimes; - maxTimes.fill(globalStopTime); - - for (const auto &[k, v] : record.taskAttempts) { - size_t idx = k[0] - 65; - auto &startTime = minTimes[idx]; - auto &stopTime = maxTimes[idx]; - startTime = std::max(startTime, v.front().startTime); - stopTime = std::min(stopTime, v.back().stopTime); - } - - for (size_t i = 0; i < 3; ++i) { - for (size_t j = i + 1; j < 2; ++j) { - REQUIRE(maxTimes[i] < minTimes[j]); - } - } - } -}