#include #include #include #include #include #include #include #include #include #include #include #include #include namespace rj = rapidjson; using namespace daggy; #ifdef DEBUG_HTTP static int my_trace(CURL *handle, curl_infotype type, char *data, size_t size, void *userp) { const char *text; (void)handle; /* prevent compiler warning */ (void)userp; switch (type) { case CURLINFO_TEXT: fprintf(stderr, "== Info: %s", data); default: /* in case a new one is introduced to shock us */ return 0; case CURLINFO_HEADER_OUT: text = "=> Send header"; break; case CURLINFO_DATA_OUT: text = "=> Send data"; break; case CURLINFO_SSL_DATA_OUT: text = "=> Send SSL data"; break; case CURLINFO_HEADER_IN: text = "<= Recv header"; break; case CURLINFO_DATA_IN: text = "<= Recv data"; break; case CURLINFO_SSL_DATA_IN: text = "<= Recv SSL data"; break; } std::cerr << "\n================== " << text << " ==================" << std::endl << data << std::endl; return 0; } #endif enum HTTPCode : long { Ok = 200, Not_Found = 404 }; struct HTTPResponse { HTTPCode code; std::string body; }; uint curlWriter(char *in, uint size, uint nmemb, std::stringstream *out) { uint r; r = size * nmemb; out->write(in, r); return r; } HTTPResponse REQUEST(const std::string &url, const std::string &payload = "", const std::string &method = "GET") { HTTPResponse response; CURL *curl; CURLcode res; struct curl_slist *headers = NULL; curl_global_init(CURL_GLOBAL_ALL); curl = curl_easy_init(); if (curl) { std::stringstream buffer; #ifdef DEBUG_HTTP curl_easy_setopt(curl, CURLOPT_DEBUGFUNCTION, my_trace); curl_easy_setopt(curl, CURLOPT_VERBOSE, 1L); #endif curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, curlWriter); curl_easy_setopt(curl, CURLOPT_WRITEDATA, &buffer); if (!payload.empty()) { curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, payload.size()); curl_easy_setopt(curl, CURLOPT_POSTFIELDS, payload.c_str()); headers = curl_slist_append(headers, "Content-Type: Application/Json"); } curl_easy_setopt(curl, CURLOPT_CUSTOMREQUEST, method.c_str()); headers = curl_slist_append(headers, "Expect:"); curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); res = curl_easy_perform(curl); if (res != CURLE_OK) { curl_easy_cleanup(curl); throw std::runtime_error(std::string{"CURL Failed: "} + curl_easy_strerror(res)); } curl_easy_cleanup(curl); curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &response.code); response.body = buffer.str(); } curl_global_cleanup(); return response; } TEST_CASE("rest_endpoint", "[server_basic]") { std::stringstream ss; daggy::executors::task::ForkingTaskExecutor executor(10); daggy::loggers::dag_run::OStreamLogger logger(ss); Pistache::Address listenSpec("localhost", Pistache::Port(0)); const size_t nDAGRunners = 10, nWebThreads = 10; daggy::Server server(listenSpec, logger, executor, nDAGRunners); server.init(nWebThreads); server.start(); const std::string host = "localhost:"; const std::string baseURL = host + std::to_string(server.getPort()); SECTION("Ready Endpoint") { auto response = REQUEST(baseURL + "/ready"); REQUIRE(response.code == HTTPCode::Ok); } SECTION("Querying a non-existent dagrunid should fail ") { auto response = REQUEST(baseURL + "/v1/dagrun/100"); REQUIRE(response.code != HTTPCode::Ok); } SECTION("Simple DAGRun Submission") { std::string dagRun = R"({ "tag": "unit_server", "parameters": { "FILE": [ "A", "B" ] }, "tasks": { "touch": { "job": { "command": [ "/usr/bin/touch", "dagrun_{{FILE}}" ]} }, "cat": { "job": { "command": [ "/usr/bin/cat", "dagrun_A", "dagrun_B" ]}, "parents": [ "touch" ] } } })"; auto dagSpec = daggy::dagFromJSON(dagRun); // Submit, and get the runID daggy::DAGRunID runID = 0; { auto response = REQUEST(baseURL + "/v1/dagrun/", dagRun, "POST"); REQUIRE(response.code == HTTPCode::Ok); rj::Document doc; daggy::checkRJParse(doc.Parse(response.body.c_str())); REQUIRE(doc.IsObject()); REQUIRE(doc.HasMember("runID")); runID = doc["runID"].GetUint64(); } // Ensure our runID shows up in the list of running DAGs { auto response = REQUEST(baseURL + "/v1/dagruns?all=1"); REQUIRE(response.code == HTTPCode::Ok); rj::Document doc; daggy::checkRJParse(doc.Parse(response.body.c_str())); REQUIRE(doc.IsArray()); REQUIRE(doc.Size() >= 1); // Ensure that our DAG is in the list and matches our given DAGRunID bool found = false; const auto &runs = doc.GetArray(); for (size_t i = 0; i < runs.Size(); ++i) { const auto &run = runs[i]; REQUIRE(run.IsObject()); REQUIRE(run.HasMember("tag")); REQUIRE(run.HasMember("runID")); std::string runName = run["tag"].GetString(); if (runName == "unit_server") { REQUIRE(run["runID"].GetUint64() == runID); found = true; break; } } REQUIRE(found); } // Ensure we can get one of our tasks { auto response = REQUEST(baseURL + "/v1/dagrun/" + std::to_string(runID) + "/task/cat_0"); REQUIRE(response.code == HTTPCode::Ok); rj::Document doc; daggy::checkRJParse(doc.Parse(response.body.c_str())); REQUIRE_NOTHROW(daggy::taskFromJSON("cat", doc)); auto task = daggy::taskFromJSON("cat", doc); REQUIRE(task == dagSpec.tasks.at("cat")); } // Wait until our DAG is complete bool complete = true; for (auto i = 0; i < 10; ++i) { auto response = REQUEST(baseURL + "/v1/dagrun/" + std::to_string(runID)); REQUIRE(response.code == HTTPCode::Ok); rj::Document doc; daggy::checkRJParse(doc.Parse(response.body.c_str())); REQUIRE(doc.IsObject()); REQUIRE(doc.HasMember("taskStates")); const auto &taskStates = doc["taskStates"].GetObject(); size_t nStates = 0; for (auto it = taskStates.MemberBegin(); it != taskStates.MemberEnd(); ++it) { nStates++; } REQUIRE(nStates == 3); complete = true; for (auto it = taskStates.MemberBegin(); it != taskStates.MemberEnd(); ++it) { std::string state = it->value.GetString(); if (state != "COMPLETED") { complete = false; break; } } if (complete) break; std::this_thread::sleep_for(std::chrono::seconds(1)); } REQUIRE(complete); std::this_thread::sleep_for(std::chrono::seconds(2)); for (const auto &pth : std::vector{"dagrun_A", "dagrun_B"}) { REQUIRE(fs::exists(pth)); fs::remove(pth); } } } TEST_CASE("Server cancels and resumes execution", "[server_resume]") { std::stringstream ss; daggy::executors::task::ForkingTaskExecutor executor(10); daggy::loggers::dag_run::OStreamLogger logger(ss); Pistache::Address listenSpec("localhost", Pistache::Port(0)); const size_t nDAGRunners = 10, nWebThreads = 10; daggy::Server server(listenSpec, logger, executor, nDAGRunners); server.init(nWebThreads); server.start(); const std::string host = "localhost:"; const std::string baseURL = host + std::to_string(server.getPort()); SECTION("Cancel / Resume DAGRun") { std::string dagRunJSON = R"({ "tag": "unit_server", "tasks": { "touch_A": { "job": { "command": [ "/usr/bin/touch", "resume_touch_a" ]}, "children": ["touch_C"] }, "sleep_B": { "job": { "command": [ "/usr/bin/sleep", "3" ]}, "children": ["touch_C"] }, "touch_C": { "job": { "command": [ "/usr/bin/touch", "resume_touch_c" ]} } } })"; auto dagSpec = daggy::dagFromJSON(dagRunJSON); // Submit, and get the runID daggy::DAGRunID runID; { auto response = REQUEST(baseURL + "/v1/dagrun/", dagRunJSON, "POST"); REQUIRE(response.code == HTTPCode::Ok); rj::Document doc; daggy::checkRJParse(doc.Parse(response.body.c_str())); REQUIRE(doc.IsObject()); REQUIRE(doc.HasMember("runID")); runID = doc["runID"].GetUint64(); } std::this_thread::sleep_for(1s); // Stop the current run { auto response = REQUEST( baseURL + "/v1/dagrun/" + std::to_string(runID) + "/state/KILLED", "", "PATCH"); REQUIRE(response.code == HTTPCode::Ok); REQUIRE(logger.getDAGRunState(runID) == +daggy::RunState::KILLED); } // Verify that the run still exists { auto dagRun = logger.getDAGRun(runID); REQUIRE(dagRun.taskRunStates.at("touch_A_0") == +daggy::RunState::COMPLETED); REQUIRE(fs::exists("resume_touch_a")); REQUIRE(dagRun.taskRunStates.at("sleep_B_0") == +daggy::RunState::ERRORED); REQUIRE(dagRun.taskRunStates.at("touch_C_0") == +daggy::RunState::QUEUED); } // Set the errored task state { auto url = baseURL + "/v1/dagrun/" + std::to_string(runID) + "/task/sleep_B_0/state/QUEUED"; auto response = REQUEST(url, "", "PATCH"); REQUIRE(response.code == HTTPCode::Ok); REQUIRE(logger.getTaskState(runID, "sleep_B_0") == +daggy::RunState::QUEUED); } // Resume { struct stat s; lstat("resume_touch_A", &s); auto preMTime = s.st_mtim.tv_sec; auto response = REQUEST( baseURL + "/v1/dagrun/" + std::to_string(runID) + "/state/QUEUED", "", "PATCH"); // Wait for run to complete std::this_thread::sleep_for(5s); REQUIRE(logger.getDAGRunState(runID) == +daggy::RunState::COMPLETED); REQUIRE(fs::exists("resume_touch_c")); REQUIRE(fs::exists("resume_touch_a")); for (const auto &[taskName, task] : dagSpec.tasks) { REQUIRE(logger.getTaskState(runID, taskName + "_0") == +daggy::RunState::COMPLETED); } // Ensure "touch_A" wasn't run again lstat("resume_touch_A", &s); auto postMTime = s.st_mtim.tv_sec; REQUIRE(preMTime == postMTime); } } server.shutdown(); }