From 779852022ac2044553c250ed548069024ada73fa Mon Sep 17 00:00:00 2001 From: Kinesin Data Technologies Incorporated <93931750+kinesintech@users.noreply.github.com> Date: Wed, 5 Oct 2022 08:36:16 -0300 Subject: [PATCH] Adding agent executor --- Cargo.toml | 1 + src/bin/wf/main.rs | 8 +- src/executors/agent_executor.rs | 289 ++++++++++++++++++++++++++++++++ src/executors/mod.rs | 1 + src/task.rs | 59 +++++++ 5 files changed, 357 insertions(+), 1 deletion(-) create mode 100644 src/executors/agent_executor.rs diff --git a/Cargo.toml b/Cargo.toml index c64e877..8a3b07c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,3 +20,4 @@ sysinfo = "0.23" redis = { version = "*", features = ["aio", "tokio-comp"] } clap = { version = "3.1", features = ["derive"] } env_logger = "0.9" +log = "0.4" diff --git a/src/bin/wf/main.rs b/src/bin/wf/main.rs index 18ef9bb..f8f8098 100644 --- a/src/bin/wf/main.rs +++ b/src/bin/wf/main.rs @@ -29,7 +29,12 @@ impl StorageConfig { #[derive(Serialize, Deserialize, Debug)] #[serde(rename_all = "snake_case", deny_unknown_fields, tag = "type")] enum ExecutorConfig { - Local { workers: usize }, + Local { + workers: usize, + }, + Agent { + targets: Vec, + }, } impl ExecutorConfig { @@ -42,6 +47,7 @@ impl ExecutorConfig { let (tx, rx) = mpsc::unbounded_channel(); match self { ExecutorConfig::Local { workers } => (tx, local_executor::start(*workers, rx)), + ExecutorConfig::Agent { targets } => (tx, agent_executor::start(targets.clone(), rx)), } } } diff --git a/src/executors/agent_executor.rs b/src/executors/agent_executor.rs new file mode 100644 index 0000000..bcf78b5 --- /dev/null +++ b/src/executors/agent_executor.rs @@ -0,0 +1,289 @@ +//! The Agent executor is essentially a wrapped version of the local executor. +//! It dispatches tasks to remote hosts + +use super::*; +use futures::stream::futures_unordered::FuturesUnordered; +use log::{info, warn}; +use serde::{Deserialize, Serialize}; +use tokio::sync::{mpsc, oneshot}; + +use futures::StreamExt; + +fn default_as_true() -> bool { + true +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct AgentTarget { + pub base_url: String, + + #[serde(default)] + pub resources: TaskResources, + + #[serde(default)] + pub current_resources: TaskResources, + + #[serde(default)] + pub enabled: bool, +} + +impl AgentTarget { + fn new(base_url: String, resources: TaskResources) -> Self { + AgentTarget { + base_url, + resources: resources.clone(), + current_resources: resources, + enabled: true, + } + } + + async fn refresh_resources(&mut self, client: &reqwest::Client) { + let resource_url = format!("{}/resources", self.base_url); + let disabled = match client.get(resource_url).send().await { + Ok(result) => { + if result.status() == reqwest::StatusCode::OK { + self.resources = result.json().await.unwrap(); + self.current_resources = self.resources.clone(); + false + } else { + true + } + } + Err(_) => true, + }; + if self.enabled && disabled { + warn!("Disabling {}: unable to refresh resources", self.base_url); + } + self.enabled = !disabled; + } + + async fn ping(&mut self, client: &reqwest::Client) -> Result<()> { + let resource_url = format!("{}/ready", self.base_url); + let result = client.get(resource_url).send().await?; + self.enabled = result.status() == reqwest::StatusCode::OK; + Ok(()) + } +} + +/// Contains specifics on how to run a local task +#[derive(Serialize, Deserialize, Clone, Debug)] +struct AgentTaskDetail { + /// The command and all arguments to run + #[serde(default)] + command: Vec, + + /// Environment variables to set + #[serde(default)] + environment: HashMap, + + /// Timeout in seconds + #[serde(default)] + timeout: i64, + + /// resources required by the task + resources: TaskResources, +} + +fn extract_details(details: &TaskDetails) -> Result { + serde_json::from_value::(details.clone()) +} + +fn validate_task(details: &TaskDetails, max_capacities: &[TaskResources]) -> Result<()> { + let parsed = extract_details(details)?; + if max_capacities.is_empty() + || max_capacities.iter().all(|x| x.values().all(|x| *x == 0)) + || max_capacities + .iter() + .any(|x| x.can_satisfy(&parsed.resources)) + { + Ok(()) + } else { + Err(anyhow!("No Agent target satisfies the required resources")) + } +} + +async fn submit_task( + base_url: String, + details: TaskDetails, + output_options: TaskOutputOptions, + client: reqwest::Client, + varmap: VarMap, +) -> TaskAttempt { + let submit_url = format!("{}/run", base_url); + let mut attempt = TaskAttempt::new(); + match client.post(submit_url).json(&details).send().await { + Ok(result) => { + if result.status() == reqwest::StatusCode::OK { + attempt = result.json().await.unwrap(); + attempt + .executor + .push(format!("Executed on agent at {}", base_url)); + } else { + attempt.succeeded = false; + attempt.infra_failure = true; + attempt.executor.push(format!( + "Unable to dispatch to agent at {}: {:?}", + base_url, + result.text().await.unwrap() + )); + } + } + Err(e) => { + attempt.succeeded = false; + attempt.infra_failure = true; + attempt.executor.push(format!( + "Unable to dispatch to agent at {}: {:?}", + base_url, e + )); + } + } + + attempt +} + +// async fn select_target() -> Option {} + +struct RunningTask { + resources: TaskResources, + target_id: usize, +} + +/// The mpsc channel can be sized to fit max parallelism +async fn start_agent_executor( + mut targets: Vec, + mut exe_msgs: mpsc::UnboundedReceiver, +) { + let client = reqwest::Client::new(); + + for target in &mut targets { + target.refresh_resources(&client).await; + } + let mut max_caps: Vec = targets.iter().map(|x| x.resources.clone()).collect(); + + // Set up the local executor + let (le_tx, le_rx) = mpsc::unbounded_channel(); + local_executor::start(1, le_rx); + + // Tasks waiting to release resources + let mut running = FuturesUnordered::new(); + + while let Some(msg) = exe_msgs.recv().await { + use ExecutorMessage::*; + match msg { + ValidateTask { details, response } => { + let ltx = le_tx.clone(); + let caps = max_caps.clone(); + tokio::spawn(async move { + let result = validate_task(&details, &caps); + if result.is_err() { + response.send(result).unwrap_or(()); + } else { + ltx.send(ValidateTask { details, response }).unwrap_or(()); + } + }); + } + ExecuteTask { + task_name, + interval, + details, + varmap, + output_options, + storage, + response, + kill, + } => { + let task = extract_details(&details).unwrap(); + let resources = task.resources.clone(); + + loop { + match targets.iter_mut().enumerate().find(|(_, x)| { + x.enabled && x.current_resources.can_satisfy(&task.resources) + }) { + // There is a remote agent with capacity + Some((tid, target)) => { + target.current_resources.sub(&resources).unwrap(); + let base_url = target.base_url.clone(); + let submit_client = client.clone(); + running.push(tokio::spawn(async move { + let attempt = submit_task( + base_url, + details, + output_options, + submit_client, + varmap, + ) + .await; + let rc = attempt.succeeded; + storage + .send(StorageMessage::StoreAttempt { + task_name, + interval, + attempt, + }) + .unwrap(); + response.send(rc).unwrap(); + (tid, resources, rc) + })); + break; + } + // No agent has capacity + None => { + // Give the outstanding tasks a chance to complete or agents + // recover + tokio::time::sleep(tokio::time::Duration::from_millis(250)).await; + info!("Waiting to run message"); + + // Refresh any disabled targets + for (tid, target) in targets.iter_mut().enumerate() { + if target.enabled { + info!("Skipping {} as it is enabled", target.base_url); + continue; + } + target.refresh_resources(&client).await; + if target.enabled { + max_caps[tid] = target.resources.clone(); + info!("{} is now enabled.", target.base_url); + } + } + + // Wait for the next item + if !running.is_empty() { + let result: Result< + (usize, TaskResources, bool), + tokio::task::JoinError, + > = running.next().await.unwrap(); + + let (tid, resources, submit_ok) = result.unwrap(); + if !submit_ok { + warn!( + "Disabling agent at {} due to incomplete submission.", + targets[tid].base_url + ); + targets[tid].enabled = false; + } + targets[tid].current_resources.add(&resources); + } + } + } + } + } + /* + msg @ StopTask { .. } => { + le_tx.send(msg).unwrap_or(()); + } + */ + Stop {} => { + break; + } + } + } +} + +pub fn start( + targets: Vec, + msgs: mpsc::UnboundedReceiver, +) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + start_agent_executor(targets, msgs).await; + }) +} diff --git a/src/executors/mod.rs b/src/executors/mod.rs index 5058392..17ceb0d 100644 --- a/src/executors/mod.rs +++ b/src/executors/mod.rs @@ -1,4 +1,5 @@ use super::*; +pub mod agent_executor; pub mod local_executor; /// Messages for interacting with an Executor diff --git a/src/task.rs b/src/task.rs index 8628445..035c11d 100644 --- a/src/task.rs +++ b/src/task.rs @@ -1,4 +1,63 @@ use super::*; +use std::ops::{Deref, DerefMut}; + +#[derive(Clone, Debug, Serialize, Deserialize, Default)] +pub struct TaskResources(HashMap); + +impl Deref for TaskResources { + type Target = HashMap; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for TaskResources { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl TaskResources { + #[must_use] + pub fn new() -> Self { + TaskResources(HashMap::new()) + } + + #[must_use] + pub fn can_satisfy(&self, requirements: &TaskResources) -> bool { + requirements + .iter() + .all(|(k, v)| self.contains_key(k) && self[k] >= *v) + } + + /// Subtracts resources from available resources. + /// # Errors + /// Returns an `Err` if the requested resources cannot be fulfilled + /// # Panics + /// It doesn't, keys are checked for ahead-of-time + pub fn sub(&mut self, resources: &TaskResources) -> Result<()> { + if self.can_satisfy(resources) { + for (k, v) in resources.iter() { + *self.get_mut(k).unwrap() -= v; + } + Ok(()) + } else { + Err(anyhow!("Cannot satisfy requested resources")) + } + } + + /// # Panics + /// It doesn't, keys are checked for ahead-of-time + pub fn add(&mut self, resources: &TaskResources) { + for (k, v) in resources.iter() { + if self.contains_key(k) { + *self.get_mut(k).unwrap() += *v; + } else { + self.insert(k.clone(), *v); + } + } + } +} /// Defines the struct to parse for tasks #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]