Adapting the agent
This commit is contained in:
parent
0d6cea4152
commit
ca9a32c032
+15
-14
@@ -7,8 +7,11 @@ use serde::Serialize;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
use config::*;
|
||||
use waterfall::executors::agent_executor::TaskSubmission;
|
||||
use waterfall::prelude::*;
|
||||
|
||||
type TaskDetails = serde_json::Value;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct SimpleError {
|
||||
error: String,
|
||||
@@ -19,30 +22,28 @@ async fn get_resources(data: web::Data<GlobalConfig>) -> impl Responder {
|
||||
}
|
||||
|
||||
async fn submit_task(
|
||||
details: web::Json<TaskDetails>,
|
||||
details: web::Json<TaskSubmission>,
|
||||
data: web::Data<GlobalConfig>,
|
||||
) -> impl Responder {
|
||||
let (response, mut rx) = mpsc::unbounded_channel();
|
||||
let (response, rx) = oneshot::channel();
|
||||
|
||||
let submission = data.into_inner();
|
||||
|
||||
let trx = data.tracker.clone();
|
||||
|
||||
data.executor
|
||||
.send(ExecutorMessage::ExecuteTask {
|
||||
details: details.into_inner(),
|
||||
output_options: TaskOutputOptions::default(),
|
||||
tracker: trx,
|
||||
details: submission.details,
|
||||
output_options: submission.output_options,
|
||||
varmap: submission.varmap,
|
||||
response,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
match rx.recv().await.unwrap() {
|
||||
RunnerMessage::ExecutionReport { attempt, .. } => HttpResponse::Ok().json(attempt),
|
||||
other => HttpResponse::BadRequest().json(SimpleError {
|
||||
error: format!("Unexpected message {:?}", other),
|
||||
}),
|
||||
}
|
||||
HttpResponse::Ok().json(rx.await.unwrap())
|
||||
}
|
||||
|
||||
/*
|
||||
async fn stop_task(
|
||||
path: web::Path<(RunID, TaskID)>,
|
||||
data: web::Data<GlobalConfig>,
|
||||
@@ -61,6 +62,7 @@ async fn stop_task(
|
||||
rx.await.unwrap();
|
||||
HttpResponse::Ok()
|
||||
}
|
||||
*/
|
||||
|
||||
async fn ready() -> impl Responder {
|
||||
HttpResponse::Ok()
|
||||
@@ -148,8 +150,7 @@ async fn main() -> std::io::Result<()> {
|
||||
.service(
|
||||
web::scope("/api/v1")
|
||||
.route("/resources", web::get().to(get_resources))
|
||||
.route("/{run_id}/{task_id}", web::post().to(submit_task))
|
||||
.route("/{run_id}/{task_id}", web::delete().to(stop_task)),
|
||||
.route("/run", web::post().to(submit_task)),
|
||||
)
|
||||
})
|
||||
.bind(config.listen_spec())?
|
||||
@@ -157,7 +158,7 @@ async fn main() -> std::io::Result<()> {
|
||||
.await;
|
||||
|
||||
config.executor.send(ExecutorMessage::Stop {}).unwrap();
|
||||
config.tracker.send(TrackerMessage::Stop {}).unwrap();
|
||||
config.storage.send(StorageMessage::Stop {}).unwrap();
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
@@ -102,6 +102,14 @@ fn validate_task(details: &TaskDetails, max_capacities: &[TaskResources]) -> Res
|
||||
}
|
||||
}
|
||||
|
||||
/// Contains specifics on how to run a local task
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct TaskSubmission {
|
||||
details: TaskDetails,
|
||||
varmap: VarMap,
|
||||
output_options: TaskOutputOptions,
|
||||
}
|
||||
|
||||
async fn submit_task(
|
||||
base_url: String,
|
||||
details: TaskDetails,
|
||||
@@ -111,7 +119,12 @@ async fn submit_task(
|
||||
) -> TaskAttempt {
|
||||
let submit_url = format!("{}/run", base_url);
|
||||
let mut attempt = TaskAttempt::new();
|
||||
match client.post(submit_url).json(&details).send().await {
|
||||
let submission = TaskSubmission {
|
||||
details,
|
||||
varmap,
|
||||
output_options,
|
||||
};
|
||||
match client.post(submit_url).json(&submission).send().await {
|
||||
Ok(result) => {
|
||||
if result.status() == reqwest::StatusCode::OK {
|
||||
attempt = result.json().await.unwrap();
|
||||
|
||||
Reference in New Issue
Block a user