Adding JWT token auth

This commit is contained in:
Ian Roddis
2025-05-31 10:13:09 -03:00
parent 539dcebbe1
commit ddca9d3b72
7 changed files with 77 additions and 22 deletions

2
.gitignore vendored
View File

@@ -2,3 +2,5 @@ build
.cache .cache
cmake-build-debug/ cmake-build-debug/
.idea .idea
**/.claude/settings.local.json

View File

@@ -32,6 +32,7 @@ include(cmake/pistache.cmake)
include(cmake/better-enums.cmake) include(cmake/better-enums.cmake)
include(cmake/argparse.cmake) include(cmake/argparse.cmake)
include(cmake/Catch2.cmake) include(cmake/Catch2.cmake)
include(cmake/jwt-cpp.cmake)
include(cmake/daggy_features.cmake) include(cmake/daggy_features.cmake)
message("-- CMAKE Build Type is ${CMAKE_BUILD_TYPE}") message("-- CMAKE Build Type is ${CMAKE_BUILD_TYPE}")

View File

@@ -269,6 +269,7 @@ int main(int argc, char **argv)
args.add_argument("--ip").default_value(std::string{"127.0.0.1"}); args.add_argument("--ip").default_value(std::string{"127.0.0.1"});
args.add_argument("--port").default_value(2503u).action( args.add_argument("--port").default_value(2503u).action(
[](const std::string &value) -> unsigned { return std::stoul(value); }); [](const std::string &value) -> unsigned { return std::stoul(value); });
args.add_argument("--jwt-secret").default_value(std::string{});
try { try {
args.parse_args(argc, argv); args.parse_args(argc, argv);
@@ -285,6 +286,7 @@ int main(int argc, char **argv)
auto staticAssetsDir = args.get<std::string>("--assets-dir"); auto staticAssetsDir = args.get<std::string>("--assets-dir");
std::string listenIP = args.get<std::string>("--ip"); std::string listenIP = args.get<std::string>("--ip");
auto listenPort = args.get<unsigned>("--port"); auto listenPort = args.get<unsigned>("--port");
auto jwtSecret = args.get<std::string>("--jwt-secret");
size_t webThreads = 50; size_t webThreads = 50;
size_t dagThreads = 50; size_t dagThreads = 50;
@@ -311,6 +313,8 @@ int main(int argc, char **argv)
webThreads = doc["web-threads"].GetInt64(); webThreads = doc["web-threads"].GetInt64();
if (doc.HasMember("dag-threads")) if (doc.HasMember("dag-threads"))
dagThreads = doc["dag-threads"].GetInt64(); dagThreads = doc["dag-threads"].GetInt64();
if (doc.HasMember("jwt-secret"))
jwtSecret = doc["jwt-secret"].GetString();
} }
else { else {
doc.SetObject(); doc.SetObject();
@@ -337,6 +341,9 @@ int main(int argc, char **argv)
daggy::daggyd::Server server(listenSpec, *logger, *executor, dagThreads, daggy::daggyd::Server server(listenSpec, *logger, *executor, dagThreads,
staticAssetsDir); staticAssetsDir);
if (!jwtSecret.empty()) {
server.setJWTSecret(jwtSecret);
}
server.init(webThreads); server.init(webThreads);
server.start(); server.start();

View File

@@ -3,6 +3,6 @@ project(libdaggyd)
add_library(${PROJECT_NAME} STATIC) add_library(${PROJECT_NAME} STATIC)
target_include_directories(${PROJECT_NAME} PUBLIC include) target_include_directories(${PROJECT_NAME} PUBLIC include)
target_link_libraries(${PROJECT_NAME} libdaggy stdc++fs) target_link_libraries(${PROJECT_NAME} libdaggy jwt stdc++fs)
add_subdirectory(src) add_subdirectory(src)

View File

@@ -27,6 +27,7 @@ namespace daggy::daggyd {
~Server(); ~Server();
Server &setSSLCertificates(const fs::path &cert, const fs::path &key); Server &setSSLCertificates(const fs::path &cert, const fs::path &key);
Server &setJWTSecret(const std::string &secret);
void init(size_t threads = 1); void init(size_t threads = 1);
@@ -69,5 +70,7 @@ namespace daggy::daggyd {
std::mutex runnerGuard_; std::mutex runnerGuard_;
std::unordered_map<DAGRunID, std::shared_ptr<DAGRunner>> runners_; std::unordered_map<DAGRunID, std::shared_ptr<DAGRunner>> runners_;
std::string jwtSecret_;
}; };
} // namespace daggy::daggyd } // namespace daggy::daggyd

View File

@@ -10,6 +10,7 @@
#include <stdexcept> #include <stdexcept>
#include <thread> #include <thread>
#include <utility> #include <utility>
#include <jwt-cpp/jwt.h>
#define REQ_RESPONSE(code, msg) \ #define REQ_RESPONSE(code, msg) \
std::stringstream ss; \ std::stringstream ss; \
@@ -81,6 +82,12 @@ namespace daggy::daggyd {
return *this; return *this;
} }
Server &Server::setJWTSecret(const std::string &secret)
{
jwtSecret_ = secret;
return *this;
}
void Server::shutdown() void Server::shutdown()
{ {
endpoint_.shutdown(); endpoint_.shutdown();
@@ -229,9 +236,10 @@ namespace daggy::daggyd {
void Server::handleRunDAG(const Pistache::Rest::Request &request, void Server::handleRunDAG(const Pistache::Rest::Request &request,
Pistache::Http::ResponseWriter response) Pistache::Http::ResponseWriter response)
{ {
if (!handleAuth(request))
return;
addResponseHeaders(response); addResponseHeaders(response);
if (!handleAuth(request)) {
REQ_RESPONSE(Unauthorized, "Authentication required");
}
DAGRunID runID = 0; DAGRunID runID = 0;
try { try {
@@ -272,8 +280,9 @@ namespace daggy::daggyd {
Pistache::Http::ResponseWriter response) Pistache::Http::ResponseWriter response)
{ {
addResponseHeaders(response); addResponseHeaders(response);
if (!handleAuth(request)) if (!handleAuth(request)) {
return; REQ_RESPONSE(Unauthorized, "Authentication required");
}
bool all = false; bool all = false;
std::string tag = ""; std::string tag = "";
@@ -332,8 +341,9 @@ namespace daggy::daggyd {
Pistache::Http::ResponseWriter response) Pistache::Http::ResponseWriter response)
{ {
addResponseHeaders(response); addResponseHeaders(response);
if (!handleAuth(request)) if (!handleAuth(request)) {
return; REQ_RESPONSE(Unauthorized, "Authentication required");
}
if (!request.hasParam(":runID")) { if (!request.hasParam(":runID")) {
REQ_RESPONSE(Not_Found, "No runID provided in URL"); REQ_RESPONSE(Not_Found, "No runID provided in URL");
} }
@@ -413,8 +423,9 @@ namespace daggy::daggyd {
Pistache::Http::ResponseWriter response) Pistache::Http::ResponseWriter response)
{ {
addResponseHeaders(response); addResponseHeaders(response);
if (!handleAuth(request)) if (!handleAuth(request)) {
return; REQ_RESPONSE(Unauthorized, "Authentication required");
}
if (!request.hasParam(":runID")) { if (!request.hasParam(":runID")) {
REQ_RESPONSE(Not_Found, "No runID provided in URL"); REQ_RESPONSE(Not_Found, "No runID provided in URL");
} }
@@ -434,8 +445,10 @@ namespace daggy::daggyd {
void Server::handleGetDAGRunState(const Pistache::Rest::Request &request, void Server::handleGetDAGRunState(const Pistache::Rest::Request &request,
Pistache::Http::ResponseWriter response) Pistache::Http::ResponseWriter response)
{ {
if (!handleAuth(request)) addResponseHeaders(response);
return; if (!handleAuth(request)) {
REQ_RESPONSE(Unauthorized, "Authentication required");
}
DAGRunID runID = request.param(":runID").as<DAGRunID>(); DAGRunID runID = request.param(":runID").as<DAGRunID>();
RunState state = RunState::QUEUED; RunState state = RunState::QUEUED;
@@ -478,9 +491,10 @@ namespace daggy::daggyd {
void Server::handleSetDAGRunState(const Pistache::Rest::Request &request, void Server::handleSetDAGRunState(const Pistache::Rest::Request &request,
Pistache::Http::ResponseWriter response) Pistache::Http::ResponseWriter response)
{ {
if (!handleAuth(request))
return;
addResponseHeaders(response); addResponseHeaders(response);
if (!handleAuth(request)) {
REQ_RESPONSE(Unauthorized, "Authentication required");
}
// TODO handle state transition // TODO handle state transition
DAGRunID runID = request.param(":runID").as<DAGRunID>(); DAGRunID runID = request.param(":runID").as<DAGRunID>();
@@ -535,8 +549,9 @@ namespace daggy::daggyd {
Pistache::Http::ResponseWriter response) Pistache::Http::ResponseWriter response)
{ {
addResponseHeaders(response); addResponseHeaders(response);
if (!handleAuth(request)) if (!handleAuth(request)) {
return; REQ_RESPONSE(Unauthorized, "Authentication required");
}
auto runID = request.param(":runID").as<DAGRunID>(); auto runID = request.param(":runID").as<DAGRunID>();
auto taskName = request.param(":taskName").as<std::string>(); auto taskName = request.param(":taskName").as<std::string>();
@@ -556,8 +571,9 @@ namespace daggy::daggyd {
Pistache::Http::ResponseWriter response) Pistache::Http::ResponseWriter response)
{ {
addResponseHeaders(response); addResponseHeaders(response);
if (!handleAuth(request)) if (!handleAuth(request)) {
return; REQ_RESPONSE(Unauthorized, "Authentication required");
}
auto runID = request.param(":runID").as<DAGRunID>(); auto runID = request.param(":runID").as<DAGRunID>();
auto taskName = request.param(":taskName").as<std::string>(); auto taskName = request.param(":taskName").as<std::string>();
@@ -581,8 +597,9 @@ namespace daggy::daggyd {
Pistache::Http::ResponseWriter response) Pistache::Http::ResponseWriter response)
{ {
addResponseHeaders(response); addResponseHeaders(response);
if (!handleAuth(request)) if (!handleAuth(request)) {
return; REQ_RESPONSE(Unauthorized, "Authentication required");
}
auto runID = request.param(":runID").as<DAGRunID>(); auto runID = request.param(":runID").as<DAGRunID>();
auto taskName = request.param(":taskName").as<std::string>(); auto taskName = request.param(":taskName").as<std::string>();
@@ -604,8 +621,9 @@ namespace daggy::daggyd {
Pistache::Http::ResponseWriter response) Pistache::Http::ResponseWriter response)
{ {
addResponseHeaders(response); addResponseHeaders(response);
if (!handleAuth(request)) if (!handleAuth(request)) {
return; REQ_RESPONSE(Unauthorized, "Authentication required");
}
// TODO implement handling of task state // TODO implement handling of task state
auto runID = request.param(":runID").as<DAGRunID>(); auto runID = request.param(":runID").as<DAGRunID>();
@@ -686,6 +704,29 @@ namespace daggy::daggyd {
*/ */
bool Server::handleAuth(const Pistache::Rest::Request &request) bool Server::handleAuth(const Pistache::Rest::Request &request)
{ {
return true; if (jwtSecret_.empty()) {
return true;
}
auto authHeader = request.headers().tryGet<Pistache::Http::Header::Authorization>();
if (!authHeader) {
return false;
}
std::string authValue = authHeader->value();
if (authValue.length() < 7 || authValue.substr(0, 7) != "Bearer ") {
return false;
}
std::string token = authValue.substr(7);
try {
auto verifier = jwt::verify()
.allow_algorithm(jwt::algorithm::hs256{jwtSecret_});
verifier.verify(jwt::decode(token));
return true;
} catch (const std::exception&) {
return false;
}
} }
} // namespace daggy::daggyd } // namespace daggy::daggyd

View File

@@ -3,6 +3,7 @@ project(daggyd_tests)
add_executable(${PROJECT_NAME} main.cpp add_executable(${PROJECT_NAME} main.cpp
# unit tests # unit tests
unit_server.cpp unit_server.cpp
unit_jwt_auth.cpp
) )
target_link_libraries(${PROJECT_NAME} libdaggyd libdaggy stdc++fs Catch2::Catch2 curl) target_link_libraries(${PROJECT_NAME} libdaggyd libdaggy stdc++fs Catch2::Catch2 curl)