Adding JWT token auth

This commit is contained in:
Ian Roddis
2025-05-31 10:13:09 -03:00
parent 539dcebbe1
commit ddca9d3b72
7 changed files with 77 additions and 22 deletions

2
.gitignore vendored
View File

@@ -2,3 +2,5 @@ build
.cache
cmake-build-debug/
.idea
**/.claude/settings.local.json

View File

@@ -32,6 +32,7 @@ include(cmake/pistache.cmake)
include(cmake/better-enums.cmake)
include(cmake/argparse.cmake)
include(cmake/Catch2.cmake)
include(cmake/jwt-cpp.cmake)
include(cmake/daggy_features.cmake)
message("-- CMAKE Build Type is ${CMAKE_BUILD_TYPE}")

View File

@@ -269,6 +269,7 @@ int main(int argc, char **argv)
args.add_argument("--ip").default_value(std::string{"127.0.0.1"});
args.add_argument("--port").default_value(2503u).action(
[](const std::string &value) -> unsigned { return std::stoul(value); });
args.add_argument("--jwt-secret").default_value(std::string{});
try {
args.parse_args(argc, argv);
@@ -285,6 +286,7 @@ int main(int argc, char **argv)
auto staticAssetsDir = args.get<std::string>("--assets-dir");
std::string listenIP = args.get<std::string>("--ip");
auto listenPort = args.get<unsigned>("--port");
auto jwtSecret = args.get<std::string>("--jwt-secret");
size_t webThreads = 50;
size_t dagThreads = 50;
@@ -311,6 +313,8 @@ int main(int argc, char **argv)
webThreads = doc["web-threads"].GetInt64();
if (doc.HasMember("dag-threads"))
dagThreads = doc["dag-threads"].GetInt64();
if (doc.HasMember("jwt-secret"))
jwtSecret = doc["jwt-secret"].GetString();
}
else {
doc.SetObject();
@@ -337,6 +341,9 @@ int main(int argc, char **argv)
daggy::daggyd::Server server(listenSpec, *logger, *executor, dagThreads,
staticAssetsDir);
if (!jwtSecret.empty()) {
server.setJWTSecret(jwtSecret);
}
server.init(webThreads);
server.start();

View File

@@ -3,6 +3,6 @@ project(libdaggyd)
add_library(${PROJECT_NAME} STATIC)
target_include_directories(${PROJECT_NAME} PUBLIC include)
target_link_libraries(${PROJECT_NAME} libdaggy stdc++fs)
target_link_libraries(${PROJECT_NAME} libdaggy jwt stdc++fs)
add_subdirectory(src)

View File

@@ -27,6 +27,7 @@ namespace daggy::daggyd {
~Server();
Server &setSSLCertificates(const fs::path &cert, const fs::path &key);
Server &setJWTSecret(const std::string &secret);
void init(size_t threads = 1);
@@ -69,5 +70,7 @@ namespace daggy::daggyd {
std::mutex runnerGuard_;
std::unordered_map<DAGRunID, std::shared_ptr<DAGRunner>> runners_;
std::string jwtSecret_;
};
} // namespace daggy::daggyd

View File

@@ -10,6 +10,7 @@
#include <stdexcept>
#include <thread>
#include <utility>
#include <jwt-cpp/jwt.h>
#define REQ_RESPONSE(code, msg) \
std::stringstream ss; \
@@ -81,6 +82,12 @@ namespace daggy::daggyd {
return *this;
}
Server &Server::setJWTSecret(const std::string &secret)
{
jwtSecret_ = secret;
return *this;
}
void Server::shutdown()
{
endpoint_.shutdown();
@@ -229,9 +236,10 @@ namespace daggy::daggyd {
void Server::handleRunDAG(const Pistache::Rest::Request &request,
Pistache::Http::ResponseWriter response)
{
if (!handleAuth(request))
return;
addResponseHeaders(response);
if (!handleAuth(request)) {
REQ_RESPONSE(Unauthorized, "Authentication required");
}
DAGRunID runID = 0;
try {
@@ -272,8 +280,9 @@ namespace daggy::daggyd {
Pistache::Http::ResponseWriter response)
{
addResponseHeaders(response);
if (!handleAuth(request))
return;
if (!handleAuth(request)) {
REQ_RESPONSE(Unauthorized, "Authentication required");
}
bool all = false;
std::string tag = "";
@@ -332,8 +341,9 @@ namespace daggy::daggyd {
Pistache::Http::ResponseWriter response)
{
addResponseHeaders(response);
if (!handleAuth(request))
return;
if (!handleAuth(request)) {
REQ_RESPONSE(Unauthorized, "Authentication required");
}
if (!request.hasParam(":runID")) {
REQ_RESPONSE(Not_Found, "No runID provided in URL");
}
@@ -413,8 +423,9 @@ namespace daggy::daggyd {
Pistache::Http::ResponseWriter response)
{
addResponseHeaders(response);
if (!handleAuth(request))
return;
if (!handleAuth(request)) {
REQ_RESPONSE(Unauthorized, "Authentication required");
}
if (!request.hasParam(":runID")) {
REQ_RESPONSE(Not_Found, "No runID provided in URL");
}
@@ -434,8 +445,10 @@ namespace daggy::daggyd {
void Server::handleGetDAGRunState(const Pistache::Rest::Request &request,
Pistache::Http::ResponseWriter response)
{
if (!handleAuth(request))
return;
addResponseHeaders(response);
if (!handleAuth(request)) {
REQ_RESPONSE(Unauthorized, "Authentication required");
}
DAGRunID runID = request.param(":runID").as<DAGRunID>();
RunState state = RunState::QUEUED;
@@ -478,9 +491,10 @@ namespace daggy::daggyd {
void Server::handleSetDAGRunState(const Pistache::Rest::Request &request,
Pistache::Http::ResponseWriter response)
{
if (!handleAuth(request))
return;
addResponseHeaders(response);
if (!handleAuth(request)) {
REQ_RESPONSE(Unauthorized, "Authentication required");
}
// TODO handle state transition
DAGRunID runID = request.param(":runID").as<DAGRunID>();
@@ -535,8 +549,9 @@ namespace daggy::daggyd {
Pistache::Http::ResponseWriter response)
{
addResponseHeaders(response);
if (!handleAuth(request))
return;
if (!handleAuth(request)) {
REQ_RESPONSE(Unauthorized, "Authentication required");
}
auto runID = request.param(":runID").as<DAGRunID>();
auto taskName = request.param(":taskName").as<std::string>();
@@ -556,8 +571,9 @@ namespace daggy::daggyd {
Pistache::Http::ResponseWriter response)
{
addResponseHeaders(response);
if (!handleAuth(request))
return;
if (!handleAuth(request)) {
REQ_RESPONSE(Unauthorized, "Authentication required");
}
auto runID = request.param(":runID").as<DAGRunID>();
auto taskName = request.param(":taskName").as<std::string>();
@@ -581,8 +597,9 @@ namespace daggy::daggyd {
Pistache::Http::ResponseWriter response)
{
addResponseHeaders(response);
if (!handleAuth(request))
return;
if (!handleAuth(request)) {
REQ_RESPONSE(Unauthorized, "Authentication required");
}
auto runID = request.param(":runID").as<DAGRunID>();
auto taskName = request.param(":taskName").as<std::string>();
@@ -604,8 +621,9 @@ namespace daggy::daggyd {
Pistache::Http::ResponseWriter response)
{
addResponseHeaders(response);
if (!handleAuth(request))
return;
if (!handleAuth(request)) {
REQ_RESPONSE(Unauthorized, "Authentication required");
}
// TODO implement handling of task state
auto runID = request.param(":runID").as<DAGRunID>();
@@ -686,6 +704,29 @@ namespace daggy::daggyd {
*/
bool Server::handleAuth(const Pistache::Rest::Request &request)
{
if (jwtSecret_.empty()) {
return true;
}
auto authHeader = request.headers().tryGet<Pistache::Http::Header::Authorization>();
if (!authHeader) {
return false;
}
std::string authValue = authHeader->value();
if (authValue.length() < 7 || authValue.substr(0, 7) != "Bearer ") {
return false;
}
std::string token = authValue.substr(7);
try {
auto verifier = jwt::verify()
.allow_algorithm(jwt::algorithm::hs256{jwtSecret_});
verifier.verify(jwt::decode(token));
return true;
} catch (const std::exception&) {
return false;
}
}
} // namespace daggy::daggyd

View File

@@ -3,6 +3,7 @@ project(daggyd_tests)
add_executable(${PROJECT_NAME} main.cpp
# unit tests
unit_server.cpp
unit_jwt_auth.cpp
)
target_link_libraries(${PROJECT_NAME} libdaggyd libdaggy stdc++fs Catch2::Catch2 curl)