Adding JWT token auth
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -2,3 +2,5 @@ build
|
|||||||
.cache
|
.cache
|
||||||
cmake-build-debug/
|
cmake-build-debug/
|
||||||
.idea
|
.idea
|
||||||
|
|
||||||
|
**/.claude/settings.local.json
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ include(cmake/pistache.cmake)
|
|||||||
include(cmake/better-enums.cmake)
|
include(cmake/better-enums.cmake)
|
||||||
include(cmake/argparse.cmake)
|
include(cmake/argparse.cmake)
|
||||||
include(cmake/Catch2.cmake)
|
include(cmake/Catch2.cmake)
|
||||||
|
include(cmake/jwt-cpp.cmake)
|
||||||
include(cmake/daggy_features.cmake)
|
include(cmake/daggy_features.cmake)
|
||||||
|
|
||||||
message("-- CMAKE Build Type is ${CMAKE_BUILD_TYPE}")
|
message("-- CMAKE Build Type is ${CMAKE_BUILD_TYPE}")
|
||||||
|
|||||||
@@ -269,6 +269,7 @@ int main(int argc, char **argv)
|
|||||||
args.add_argument("--ip").default_value(std::string{"127.0.0.1"});
|
args.add_argument("--ip").default_value(std::string{"127.0.0.1"});
|
||||||
args.add_argument("--port").default_value(2503u).action(
|
args.add_argument("--port").default_value(2503u).action(
|
||||||
[](const std::string &value) -> unsigned { return std::stoul(value); });
|
[](const std::string &value) -> unsigned { return std::stoul(value); });
|
||||||
|
args.add_argument("--jwt-secret").default_value(std::string{});
|
||||||
|
|
||||||
try {
|
try {
|
||||||
args.parse_args(argc, argv);
|
args.parse_args(argc, argv);
|
||||||
@@ -285,6 +286,7 @@ int main(int argc, char **argv)
|
|||||||
auto staticAssetsDir = args.get<std::string>("--assets-dir");
|
auto staticAssetsDir = args.get<std::string>("--assets-dir");
|
||||||
std::string listenIP = args.get<std::string>("--ip");
|
std::string listenIP = args.get<std::string>("--ip");
|
||||||
auto listenPort = args.get<unsigned>("--port");
|
auto listenPort = args.get<unsigned>("--port");
|
||||||
|
auto jwtSecret = args.get<std::string>("--jwt-secret");
|
||||||
size_t webThreads = 50;
|
size_t webThreads = 50;
|
||||||
size_t dagThreads = 50;
|
size_t dagThreads = 50;
|
||||||
|
|
||||||
@@ -311,6 +313,8 @@ int main(int argc, char **argv)
|
|||||||
webThreads = doc["web-threads"].GetInt64();
|
webThreads = doc["web-threads"].GetInt64();
|
||||||
if (doc.HasMember("dag-threads"))
|
if (doc.HasMember("dag-threads"))
|
||||||
dagThreads = doc["dag-threads"].GetInt64();
|
dagThreads = doc["dag-threads"].GetInt64();
|
||||||
|
if (doc.HasMember("jwt-secret"))
|
||||||
|
jwtSecret = doc["jwt-secret"].GetString();
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
doc.SetObject();
|
doc.SetObject();
|
||||||
@@ -337,6 +341,9 @@ int main(int argc, char **argv)
|
|||||||
|
|
||||||
daggy::daggyd::Server server(listenSpec, *logger, *executor, dagThreads,
|
daggy::daggyd::Server server(listenSpec, *logger, *executor, dagThreads,
|
||||||
staticAssetsDir);
|
staticAssetsDir);
|
||||||
|
if (!jwtSecret.empty()) {
|
||||||
|
server.setJWTSecret(jwtSecret);
|
||||||
|
}
|
||||||
server.init(webThreads);
|
server.init(webThreads);
|
||||||
server.start();
|
server.start();
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,6 @@ project(libdaggyd)
|
|||||||
add_library(${PROJECT_NAME} STATIC)
|
add_library(${PROJECT_NAME} STATIC)
|
||||||
|
|
||||||
target_include_directories(${PROJECT_NAME} PUBLIC include)
|
target_include_directories(${PROJECT_NAME} PUBLIC include)
|
||||||
target_link_libraries(${PROJECT_NAME} libdaggy stdc++fs)
|
target_link_libraries(${PROJECT_NAME} libdaggy jwt stdc++fs)
|
||||||
|
|
||||||
add_subdirectory(src)
|
add_subdirectory(src)
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ namespace daggy::daggyd {
|
|||||||
~Server();
|
~Server();
|
||||||
|
|
||||||
Server &setSSLCertificates(const fs::path &cert, const fs::path &key);
|
Server &setSSLCertificates(const fs::path &cert, const fs::path &key);
|
||||||
|
Server &setJWTSecret(const std::string &secret);
|
||||||
|
|
||||||
void init(size_t threads = 1);
|
void init(size_t threads = 1);
|
||||||
|
|
||||||
@@ -69,5 +70,7 @@ namespace daggy::daggyd {
|
|||||||
|
|
||||||
std::mutex runnerGuard_;
|
std::mutex runnerGuard_;
|
||||||
std::unordered_map<DAGRunID, std::shared_ptr<DAGRunner>> runners_;
|
std::unordered_map<DAGRunID, std::shared_ptr<DAGRunner>> runners_;
|
||||||
|
|
||||||
|
std::string jwtSecret_;
|
||||||
};
|
};
|
||||||
} // namespace daggy::daggyd
|
} // namespace daggy::daggyd
|
||||||
|
|||||||
@@ -10,6 +10,7 @@
|
|||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <jwt-cpp/jwt.h>
|
||||||
|
|
||||||
#define REQ_RESPONSE(code, msg) \
|
#define REQ_RESPONSE(code, msg) \
|
||||||
std::stringstream ss; \
|
std::stringstream ss; \
|
||||||
@@ -81,6 +82,12 @@ namespace daggy::daggyd {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Server &Server::setJWTSecret(const std::string &secret)
|
||||||
|
{
|
||||||
|
jwtSecret_ = secret;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
void Server::shutdown()
|
void Server::shutdown()
|
||||||
{
|
{
|
||||||
endpoint_.shutdown();
|
endpoint_.shutdown();
|
||||||
@@ -229,9 +236,10 @@ namespace daggy::daggyd {
|
|||||||
void Server::handleRunDAG(const Pistache::Rest::Request &request,
|
void Server::handleRunDAG(const Pistache::Rest::Request &request,
|
||||||
Pistache::Http::ResponseWriter response)
|
Pistache::Http::ResponseWriter response)
|
||||||
{
|
{
|
||||||
if (!handleAuth(request))
|
|
||||||
return;
|
|
||||||
addResponseHeaders(response);
|
addResponseHeaders(response);
|
||||||
|
if (!handleAuth(request)) {
|
||||||
|
REQ_RESPONSE(Unauthorized, "Authentication required");
|
||||||
|
}
|
||||||
|
|
||||||
DAGRunID runID = 0;
|
DAGRunID runID = 0;
|
||||||
try {
|
try {
|
||||||
@@ -272,8 +280,9 @@ namespace daggy::daggyd {
|
|||||||
Pistache::Http::ResponseWriter response)
|
Pistache::Http::ResponseWriter response)
|
||||||
{
|
{
|
||||||
addResponseHeaders(response);
|
addResponseHeaders(response);
|
||||||
if (!handleAuth(request))
|
if (!handleAuth(request)) {
|
||||||
return;
|
REQ_RESPONSE(Unauthorized, "Authentication required");
|
||||||
|
}
|
||||||
|
|
||||||
bool all = false;
|
bool all = false;
|
||||||
std::string tag = "";
|
std::string tag = "";
|
||||||
@@ -332,8 +341,9 @@ namespace daggy::daggyd {
|
|||||||
Pistache::Http::ResponseWriter response)
|
Pistache::Http::ResponseWriter response)
|
||||||
{
|
{
|
||||||
addResponseHeaders(response);
|
addResponseHeaders(response);
|
||||||
if (!handleAuth(request))
|
if (!handleAuth(request)) {
|
||||||
return;
|
REQ_RESPONSE(Unauthorized, "Authentication required");
|
||||||
|
}
|
||||||
if (!request.hasParam(":runID")) {
|
if (!request.hasParam(":runID")) {
|
||||||
REQ_RESPONSE(Not_Found, "No runID provided in URL");
|
REQ_RESPONSE(Not_Found, "No runID provided in URL");
|
||||||
}
|
}
|
||||||
@@ -413,8 +423,9 @@ namespace daggy::daggyd {
|
|||||||
Pistache::Http::ResponseWriter response)
|
Pistache::Http::ResponseWriter response)
|
||||||
{
|
{
|
||||||
addResponseHeaders(response);
|
addResponseHeaders(response);
|
||||||
if (!handleAuth(request))
|
if (!handleAuth(request)) {
|
||||||
return;
|
REQ_RESPONSE(Unauthorized, "Authentication required");
|
||||||
|
}
|
||||||
if (!request.hasParam(":runID")) {
|
if (!request.hasParam(":runID")) {
|
||||||
REQ_RESPONSE(Not_Found, "No runID provided in URL");
|
REQ_RESPONSE(Not_Found, "No runID provided in URL");
|
||||||
}
|
}
|
||||||
@@ -434,8 +445,10 @@ namespace daggy::daggyd {
|
|||||||
void Server::handleGetDAGRunState(const Pistache::Rest::Request &request,
|
void Server::handleGetDAGRunState(const Pistache::Rest::Request &request,
|
||||||
Pistache::Http::ResponseWriter response)
|
Pistache::Http::ResponseWriter response)
|
||||||
{
|
{
|
||||||
if (!handleAuth(request))
|
addResponseHeaders(response);
|
||||||
return;
|
if (!handleAuth(request)) {
|
||||||
|
REQ_RESPONSE(Unauthorized, "Authentication required");
|
||||||
|
}
|
||||||
|
|
||||||
DAGRunID runID = request.param(":runID").as<DAGRunID>();
|
DAGRunID runID = request.param(":runID").as<DAGRunID>();
|
||||||
RunState state = RunState::QUEUED;
|
RunState state = RunState::QUEUED;
|
||||||
@@ -478,9 +491,10 @@ namespace daggy::daggyd {
|
|||||||
void Server::handleSetDAGRunState(const Pistache::Rest::Request &request,
|
void Server::handleSetDAGRunState(const Pistache::Rest::Request &request,
|
||||||
Pistache::Http::ResponseWriter response)
|
Pistache::Http::ResponseWriter response)
|
||||||
{
|
{
|
||||||
if (!handleAuth(request))
|
|
||||||
return;
|
|
||||||
addResponseHeaders(response);
|
addResponseHeaders(response);
|
||||||
|
if (!handleAuth(request)) {
|
||||||
|
REQ_RESPONSE(Unauthorized, "Authentication required");
|
||||||
|
}
|
||||||
|
|
||||||
// TODO handle state transition
|
// TODO handle state transition
|
||||||
DAGRunID runID = request.param(":runID").as<DAGRunID>();
|
DAGRunID runID = request.param(":runID").as<DAGRunID>();
|
||||||
@@ -535,8 +549,9 @@ namespace daggy::daggyd {
|
|||||||
Pistache::Http::ResponseWriter response)
|
Pistache::Http::ResponseWriter response)
|
||||||
{
|
{
|
||||||
addResponseHeaders(response);
|
addResponseHeaders(response);
|
||||||
if (!handleAuth(request))
|
if (!handleAuth(request)) {
|
||||||
return;
|
REQ_RESPONSE(Unauthorized, "Authentication required");
|
||||||
|
}
|
||||||
|
|
||||||
auto runID = request.param(":runID").as<DAGRunID>();
|
auto runID = request.param(":runID").as<DAGRunID>();
|
||||||
auto taskName = request.param(":taskName").as<std::string>();
|
auto taskName = request.param(":taskName").as<std::string>();
|
||||||
@@ -556,8 +571,9 @@ namespace daggy::daggyd {
|
|||||||
Pistache::Http::ResponseWriter response)
|
Pistache::Http::ResponseWriter response)
|
||||||
{
|
{
|
||||||
addResponseHeaders(response);
|
addResponseHeaders(response);
|
||||||
if (!handleAuth(request))
|
if (!handleAuth(request)) {
|
||||||
return;
|
REQ_RESPONSE(Unauthorized, "Authentication required");
|
||||||
|
}
|
||||||
|
|
||||||
auto runID = request.param(":runID").as<DAGRunID>();
|
auto runID = request.param(":runID").as<DAGRunID>();
|
||||||
auto taskName = request.param(":taskName").as<std::string>();
|
auto taskName = request.param(":taskName").as<std::string>();
|
||||||
@@ -581,8 +597,9 @@ namespace daggy::daggyd {
|
|||||||
Pistache::Http::ResponseWriter response)
|
Pistache::Http::ResponseWriter response)
|
||||||
{
|
{
|
||||||
addResponseHeaders(response);
|
addResponseHeaders(response);
|
||||||
if (!handleAuth(request))
|
if (!handleAuth(request)) {
|
||||||
return;
|
REQ_RESPONSE(Unauthorized, "Authentication required");
|
||||||
|
}
|
||||||
|
|
||||||
auto runID = request.param(":runID").as<DAGRunID>();
|
auto runID = request.param(":runID").as<DAGRunID>();
|
||||||
auto taskName = request.param(":taskName").as<std::string>();
|
auto taskName = request.param(":taskName").as<std::string>();
|
||||||
@@ -604,8 +621,9 @@ namespace daggy::daggyd {
|
|||||||
Pistache::Http::ResponseWriter response)
|
Pistache::Http::ResponseWriter response)
|
||||||
{
|
{
|
||||||
addResponseHeaders(response);
|
addResponseHeaders(response);
|
||||||
if (!handleAuth(request))
|
if (!handleAuth(request)) {
|
||||||
return;
|
REQ_RESPONSE(Unauthorized, "Authentication required");
|
||||||
|
}
|
||||||
|
|
||||||
// TODO implement handling of task state
|
// TODO implement handling of task state
|
||||||
auto runID = request.param(":runID").as<DAGRunID>();
|
auto runID = request.param(":runID").as<DAGRunID>();
|
||||||
@@ -686,6 +704,29 @@ namespace daggy::daggyd {
|
|||||||
*/
|
*/
|
||||||
bool Server::handleAuth(const Pistache::Rest::Request &request)
|
bool Server::handleAuth(const Pistache::Rest::Request &request)
|
||||||
{
|
{
|
||||||
return true;
|
if (jwtSecret_.empty()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto authHeader = request.headers().tryGet<Pistache::Http::Header::Authorization>();
|
||||||
|
if (!authHeader) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string authValue = authHeader->value();
|
||||||
|
if (authValue.length() < 7 || authValue.substr(0, 7) != "Bearer ") {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string token = authValue.substr(7);
|
||||||
|
|
||||||
|
try {
|
||||||
|
auto verifier = jwt::verify()
|
||||||
|
.allow_algorithm(jwt::algorithm::hs256{jwtSecret_});
|
||||||
|
verifier.verify(jwt::decode(token));
|
||||||
|
return true;
|
||||||
|
} catch (const std::exception&) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} // namespace daggy::daggyd
|
} // namespace daggy::daggyd
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ project(daggyd_tests)
|
|||||||
add_executable(${PROJECT_NAME} main.cpp
|
add_executable(${PROJECT_NAME} main.cpp
|
||||||
# unit tests
|
# unit tests
|
||||||
unit_server.cpp
|
unit_server.cpp
|
||||||
|
unit_jwt_auth.cpp
|
||||||
)
|
)
|
||||||
target_link_libraries(${PROJECT_NAME} libdaggyd libdaggy stdc++fs Catch2::Catch2 curl)
|
target_link_libraries(${PROJECT_NAME} libdaggyd libdaggy stdc++fs Catch2::Catch2 curl)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user