diff --git a/daggyd/daggyd/daggyd.cpp b/daggyd/daggyd/daggyd.cpp index ed78a2d..64ee6e0 100644 --- a/daggyd/daggyd/daggyd.cpp +++ b/daggyd/daggyd/daggyd.cpp @@ -12,6 +12,7 @@ // Add executors here #include #include +#include #ifdef DAGGY_ENABLE_SLURM #include @@ -208,7 +209,45 @@ std::unique_ptr executorFactory(const rj::Value &config) "DaggyRunnerExecutor runners must be an array of urls"); exe->addRunner(runners[i].GetString()); } + } + else if (name == "SSHTaskExecutor") { + if (!execConfig.HasMember("hosts")) + throw std::runtime_error( + "SSHTaskExecutor config needs at least one host"); + std::unordered_map + remoteHosts; + + const auto &hosts = execConfig["hosts"]; + if (!hosts.IsObject()) + throw std::runtime_error( + "SSHTaskExecutor hosts must be a dictionary of host => {cores, " + "memoryMB}"); + + for (auto it = hosts.MemberBegin(); it != hosts.MemberEnd(); ++it) { + if (!it->name.IsString()) + throw std::runtime_error("Hostnames names must be a string."); + if (!it->value.IsObject()) + throw std::runtime_error("Hostname definitions must be an object."); + const std::string hostName = it->name.GetString(); + + const auto &caps = it->value.GetObject(); + + if (!caps.HasMember("cores")) + throw std::runtime_error("Host " + hostName + + " is missing cores count."); + if (!caps.HasMember("memoryMB")) + throw std::runtime_error("Host " + hostName + + " is missing memoryMB size."); + + size_t cores = caps["cores"].GetInt64(); + size_t mem = caps["memoryMB"].GetInt64(); + remoteHosts.emplace(hostName, + daggy::executors::task::SSHTaskExecutor::RemoteHost{ + .cores = cores, .memoryMB = mem}); + } + auto exe = std::make_unique(remoteHosts); return exe; } diff --git a/libdaggy/include/daggy/executors/task/SSHTaskExecutor.hpp b/libdaggy/include/daggy/executors/task/SSHTaskExecutor.hpp new file mode 100644 index 0000000..2b96676 --- /dev/null +++ b/libdaggy/include/daggy/executors/task/SSHTaskExecutor.hpp @@ -0,0 +1,60 @@ +#pragma once + +#include "ForkingTaskExecutor.hpp" +#include "TaskExecutor.hpp" + +namespace daggy::executors::task { + class SSHTaskExecutor : public TaskExecutor + { + public: + struct RemoteHost + { + size_t cores; + size_t memoryMB; + }; + + using Command = std::vector; + + explicit SSHTaskExecutor(std::unordered_map hosts); + + // Validates the job to ensure that all required values are set and are of + // the right type, + bool validateTaskParameters(const ConfigValues &job) override; + + std::vector expandTaskParameters( + const ConfigValues &job, const ConfigValues &expansionValues) override; + + // Runs the task + TaskFuture execute(DAGRunID runID, const std::string &taskName, + const Task &task) override; + + bool stop(DAGRunID runID, const std::string &taskName) override; + + std::string description() const override; + + private: + struct RunningTask + { + std::string host; + size_t cores; + size_t memoryMB; + TaskFuture fut; + TaskFuture feFuture; + int sshRetries; + DAGRunID runID; + std::string taskName; + Task task; + }; + + std::unordered_map hosts_; + ForkingTaskExecutor fe_; + std::mutex hostGuard_; + std::condition_variable hostCV_; + std::deque runningTasks_; + + void monitor(); + std::atomic running_; + std::thread monitorWorker_; + }; +} // namespace daggy::executors::task + diff --git a/libdaggy/src/executors/task/CMakeLists.txt b/libdaggy/src/executors/task/CMakeLists.txt index 96a8ada..6bfb756 100644 --- a/libdaggy/src/executors/task/CMakeLists.txt +++ b/libdaggy/src/executors/task/CMakeLists.txt @@ -3,4 +3,5 @@ target_sources(${PROJECT_NAME} PRIVATE NoopTaskExecutor.cpp ForkingTaskExecutor.cpp DaggyRunnerTaskExecutor.cpp + SSHTaskExecutor.cpp ) diff --git a/libdaggy/src/executors/task/SSHTaskExecutor.cpp b/libdaggy/src/executors/task/SSHTaskExecutor.cpp new file mode 100644 index 0000000..86958e1 --- /dev/null +++ b/libdaggy/src/executors/task/SSHTaskExecutor.cpp @@ -0,0 +1,142 @@ +#include +#include +#include +#include + +using namespace daggy::executors::task; + +SSHTaskExecutor::SSHTaskExecutor( + std::unordered_map hosts) + : hosts_(hosts) + , fe_(std::accumulate( + hosts_.begin(), hosts_.end(), 0UL, + [](size_t t, const auto &a) { return t + a.second.cores; })) + , running_(true) + , monitorWorker_(&SSHTaskExecutor::monitor, this) +{ +} + +std::string SSHTaskExecutor::description() const +{ + std::stringstream ss; + ss << "SSHTaskExecutor with total cores on " << hosts_.size() << " hosts"; + return ss.str(); +} + +bool SSHTaskExecutor::stop(DAGRunID runID, const std::string &taskName) +{ + return fe_.stop(runID, taskName); +} + +TaskFuture SSHTaskExecutor::execute(DAGRunID runID, const std::string &taskName, + const Task &task) +{ + std::vector newCommand{"ssh"}; + std::string user = ""; + if (task.job.count("user") > 1) + user = std::get(task.job.at("user")) + "@"; + if (task.job.count("port") > 1) { + newCommand.push_back("-p"); + newCommand.push_back(std::get(task.job.at("port"))); + } + + size_t coresNeeded = std::stoull(std::get(task.job.at("cores"))); + size_t memoryMBNeeded = + std::stoull(std::get(task.job.at("memoryMB"))); + RemoteHost *host; + std::string hostname = ""; + + // Block until a host is found + std::unique_lock lock(hostGuard_); + // Wait for a host to be available + hostCV_.wait(lock, [&] { + for (auto &r : hosts_) { + if (r.second.cores >= coresNeeded and + r.second.memoryMB >= memoryMBNeeded) { + host = &r.second; + hostname = r.first; + return true; + } + } + return false; + }); + host->cores -= coresNeeded; + host->memoryMB -= memoryMBNeeded; + Task sshTask{task}; + newCommand.push_back(user + hostname); + const auto oldCommand = + std::get>(task.job.at("command")); + std::copy(oldCommand.begin(), oldCommand.end(), + std::back_inserter(newCommand)); + + sshTask.job["command"] = newCommand; + + RunningTask rt{ + .host = hostname, + .cores = coresNeeded, + .memoryMB = memoryMBNeeded, + .fut = std::make_shared>(), + .feFuture = fe_.execute(runID, taskName, sshTask), + .sshRetries = 3, + .runID = runID, + .taskName = taskName, + .task = sshTask, + }; + + auto fut = rt.fut; + runningTasks_.emplace_back(std::move(rt)); + + return fut; +} + +void SSHTaskExecutor::monitor() +{ + while (running_) { + { + std::lock_guard lock(hostGuard_); + while (!runningTasks_.empty() and runningTasks_.front().fut->ready()) + runningTasks_.pop_front(); + + for (auto &rt : runningTasks_) { + if (rt.feFuture->ready() and !rt.fut->ready()) { + auto attempt = rt.feFuture->get(); + // SSH is a bit flakey, but will error with 255 if it doesn't work + if (attempt.rc == 255) { + --rt.sshRetries; + if (rt.sshRetries > 0) { + /* + std::cout << "Resubmitting: " << rt.sshRetries; + for (const auto &i : std::get>( + rt.task.job.at("command"))) + std::cout << " " << i; + std::cout << std::endl; + */ + rt.feFuture = fe_.execute(rt.runID, rt.taskName, rt.task); + continue; + } + } + rt.fut->set(rt.feFuture->get()); + hosts_[rt.host].cores += rt.cores; + hosts_[rt.host].memoryMB += rt.memoryMB; + hostCV_.notify_one(); + } + } + } + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + } +} + +bool SSHTaskExecutor::validateTaskParameters(const ConfigValues &job) +{ + return fe_.validateTaskParameters(job); + + // TODO add in requirement for memory, cores, users, privkey, port + + return true; +} + +std::vector SSHTaskExecutor::expandTaskParameters( + const ConfigValues &job, const ConfigValues &expansionValues) +{ + return fe_.expandTaskParameters(job, expansionValues); +}