From ddca9d3b721a57c5815c9657c94dc817a473c19d Mon Sep 17 00:00:00 2001 From: Ian Roddis <31021769+iroddis@users.noreply.github.com> Date: Sat, 31 May 2025 10:13:09 -0300 Subject: [PATCH] Adding JWT token auth --- .gitignore | 2 + CMakeLists.txt | 1 + daggyd/daggyd/daggyd.cpp | 7 ++ daggyd/libdaggyd/CMakeLists.txt | 2 +- daggyd/libdaggyd/include/daggyd/Server.hpp | 3 + daggyd/libdaggyd/src/Server.cpp | 83 ++++++++++++++++------ daggyd/tests/CMakeLists.txt | 1 + 7 files changed, 77 insertions(+), 22 deletions(-) diff --git a/.gitignore b/.gitignore index 2ac2e63..fb45b38 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ build .cache cmake-build-debug/ .idea + +**/.claude/settings.local.json diff --git a/CMakeLists.txt b/CMakeLists.txt index a594844..a11a31f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -32,6 +32,7 @@ include(cmake/pistache.cmake) include(cmake/better-enums.cmake) include(cmake/argparse.cmake) include(cmake/Catch2.cmake) +include(cmake/jwt-cpp.cmake) include(cmake/daggy_features.cmake) message("-- CMAKE Build Type is ${CMAKE_BUILD_TYPE}") diff --git a/daggyd/daggyd/daggyd.cpp b/daggyd/daggyd/daggyd.cpp index 12b8723..eeb17c7 100644 --- a/daggyd/daggyd/daggyd.cpp +++ b/daggyd/daggyd/daggyd.cpp @@ -269,6 +269,7 @@ int main(int argc, char **argv) args.add_argument("--ip").default_value(std::string{"127.0.0.1"}); args.add_argument("--port").default_value(2503u).action( [](const std::string &value) -> unsigned { return std::stoul(value); }); + args.add_argument("--jwt-secret").default_value(std::string{}); try { args.parse_args(argc, argv); @@ -285,6 +286,7 @@ int main(int argc, char **argv) auto staticAssetsDir = args.get("--assets-dir"); std::string listenIP = args.get("--ip"); auto listenPort = args.get("--port"); + auto jwtSecret = args.get("--jwt-secret"); size_t webThreads = 50; size_t dagThreads = 50; @@ -311,6 +313,8 @@ int main(int argc, char **argv) webThreads = doc["web-threads"].GetInt64(); if (doc.HasMember("dag-threads")) dagThreads = doc["dag-threads"].GetInt64(); + if (doc.HasMember("jwt-secret")) + jwtSecret = doc["jwt-secret"].GetString(); } else { doc.SetObject(); @@ -337,6 +341,9 @@ int main(int argc, char **argv) daggy::daggyd::Server server(listenSpec, *logger, *executor, dagThreads, staticAssetsDir); + if (!jwtSecret.empty()) { + server.setJWTSecret(jwtSecret); + } server.init(webThreads); server.start(); diff --git a/daggyd/libdaggyd/CMakeLists.txt b/daggyd/libdaggyd/CMakeLists.txt index d6fbcf6..fef2c63 100644 --- a/daggyd/libdaggyd/CMakeLists.txt +++ b/daggyd/libdaggyd/CMakeLists.txt @@ -3,6 +3,6 @@ project(libdaggyd) add_library(${PROJECT_NAME} STATIC) target_include_directories(${PROJECT_NAME} PUBLIC include) -target_link_libraries(${PROJECT_NAME} libdaggy stdc++fs) +target_link_libraries(${PROJECT_NAME} libdaggy jwt stdc++fs) add_subdirectory(src) diff --git a/daggyd/libdaggyd/include/daggyd/Server.hpp b/daggyd/libdaggyd/include/daggyd/Server.hpp index b16d182..39421b7 100644 --- a/daggyd/libdaggyd/include/daggyd/Server.hpp +++ b/daggyd/libdaggyd/include/daggyd/Server.hpp @@ -27,6 +27,7 @@ namespace daggy::daggyd { ~Server(); Server &setSSLCertificates(const fs::path &cert, const fs::path &key); + Server &setJWTSecret(const std::string &secret); void init(size_t threads = 1); @@ -69,5 +70,7 @@ namespace daggy::daggyd { std::mutex runnerGuard_; std::unordered_map> runners_; + + std::string jwtSecret_; }; } // namespace daggy::daggyd diff --git a/daggyd/libdaggyd/src/Server.cpp b/daggyd/libdaggyd/src/Server.cpp index 9eb9786..afa09ce 100644 --- a/daggyd/libdaggyd/src/Server.cpp +++ b/daggyd/libdaggyd/src/Server.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #define REQ_RESPONSE(code, msg) \ std::stringstream ss; \ @@ -81,6 +82,12 @@ namespace daggy::daggyd { return *this; } + Server &Server::setJWTSecret(const std::string &secret) + { + jwtSecret_ = secret; + return *this; + } + void Server::shutdown() { endpoint_.shutdown(); @@ -229,9 +236,10 @@ namespace daggy::daggyd { void Server::handleRunDAG(const Pistache::Rest::Request &request, Pistache::Http::ResponseWriter response) { - if (!handleAuth(request)) - return; addResponseHeaders(response); + if (!handleAuth(request)) { + REQ_RESPONSE(Unauthorized, "Authentication required"); + } DAGRunID runID = 0; try { @@ -272,8 +280,9 @@ namespace daggy::daggyd { Pistache::Http::ResponseWriter response) { addResponseHeaders(response); - if (!handleAuth(request)) - return; + if (!handleAuth(request)) { + REQ_RESPONSE(Unauthorized, "Authentication required"); + } bool all = false; std::string tag = ""; @@ -332,8 +341,9 @@ namespace daggy::daggyd { Pistache::Http::ResponseWriter response) { addResponseHeaders(response); - if (!handleAuth(request)) - return; + if (!handleAuth(request)) { + REQ_RESPONSE(Unauthorized, "Authentication required"); + } if (!request.hasParam(":runID")) { REQ_RESPONSE(Not_Found, "No runID provided in URL"); } @@ -413,8 +423,9 @@ namespace daggy::daggyd { Pistache::Http::ResponseWriter response) { addResponseHeaders(response); - if (!handleAuth(request)) - return; + if (!handleAuth(request)) { + REQ_RESPONSE(Unauthorized, "Authentication required"); + } if (!request.hasParam(":runID")) { REQ_RESPONSE(Not_Found, "No runID provided in URL"); } @@ -434,8 +445,10 @@ namespace daggy::daggyd { void Server::handleGetDAGRunState(const Pistache::Rest::Request &request, Pistache::Http::ResponseWriter response) { - if (!handleAuth(request)) - return; + addResponseHeaders(response); + if (!handleAuth(request)) { + REQ_RESPONSE(Unauthorized, "Authentication required"); + } DAGRunID runID = request.param(":runID").as(); RunState state = RunState::QUEUED; @@ -478,9 +491,10 @@ namespace daggy::daggyd { void Server::handleSetDAGRunState(const Pistache::Rest::Request &request, Pistache::Http::ResponseWriter response) { - if (!handleAuth(request)) - return; addResponseHeaders(response); + if (!handleAuth(request)) { + REQ_RESPONSE(Unauthorized, "Authentication required"); + } // TODO handle state transition DAGRunID runID = request.param(":runID").as(); @@ -535,8 +549,9 @@ namespace daggy::daggyd { Pistache::Http::ResponseWriter response) { addResponseHeaders(response); - if (!handleAuth(request)) - return; + if (!handleAuth(request)) { + REQ_RESPONSE(Unauthorized, "Authentication required"); + } auto runID = request.param(":runID").as(); auto taskName = request.param(":taskName").as(); @@ -556,8 +571,9 @@ namespace daggy::daggyd { Pistache::Http::ResponseWriter response) { addResponseHeaders(response); - if (!handleAuth(request)) - return; + if (!handleAuth(request)) { + REQ_RESPONSE(Unauthorized, "Authentication required"); + } auto runID = request.param(":runID").as(); auto taskName = request.param(":taskName").as(); @@ -581,8 +597,9 @@ namespace daggy::daggyd { Pistache::Http::ResponseWriter response) { addResponseHeaders(response); - if (!handleAuth(request)) - return; + if (!handleAuth(request)) { + REQ_RESPONSE(Unauthorized, "Authentication required"); + } auto runID = request.param(":runID").as(); auto taskName = request.param(":taskName").as(); @@ -604,8 +621,9 @@ namespace daggy::daggyd { Pistache::Http::ResponseWriter response) { addResponseHeaders(response); - if (!handleAuth(request)) - return; + if (!handleAuth(request)) { + REQ_RESPONSE(Unauthorized, "Authentication required"); + } // TODO implement handling of task state auto runID = request.param(":runID").as(); @@ -686,6 +704,29 @@ namespace daggy::daggyd { */ bool Server::handleAuth(const Pistache::Rest::Request &request) { - return true; + if (jwtSecret_.empty()) { + return true; + } + + auto authHeader = request.headers().tryGet(); + if (!authHeader) { + return false; + } + + std::string authValue = authHeader->value(); + if (authValue.length() < 7 || authValue.substr(0, 7) != "Bearer ") { + return false; + } + + std::string token = authValue.substr(7); + + try { + auto verifier = jwt::verify() + .allow_algorithm(jwt::algorithm::hs256{jwtSecret_}); + verifier.verify(jwt::decode(token)); + return true; + } catch (const std::exception&) { + return false; + } } } // namespace daggy::daggyd diff --git a/daggyd/tests/CMakeLists.txt b/daggyd/tests/CMakeLists.txt index 1de4f48..14dea21 100644 --- a/daggyd/tests/CMakeLists.txt +++ b/daggyd/tests/CMakeLists.txt @@ -3,6 +3,7 @@ project(daggyd_tests) add_executable(${PROJECT_NAME} main.cpp # unit tests unit_server.cpp + unit_jwt_auth.cpp ) target_link_libraries(${PROJECT_NAME} libdaggyd libdaggy stdc++fs Catch2::Catch2 curl)