//! Upstream pool management for reverse-proxy load balancing.
//!
//! Manages a named pool of backend servers with weighted load balancing,
//! active health tracking, and automatic fail-over to backup servers.
//!
//! All mutating operations are protected by [`parking_lot::RwLock`], making
//! [`UpstreamPool`] safe to share across threads (e.g. via `Arc`).
use std::net::SocketAddr;
use std::time::{Duration, Instant};
use parking_lot::RwLock;
use tracing::{debug, warn};
use super::load_balance::{create_balancer, Algorithm, LoadBalancer};
// ---------------------------------------------------------------------------
// ServerState
// ---------------------------------------------------------------------------
/// Health state of an upstream server.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ServerState {
/// Server is healthy and eligible for new connections.
Alive,
/// Server has been administratively disabled and will not receive traffic.
Down,
/// Server has been temporarily disabled because it exceeded `max_fails`
/// within the `fail_timeout` window. It will automatically recover once
/// the timeout elapses.
Unavailable,
}
impl Default for ServerState {
fn default() -> Self {
Self::Alive
}
}
// ---------------------------------------------------------------------------
// UpstreamServer
// ---------------------------------------------------------------------------
/// A single backend server in an upstream pool.
///
/// Carries the server address, load-balancing weight, failure thresholds,
/// and current health state. Per-server failure counters are managed
/// internally by [`UpstreamPool`] to ensure thread-safe access.
#[derive(Debug, Clone)]
pub struct UpstreamServer {
/// Socket address of the backend server.
pub addr: SocketAddr,
/// Load-balancing weight. Higher values receive proportionally more
/// traffic. Minimum 1.
pub weight: u32,
/// Number of failures within `fail_timeout` required to mark the server
/// as [`ServerState::Unavailable`]. Minimum 1.
pub max_fails: u32,
/// Window for failure counting and automatic recovery.
pub fail_timeout: Duration,
/// When `true`, this server is only used when all primary (non-backup)
/// servers are unavailable.
pub backup: bool,
/// Current health state.
pub state: ServerState,
}
impl UpstreamServer {
/// Create a new server with sensible defaults:
///
/// | Field | Default |
/// |---------------|---------------|
/// | `weight` | `1` |
/// | `max_fails` | `1` |
/// | `fail_timeout`| 10 seconds |
/// | `backup` | `false` |
/// | `state` | `Alive` |
pub fn new(addr: SocketAddr) -> Self {
Self {
addr,
weight: 1,
max_fails: 1,
fail_timeout: Duration::from_secs(10),
backup: false,
state: ServerState::Alive,
}
}
/// Builder: set the load-balancing weight (clamped to minimum 1).
pub fn with_weight(mut self, weight: u32) -> Self {
self.weight = weight.max(1);
self
}
/// Builder: set the maximum failure count (clamped to minimum 1).
pub fn with_max_fails(mut self, max_fails: u32) -> Self {
self.max_fails = max_fails.max(1);
self
}
/// Builder: set the fail-timeout duration.
pub fn with_fail_timeout(mut self, timeout: Duration) -> Self {
self.fail_timeout = timeout;
self
}
/// Builder: mark this server as a backup.
pub fn with_backup(mut self, backup: bool) -> Self {
self.backup = backup;
self
}
/// Builder: set the initial health state.
pub fn with_state(mut self, state: ServerState) -> Self {
self.state = state;
self
}
/// Returns `true` if the server is not administratively disabled.
#[inline]
pub fn is_alive(&self) -> bool {
self.state != ServerState::Down
}
}
// ---------------------------------------------------------------------------
// ServerTracker (private)
// ---------------------------------------------------------------------------
/// Mutable per-server health tracking data.
///
/// Kept separate from [`UpstreamServer`] so the public struct can be
/// cheaply cloned / inspected without exposing internal counters.
#[derive(Debug)]
struct ServerTracker {
/// Consecutive failure count (resets on success or timeout expiry).
fails: u32,
/// Timestamp of the most recent failure.
last_fail: Option<Instant>,
}
impl ServerTracker {
fn new() -> Self {
Self {
fails: 0,
last_fail: None,
}
}
/// Returns `true` when the server should receive new connections.
fn is_available(&self, server: &UpstreamServer) -> bool {
match server.state {
ServerState::Down => false,
ServerState::Alive => self.fails < server.max_fails,
ServerState::Unavailable => match self.last_fail {
Some(last) => last.elapsed() >= server.fail_timeout,
None => true,
},
}
}
/// Record a failed request.
///
/// If the failure count reaches `max_fails` the server transitions to
/// [`ServerState::Unavailable`]. If the previous failure occurred
/// outside the `fail_timeout` window the counter is reset first, so a
/// recovered server does not immediately become unavailable again.
fn record_failure(&mut self, server: &mut UpstreamServer) {
// Reset counter when the fail-timeout has elapsed since the last
// failure (treat as a fresh start).
if let Some(last) = self.last_fail {
if last.elapsed() >= server.fail_timeout {
self.fails = 0;
}
}
self.fails += 1;
self.last_fail = Some(Instant::now());
if self.fails >= server.max_fails && server.state == ServerState::Alive {
server.state = ServerState::Unavailable;
warn!(
addr = %server.addr,
fails = self.fails,
max_fails = server.max_fails,
fail_timeout = ?server.fail_timeout,
"upstream server marked unavailable"
);
}
}
/// Record a successful request.
///
/// Resets the failure counter and transitions an `Unavailable` server
/// back to `Alive`.
fn record_success(&mut self, server: &mut UpstreamServer) {
if self.fails > 0 || server.state == ServerState::Unavailable {
debug!(
addr = %server.addr,
prev_fails = self.fails,
"upstream server recovered"
);
}
self.fails = 0;
self.last_fail = None;
if server.state == ServerState::Unavailable {
server.state = ServerState::Alive;
}
}
}
// ---------------------------------------------------------------------------
// UpstreamPool
// ---------------------------------------------------------------------------
/// A named pool of upstream servers with load balancing and health tracking.
///
/// Thread-safe: the server list and per-server health counters are behind a
/// [`parking_lot::RwLock`]. The [`LoadBalancer`] uses atomics internally.
/// Share across tasks with `Arc<UpstreamPool>`.
///
/// # Example
///
/// ```ignore
/// use std::net::SocketAddr;
/// use std::sync::Arc;
///
/// let pool = Arc::new(
/// UpstreamPool::new(
/// "backend",
/// vec![
/// UpstreamServer::new("10.0.0.1:8080".parse().unwrap())
/// .with_weight(3),
/// UpstreamServer::new("10.0.0.2:8080".parse().unwrap()),
/// ],
/// Algorithm::WeightedRoundRobin,
/// )
/// .with_keepalive(64),
/// );
///
/// // In a request handler:
/// if let Some(server) = pool.next() {
/// // proxy to server.addr
/// }
/// ```
pub struct UpstreamPool {
/// Human-readable pool name (used in log messages).
name: String,
/// Server list with paired health trackers, guarded for concurrent access.
servers: RwLock<Vec<(UpstreamServer, ServerTracker)>>,
/// Load balancer (uses atomics internally, no lock needed).
balancer: Box<dyn LoadBalancer>,
/// Maximum keepalive connections per upstream server.
keepalive: usize,
}
impl UpstreamPool {
/// Create a new upstream pool.
///
/// # Panics
/// Panics if `servers` is empty.
pub fn new(
name: impl Into<String>,
servers: Vec<UpstreamServer>,
algorithm: Algorithm,
) -> Self {
assert!(
!servers.is_empty(),
"upstream pool requires at least one server"
);
let name = name.into();
let count = servers.len();
let entries: Vec<_> = servers
.into_iter()
.map(|s| (s, ServerTracker::new()))
.collect();
debug!(
pool = %name,
count = count,
algorithm = ?algorithm,
"upstream pool created"
);
Self {
name,
servers: RwLock::new(entries),
balancer: create_balancer(algorithm, None, count),
keepalive: 32,
}
}
/// Set the number of keepalive connections per upstream server.
pub fn with_keepalive(mut self, keepalive: usize) -> Self {
self.keepalive = keepalive;
self
}
/// Returns the pool name.
pub fn name(&self) -> &str {
&self.name
}
/// Returns the configured per-server keepalive count.
pub fn keepalive(&self) -> usize {
self.keepalive
}
// -- server selection -------------------------------------------------
/// Select the next available server via the configured load balancer.
///
/// Selection strategy:
/// 1. Filter to primary (non-backup) servers that are currently available.
/// 2. If no primary servers are available, fall back to backup servers.
/// 3. If nothing is available at all, return `None`.
pub fn next(&self) -> Option<UpstreamServer> {
let servers = self.servers.read();
// Try primary servers first.
if let Some(server) = self.select_from_group(&servers, false) {
return Some(server);
}
// Fall back to backups.
if let Some(server) = self.select_from_group(&servers, true) {
warn!(pool = %self.name, "all primary servers down, using backup");
return Some(server);
}
warn!(pool = %self.name, "no upstream servers available");
None
}
// -- health management ------------------------------------------------
/// Record a failure for the server at `addr`.
///
/// Increments the failure counter. When the counter reaches `max_fails`
/// the server transitions to [`ServerState::Unavailable`] and will not
/// receive new connections until `fail_timeout` elapses.
pub fn mark_failed(&self, addr: &SocketAddr) {
let mut servers = self.servers.write();
if let Some((server, tracker)) = servers.iter_mut().find(|(s, _)| &s.addr == addr) {
tracker.record_failure(server);
}
}
/// Record a successful response from the server at `addr`.
///
/// Resets the failure counter. If the server was `Unavailable` it is
/// restored to `Alive`.
pub fn mark_success(&self, addr: &SocketAddr) {
let mut servers = self.servers.write();
if let Some((server, tracker)) = servers.iter_mut().find(|(s, _)| &s.addr == addr) {
tracker.record_success(server);
}
}
// -- introspection ----------------------------------------------------
/// Return clones of all currently available servers.
pub fn active_servers(&self) -> Vec<UpstreamServer> {
let servers = self.servers.read();
servers
.iter()
.filter(|(s, t)| t.is_available(s))
.map(|(s, _)| s.clone())
.collect()
}
/// Return clones of every server in the pool regardless of state.
pub fn all_servers(&self) -> Vec<UpstreamServer> {
self.servers.read().iter().map(|(s, _)| s.clone()).collect()
}
/// Total number of servers in the pool.
pub fn server_count(&self) -> usize {
self.servers.read().len()
}
/// Number of servers currently accepting connections.
pub fn available_count(&self) -> usize {
self.servers
.read()
.iter()
.filter(|(s, t)| t.is_available(s))
.count()
}
// -- private ----------------------------------------------------------
/// Pick one server from the primary or backup group using the load
/// balancer. Returns a clone so the lock can be released.
///
/// Two-phase filtering: `ServerTracker.is_available()` handles health
/// state, then the load balancer skips any remaining `Down` servers.
fn select_from_group(
&self,
servers: &[(UpstreamServer, ServerTracker)],
backup: bool,
) -> Option<UpstreamServer> {
// Phase 1: filter by backup flag and health-tracker availability.
let candidates: Vec<(usize, &UpstreamServer)> = servers
.iter()
.enumerate()
.filter(|(_, (s, t))| s.backup == backup && t.is_available(s))
.map(|(i, (s, _))| (i, s))
.collect();
if candidates.is_empty() {
return None;
}
// Phase 2: pass candidates to the load balancer, which applies
// weighted selection and skips any `Down` servers.
let candidate_refs: Vec<UpstreamServer> =
candidates.iter().map(|(_, s)| (*s).clone()).collect();
let selected = self.balancer.next_server(&candidate_refs)?;
Some(candidates[selected].1.clone())
}
}
impl std::fmt::Debug for UpstreamPool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UpstreamPool")
.field("name", &self.name)
.field("server_count", &self.server_count())
.field("keepalive", &self.keepalive)
.field("algorithm", &self.balancer.algorithm())
.finish()
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
use std::net::SocketAddr;
fn addr(port: u16) -> SocketAddr {
SocketAddr::new([127, 0, 0, 1].into(), port)
}
fn pool_with_servers(addrs: &[u16]) -> UpstreamPool {
let servers: Vec<_> = addrs
.iter()
.map(|&p| UpstreamServer::new(addr(p)))
.collect();
UpstreamPool::new("test", servers, Algorithm::RoundRobin)
}
#[test]
fn next_returns_a_server() {
let pool = pool_with_servers(&[8001, 8002, 8003]);
let server = pool.next().expect("should select a server");
assert!(server.addr.port() >= 8001 && server.addr.port() <= 8003);
}
#[test]
fn next_cycles_through_servers() {
let pool = pool_with_servers(&[8001, 8002, 8003]);
let ports: Vec<u16> = (0..6)
.filter_map(|_| pool.next().map(|s| s.addr.port()))
.collect();
// With round-robin over 3 servers we should see the pattern repeat.
assert_eq!(ports[0], ports[3]);
assert_eq!(ports[1], ports[4]);
assert_eq!(ports[2], ports[5]);
}
#[test]
fn mark_failed_disables_server_after_max_fails() {
let pool = pool_with_servers(&[8001]);
pool.mark_failed(&addr(8001));
// max_fails defaults to 1, so the server should now be unavailable.
assert!(pool.next().is_none());
}
#[test]
fn mark_success_restores_server() {
let pool = pool_with_servers(&[8001]);
pool.mark_failed(&addr(8001));
assert!(pool.next().is_none());
pool.mark_success(&addr(8001));
assert!(pool.next().is_some());
}
#[test]
fn backup_server_used_when_primary_down() {
let servers = vec![
UpstreamServer::new(addr(8001)),
UpstreamServer::new(addr(8002)).with_backup(true),
];
let pool = UpstreamPool::new("test", servers, Algorithm::RoundRobin);
// Primary is available, should get primary.
let s = pool.next().unwrap();
assert_eq!(s.addr.port(), 8001);
// Disable primary.
pool.mark_failed(&addr(8001));
let s = pool.next().unwrap();
assert_eq!(s.addr.port(), 8002);
}
#[test]
fn down_server_never_selected() {
let servers = vec![
UpstreamServer::new(addr(8001)).with_state(ServerState::Down),
UpstreamServer::new(addr(8002)),
];
let pool = UpstreamPool::new("test", servers, Algorithm::RoundRobin);
for _ in 0..10 {
let s = pool.next().unwrap();
assert_eq!(s.addr.port(), 8002);
}
}
#[test]
fn active_servers_excludes_unavailable() {
let pool = pool_with_servers(&[8001, 8002]);
pool.mark_failed(&addr(8001));
let active = pool.active_servers();
assert_eq!(active.len(), 1);
assert_eq!(active[0].addr.port(), 8002);
}
#[test]
fn all_servers_includes_everything() {
let pool = pool_with_servers(&[8001, 8002]);
pool.mark_failed(&addr(8001));
assert_eq!(pool.all_servers().len(), 2);
}
#[test]
fn available_count_tracks_state() {
let pool = pool_with_servers(&[8001, 8002, 8003]);
assert_eq!(pool.available_count(), 3);
pool.mark_failed(&addr(8002));
assert_eq!(pool.available_count(), 2);
}
#[test]
fn unknown_addr_is_noop() {
let pool = pool_with_servers(&[8001]);
// Should not panic.
pool.mark_failed(&addr(9999));
pool.mark_success(&addr(9999));
assert_eq!(pool.available_count(), 1);
}
#[test]
fn weighted_distribution() {
let servers = vec![
UpstreamServer::new(addr(8001)).with_weight(5),
UpstreamServer::new(addr(8002)).with_weight(1),
];
let pool = UpstreamPool::new("test", servers, Algorithm::RoundRobin);
let mut counts = [0usize; 2];
for _ in 0..600 {
if let Some(s) = pool.next() {
match s.addr.port() {
8001 => counts[0] += 1,
8002 => counts[1] += 1,
_ => {}
}
}
}
// Server 1 (weight 5) should get roughly 5x the traffic of server 2.
assert!(
counts[0] > counts[1] * 3,
"expected heavy skew, got {:?}",
counts
);
}
#[test]
fn builder_methods() {
let server = UpstreamServer::new(addr(8080))
.with_weight(5)
.with_max_fails(3)
.with_fail_timeout(Duration::from_secs(30))
.with_backup(true)
.with_state(ServerState::Alive);
assert_eq!(server.weight, 5);
assert_eq!(server.max_fails, 3);
assert_eq!(server.fail_timeout, Duration::from_secs(30));
assert!(server.backup);
}
#[test]
fn builder_clamps_minimums() {
let server = UpstreamServer::new(addr(8080))
.with_weight(0)
.with_max_fails(0);
assert_eq!(server.weight, 1);
assert_eq!(server.max_fails, 1);
}
#[test]
fn pool_debug_format() {
let pool = pool_with_servers(&[8001]);
let dbg = format!("{:?}", pool);
assert!(dbg.contains("test"));
assert!(dbg.contains("server_count"));
}
#[test]
#[should_panic(expected = "at least one server")]
fn empty_pool_panics() {
UpstreamPool::new("empty", vec![], Algorithm::RoundRobin);
}
}