diff --git a/daggy/include/daggy/DAG.hpp b/daggy/include/daggy/DAG.hpp index 1a84621..cdb61ae 100644 --- a/daggy/include/daggy/DAG.hpp +++ b/daggy/include/daggy/DAG.hpp @@ -42,7 +42,7 @@ namespace daggy { void addEdgeIf(const K &src, std::function &v)> predicate); - bool hasPath(const K &from, const K &to) const; + bool isValid() const; bool hasVertex(const K &from); @@ -76,7 +76,9 @@ namespace daggy { private: std::unordered_map> vertices_; std::unordered_set readyVertices_; + + std::optional findCycle_(const K & node, std::unordered_set & seen) const; }; } -#include "DAG.impl.hxx" \ No newline at end of file +#include "DAG.impl.hxx" diff --git a/daggy/include/daggy/DAG.impl.hxx b/daggy/include/daggy/DAG.impl.hxx index 423b6ed..7237e97 100644 --- a/daggy/include/daggy/DAG.impl.hxx +++ b/daggy/include/daggy/DAG.impl.hxx @@ -25,8 +25,6 @@ namespace daggy { void DAG::addEdge(const K &from, const K &to) { if (vertices_.find(from) == vertices_.end()) throw std::runtime_error("No such vertex"); if (vertices_.find(to) == vertices_.end()) throw std::runtime_error("No such vertex"); - if (hasPath(to, from)) - throw std::runtime_error("Adding edge would result in a cycle"); vertices_.at(from).children.insert(to); vertices_.at(to).depCount++; } @@ -40,15 +38,35 @@ namespace daggy { } template - bool DAG::hasPath(const K &from, const K &to) const { - if (vertices_.find(from) == vertices_.end()) throw std::runtime_error("No such vertex"); - if (vertices_.find(to) == vertices_.end()) throw std::runtime_error("No such vertex"); - for (const auto &child: vertices_.at(from).children) { - if (child == to) return true; - if (hasPath(child, to)) return true; + std::optional DAG::findCycle_(const K & node, std::unordered_set & seen) const { + if (seen.count(node) > 0) return node; + seen.insert(node); + std::optional ret; + for (const auto & child : vertices_.at(node).children) { + auto res = findCycle_(child, seen); + if (res.has_value()) { + ret.swap(res); + break; + } } + seen.extract(node); + return ret; + } - return false; + template + bool DAG::isValid() const { + std::unordered_set seen; + for (const auto & [k, v] : vertices_) { + seen.clear(); + if (v.depCount != 0) continue; + auto res = findCycle_(k, seen); + if (res.has_value()) { + std::stringstream ss; + ss << "DAG contains a cycle between " << k << " and " << res.value() << std::endl; + throw std::runtime_error(ss.str()); + } + } + return true; } template diff --git a/daggy/src/Utilities.cpp b/daggy/src/Utilities.cpp index 209a93a..c0a680f 100644 --- a/daggy/src/Utilities.cpp +++ b/daggy/src/Utilities.cpp @@ -108,6 +108,7 @@ namespace daggy { TaskDAG dag; updateDAGFromTasks(dag, tasks); dag.reset(); + dag.isValid(); // Replay any updates for (const auto &update: updates) { diff --git a/tests/unit_dag.cpp b/tests/unit_dag.cpp index f38a79d..6a3c9df 100644 --- a/tests/unit_dag.cpp +++ b/tests/unit_dag.cpp @@ -22,15 +22,13 @@ TEST_CASE("dag_construction", "[dag]") { REQUIRE(!dag.empty()); // Cannot add an edge that would result in a cycle - REQUIRE_THROWS(dag.addEdge(9, 5)); + dag.addEdge(9, 5); + REQUIRE_THROWS(dag.isValid()); // Bounds checking SECTION("addEdge Bounds Checking") { REQUIRE_THROWS(dag.addEdge(20, 0)); REQUIRE_THROWS(dag.addEdge(0, 20)); - }SECTION("hasPath Bounds Checking") { - REQUIRE_THROWS(dag.hasPath(20, 0)); - REQUIRE_THROWS(dag.hasPath(0, 20)); } } diff --git a/tests/unit_utilities.cpp b/tests/unit_utilities.cpp index f1ba9db..20d2a5a 100644 --- a/tests/unit_utilities.cpp +++ b/tests/unit_utilities.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include @@ -180,3 +181,52 @@ TEST_CASE("dag_runner", "[utilities_dag_runner]") { REQUIRE(record.tasks["C_0"].children.empty()); } } + +TEST_CASE("dag_runner_stress", "[utilities_dag_runner_stress]") { + daggy::executors::task::ForkingTaskExecutor ex(10); + std::stringstream ss; + daggy::loggers::dag_run::OStreamLogger logger(ss); + + + SECTION("Stress-test") { + static std::random_device dev; + static std::mt19937 rng(dev()); + std::uniform_int_distribution nDepDist(0, 10); + + const size_t N_NODES = 100; + daggy::TaskSet tasks; + std::vector fileNames; + std::vector taskNames; + + for (size_t i = 0; i < N_NODES; ++i) { + std::string taskName = std::to_string(i); + std::uniform_int_distribution depDist(i+1, N_NODES-1); + std::unordered_set deps; + size_t nChildren = nDepDist(rng); + for (size_t c = 0; c < nChildren; ++c) { + deps.insert(std::to_string(depDist(rng))); + } + tasks.emplace(taskName, daggy::Task{ + .definedName = taskName, + .job = { { "command", std::vector{"/usr/bin/echo", taskName}}}, + .children = deps + }); + } + + auto dag = daggy::buildDAGFromTasks(tasks); + + auto runID = logger.startDAGRun("test_run", tasks); + + auto tryDAG = daggy::runDAG(runID, ex, logger, dag); + + REQUIRE(tryDAG.allVisited()); + + // Get the DAG Run Attempts + auto record = logger.getDAGRun(runID); + for (const auto & [k, attempts] : record.taskAttempts) { + REQUIRE(attempts.size() == 1); + } + } + + +}