Ferrit Explore
中文·繁體·EN·日本語 Sign in Register
cielxl / veld / src / proxy / upstream.rs
//! Upstream pool management for reverse-proxy load balancing.
//!
//! Manages a named pool of backend servers with weighted load balancing,
//! active health tracking, and automatic fail-over to backup servers.
//!
//! All mutating operations are protected by [`parking_lot::RwLock`], making
//! [`UpstreamPool`] safe to share across threads (e.g. via `Arc`).

use std::net::SocketAddr;
use std::time::{Duration, Instant};

use parking_lot::RwLock;
use tracing::{debug, warn};

use super::load_balance::{create_balancer, Algorithm, LoadBalancer};

// ---------------------------------------------------------------------------
// ServerState
// ---------------------------------------------------------------------------

/// Health state of an upstream server.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ServerState {
    /// Server is healthy and eligible for new connections.
    Alive,
    /// Server has been administratively disabled and will not receive traffic.
    Down,
    /// Server has been temporarily disabled because it exceeded `max_fails`
    /// within the `fail_timeout` window.  It will automatically recover once
    /// the timeout elapses.
    Unavailable,
}

impl Default for ServerState {
    fn default() -> Self {
        Self::Alive
    }
}

// ---------------------------------------------------------------------------
// UpstreamServer
// ---------------------------------------------------------------------------

/// A single backend server in an upstream pool.
///
/// Carries the server address, load-balancing weight, failure thresholds,
/// and current health state.  Per-server failure counters are managed
/// internally by [`UpstreamPool`] to ensure thread-safe access.
#[derive(Debug, Clone)]
pub struct UpstreamServer {
    /// Socket address of the backend server.
    pub addr: SocketAddr,
    /// Load-balancing weight.  Higher values receive proportionally more
    /// traffic.  Minimum 1.
    pub weight: u32,
    /// Number of failures within `fail_timeout` required to mark the server
    /// as [`ServerState::Unavailable`].  Minimum 1.
    pub max_fails: u32,
    /// Window for failure counting and automatic recovery.
    pub fail_timeout: Duration,
    /// When `true`, this server is only used when all primary (non-backup)
    /// servers are unavailable.
    pub backup: bool,
    /// Current health state.
    pub state: ServerState,
}

impl UpstreamServer {
    /// Create a new server with sensible defaults:
    ///
    /// | Field          | Default       |
    /// |---------------|---------------|
    /// | `weight`      | `1`           |
    /// | `max_fails`   | `1`           |
    /// | `fail_timeout`| 10 seconds    |
    /// | `backup`      | `false`       |
    /// | `state`       | `Alive`       |
    pub fn new(addr: SocketAddr) -> Self {
        Self {
            addr,
            weight: 1,
            max_fails: 1,
            fail_timeout: Duration::from_secs(10),
            backup: false,
            state: ServerState::Alive,
        }
    }

    /// Builder: set the load-balancing weight (clamped to minimum 1).
    pub fn with_weight(mut self, weight: u32) -> Self {
        self.weight = weight.max(1);
        self
    }

    /// Builder: set the maximum failure count (clamped to minimum 1).
    pub fn with_max_fails(mut self, max_fails: u32) -> Self {
        self.max_fails = max_fails.max(1);
        self
    }

    /// Builder: set the fail-timeout duration.
    pub fn with_fail_timeout(mut self, timeout: Duration) -> Self {
        self.fail_timeout = timeout;
        self
    }

    /// Builder: mark this server as a backup.
    pub fn with_backup(mut self, backup: bool) -> Self {
        self.backup = backup;
        self
    }

    /// Builder: set the initial health state.
    pub fn with_state(mut self, state: ServerState) -> Self {
        self.state = state;
        self
    }

    /// Returns `true` if the server is not administratively disabled.
    #[inline]
    pub fn is_alive(&self) -> bool {
        self.state != ServerState::Down
    }
}

// ---------------------------------------------------------------------------
// ServerTracker  (private)
// ---------------------------------------------------------------------------

/// Mutable per-server health tracking data.
///
/// Kept separate from [`UpstreamServer`] so the public struct can be
/// cheaply cloned / inspected without exposing internal counters.
#[derive(Debug)]
struct ServerTracker {
    /// Consecutive failure count (resets on success or timeout expiry).
    fails: u32,
    /// Timestamp of the most recent failure.
    last_fail: Option<Instant>,
}

impl ServerTracker {
    fn new() -> Self {
        Self {
            fails: 0,
            last_fail: None,
        }
    }

    /// Returns `true` when the server should receive new connections.
    fn is_available(&self, server: &UpstreamServer) -> bool {
        match server.state {
            ServerState::Down => false,
            ServerState::Alive => self.fails < server.max_fails,
            ServerState::Unavailable => match self.last_fail {
                Some(last) => last.elapsed() >= server.fail_timeout,
                None => true,
            },
        }
    }

    /// Record a failed request.
    ///
    /// If the failure count reaches `max_fails` the server transitions to
    /// [`ServerState::Unavailable`].  If the previous failure occurred
    /// outside the `fail_timeout` window the counter is reset first, so a
    /// recovered server does not immediately become unavailable again.
    fn record_failure(&mut self, server: &mut UpstreamServer) {
        // Reset counter when the fail-timeout has elapsed since the last
        // failure (treat as a fresh start).
        if let Some(last) = self.last_fail {
            if last.elapsed() >= server.fail_timeout {
                self.fails = 0;
            }
        }

        self.fails += 1;
        self.last_fail = Some(Instant::now());

        if self.fails >= server.max_fails && server.state == ServerState::Alive {
            server.state = ServerState::Unavailable;
            warn!(
                addr = %server.addr,
                fails = self.fails,
                max_fails = server.max_fails,
                fail_timeout = ?server.fail_timeout,
                "upstream server marked unavailable"
            );
        }
    }

    /// Record a successful request.
    ///
    /// Resets the failure counter and transitions an `Unavailable` server
    /// back to `Alive`.
    fn record_success(&mut self, server: &mut UpstreamServer) {
        if self.fails > 0 || server.state == ServerState::Unavailable {
            debug!(
                addr = %server.addr,
                prev_fails = self.fails,
                "upstream server recovered"
            );
        }
        self.fails = 0;
        self.last_fail = None;
        if server.state == ServerState::Unavailable {
            server.state = ServerState::Alive;
        }
    }
}

// ---------------------------------------------------------------------------
// UpstreamPool
// ---------------------------------------------------------------------------

/// A named pool of upstream servers with load balancing and health tracking.
///
/// Thread-safe: the server list and per-server health counters are behind a
/// [`parking_lot::RwLock`].  The [`LoadBalancer`] uses atomics internally.
/// Share across tasks with `Arc<UpstreamPool>`.
///
/// # Example
///
/// ```ignore
/// use std::net::SocketAddr;
/// use std::sync::Arc;
///
/// let pool = Arc::new(
///     UpstreamPool::new(
///         "backend",
///         vec![
///             UpstreamServer::new("10.0.0.1:8080".parse().unwrap())
///                 .with_weight(3),
///             UpstreamServer::new("10.0.0.2:8080".parse().unwrap()),
///         ],
///         Algorithm::WeightedRoundRobin,
///     )
///     .with_keepalive(64),
/// );
///
/// // In a request handler:
/// if let Some(server) = pool.next() {
///     // proxy to server.addr
/// }
/// ```
pub struct UpstreamPool {
    /// Human-readable pool name (used in log messages).
    name: String,
    /// Server list with paired health trackers, guarded for concurrent access.
    servers: RwLock<Vec<(UpstreamServer, ServerTracker)>>,
    /// Load balancer (uses atomics internally, no lock needed).
    balancer: Box<dyn LoadBalancer>,
    /// Maximum keepalive connections per upstream server.
    keepalive: usize,
}

impl UpstreamPool {
    /// Create a new upstream pool.
    ///
    /// # Panics
    /// Panics if `servers` is empty.
    pub fn new(
        name: impl Into<String>,
        servers: Vec<UpstreamServer>,
        algorithm: Algorithm,
    ) -> Self {
        assert!(
            !servers.is_empty(),
            "upstream pool requires at least one server"
        );

        let name = name.into();
        let count = servers.len();

        let entries: Vec<_> = servers
            .into_iter()
            .map(|s| (s, ServerTracker::new()))
            .collect();

        debug!(
            pool = %name,
            count = count,
            algorithm = ?algorithm,
            "upstream pool created"
        );

        Self {
            name,
            servers: RwLock::new(entries),
            balancer: create_balancer(algorithm, None, count),
            keepalive: 32,
        }
    }

    /// Set the number of keepalive connections per upstream server.
    pub fn with_keepalive(mut self, keepalive: usize) -> Self {
        self.keepalive = keepalive;
        self
    }

    /// Returns the pool name.
    pub fn name(&self) -> &str {
        &self.name
    }

    /// Returns the configured per-server keepalive count.
    pub fn keepalive(&self) -> usize {
        self.keepalive
    }

    // -- server selection -------------------------------------------------

    /// Select the next available server via the configured load balancer.
    ///
    /// Selection strategy:
    /// 1. Filter to primary (non-backup) servers that are currently available.
    /// 2. If no primary servers are available, fall back to backup servers.
    /// 3. If nothing is available at all, return `None`.
    pub fn next(&self) -> Option<UpstreamServer> {
        let servers = self.servers.read();

        // Try primary servers first.
        if let Some(server) = self.select_from_group(&servers, false) {
            return Some(server);
        }

        // Fall back to backups.
        if let Some(server) = self.select_from_group(&servers, true) {
            warn!(pool = %self.name, "all primary servers down, using backup");
            return Some(server);
        }

        warn!(pool = %self.name, "no upstream servers available");
        None
    }

    // -- health management ------------------------------------------------

    /// Record a failure for the server at `addr`.
    ///
    /// Increments the failure counter.  When the counter reaches `max_fails`
    /// the server transitions to [`ServerState::Unavailable`] and will not
    /// receive new connections until `fail_timeout` elapses.
    pub fn mark_failed(&self, addr: &SocketAddr) {
        let mut servers = self.servers.write();
        if let Some((server, tracker)) = servers.iter_mut().find(|(s, _)| &s.addr == addr) {
            tracker.record_failure(server);
        }
    }

    /// Record a successful response from the server at `addr`.
    ///
    /// Resets the failure counter.  If the server was `Unavailable` it is
    /// restored to `Alive`.
    pub fn mark_success(&self, addr: &SocketAddr) {
        let mut servers = self.servers.write();
        if let Some((server, tracker)) = servers.iter_mut().find(|(s, _)| &s.addr == addr) {
            tracker.record_success(server);
        }
    }

    // -- introspection ----------------------------------------------------

    /// Return clones of all currently available servers.
    pub fn active_servers(&self) -> Vec<UpstreamServer> {
        let servers = self.servers.read();
        servers
            .iter()
            .filter(|(s, t)| t.is_available(s))
            .map(|(s, _)| s.clone())
            .collect()
    }

    /// Return clones of every server in the pool regardless of state.
    pub fn all_servers(&self) -> Vec<UpstreamServer> {
        self.servers.read().iter().map(|(s, _)| s.clone()).collect()
    }

    /// Total number of servers in the pool.
    pub fn server_count(&self) -> usize {
        self.servers.read().len()
    }

    /// Number of servers currently accepting connections.
    pub fn available_count(&self) -> usize {
        self.servers
            .read()
            .iter()
            .filter(|(s, t)| t.is_available(s))
            .count()
    }

    // -- private ----------------------------------------------------------

    /// Pick one server from the primary or backup group using the load
    /// balancer.  Returns a clone so the lock can be released.
    ///
    /// Two-phase filtering: `ServerTracker.is_available()` handles health
    /// state, then the load balancer skips any remaining `Down` servers.
    fn select_from_group(
        &self,
        servers: &[(UpstreamServer, ServerTracker)],
        backup: bool,
    ) -> Option<UpstreamServer> {
        // Phase 1: filter by backup flag and health-tracker availability.
        let candidates: Vec<(usize, &UpstreamServer)> = servers
            .iter()
            .enumerate()
            .filter(|(_, (s, t))| s.backup == backup && t.is_available(s))
            .map(|(i, (s, _))| (i, s))
            .collect();

        if candidates.is_empty() {
            return None;
        }

        // Phase 2: pass candidates to the load balancer, which applies
        // weighted selection and skips any `Down` servers.
        let candidate_refs: Vec<UpstreamServer> =
            candidates.iter().map(|(_, s)| (*s).clone()).collect();
        let selected = self.balancer.next_server(&candidate_refs)?;
        Some(candidates[selected].1.clone())
    }
}

impl std::fmt::Debug for UpstreamPool {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("UpstreamPool")
            .field("name", &self.name)
            .field("server_count", &self.server_count())
            .field("keepalive", &self.keepalive)
            .field("algorithm", &self.balancer.algorithm())
            .finish()
    }
}

// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------

#[cfg(test)]
mod tests {
    use super::*;
    use std::net::SocketAddr;

    fn addr(port: u16) -> SocketAddr {
        SocketAddr::new([127, 0, 0, 1].into(), port)
    }

    fn pool_with_servers(addrs: &[u16]) -> UpstreamPool {
        let servers: Vec<_> = addrs
            .iter()
            .map(|&p| UpstreamServer::new(addr(p)))
            .collect();
        UpstreamPool::new("test", servers, Algorithm::RoundRobin)
    }

    #[test]
    fn next_returns_a_server() {
        let pool = pool_with_servers(&[8001, 8002, 8003]);
        let server = pool.next().expect("should select a server");
        assert!(server.addr.port() >= 8001 && server.addr.port() <= 8003);
    }

    #[test]
    fn next_cycles_through_servers() {
        let pool = pool_with_servers(&[8001, 8002, 8003]);
        let ports: Vec<u16> = (0..6)
            .filter_map(|_| pool.next().map(|s| s.addr.port()))
            .collect();
        // With round-robin over 3 servers we should see the pattern repeat.
        assert_eq!(ports[0], ports[3]);
        assert_eq!(ports[1], ports[4]);
        assert_eq!(ports[2], ports[5]);
    }

    #[test]
    fn mark_failed_disables_server_after_max_fails() {
        let pool = pool_with_servers(&[8001]);
        pool.mark_failed(&addr(8001));
        // max_fails defaults to 1, so the server should now be unavailable.
        assert!(pool.next().is_none());
    }

    #[test]
    fn mark_success_restores_server() {
        let pool = pool_with_servers(&[8001]);
        pool.mark_failed(&addr(8001));
        assert!(pool.next().is_none());

        pool.mark_success(&addr(8001));
        assert!(pool.next().is_some());
    }

    #[test]
    fn backup_server_used_when_primary_down() {
        let servers = vec![
            UpstreamServer::new(addr(8001)),
            UpstreamServer::new(addr(8002)).with_backup(true),
        ];
        let pool = UpstreamPool::new("test", servers, Algorithm::RoundRobin);

        // Primary is available, should get primary.
        let s = pool.next().unwrap();
        assert_eq!(s.addr.port(), 8001);

        // Disable primary.
        pool.mark_failed(&addr(8001));
        let s = pool.next().unwrap();
        assert_eq!(s.addr.port(), 8002);
    }

    #[test]
    fn down_server_never_selected() {
        let servers = vec![
            UpstreamServer::new(addr(8001)).with_state(ServerState::Down),
            UpstreamServer::new(addr(8002)),
        ];
        let pool = UpstreamPool::new("test", servers, Algorithm::RoundRobin);
        for _ in 0..10 {
            let s = pool.next().unwrap();
            assert_eq!(s.addr.port(), 8002);
        }
    }

    #[test]
    fn active_servers_excludes_unavailable() {
        let pool = pool_with_servers(&[8001, 8002]);
        pool.mark_failed(&addr(8001));
        let active = pool.active_servers();
        assert_eq!(active.len(), 1);
        assert_eq!(active[0].addr.port(), 8002);
    }

    #[test]
    fn all_servers_includes_everything() {
        let pool = pool_with_servers(&[8001, 8002]);
        pool.mark_failed(&addr(8001));
        assert_eq!(pool.all_servers().len(), 2);
    }

    #[test]
    fn available_count_tracks_state() {
        let pool = pool_with_servers(&[8001, 8002, 8003]);
        assert_eq!(pool.available_count(), 3);
        pool.mark_failed(&addr(8002));
        assert_eq!(pool.available_count(), 2);
    }

    #[test]
    fn unknown_addr_is_noop() {
        let pool = pool_with_servers(&[8001]);
        // Should not panic.
        pool.mark_failed(&addr(9999));
        pool.mark_success(&addr(9999));
        assert_eq!(pool.available_count(), 1);
    }

    #[test]
    fn weighted_distribution() {
        let servers = vec![
            UpstreamServer::new(addr(8001)).with_weight(5),
            UpstreamServer::new(addr(8002)).with_weight(1),
        ];
        let pool = UpstreamPool::new("test", servers, Algorithm::RoundRobin);

        let mut counts = [0usize; 2];
        for _ in 0..600 {
            if let Some(s) = pool.next() {
                match s.addr.port() {
                    8001 => counts[0] += 1,
                    8002 => counts[1] += 1,
                    _ => {}
                }
            }
        }
        // Server 1 (weight 5) should get roughly 5x the traffic of server 2.
        assert!(
            counts[0] > counts[1] * 3,
            "expected heavy skew, got {:?}",
            counts
        );
    }

    #[test]
    fn builder_methods() {
        let server = UpstreamServer::new(addr(8080))
            .with_weight(5)
            .with_max_fails(3)
            .with_fail_timeout(Duration::from_secs(30))
            .with_backup(true)
            .with_state(ServerState::Alive);

        assert_eq!(server.weight, 5);
        assert_eq!(server.max_fails, 3);
        assert_eq!(server.fail_timeout, Duration::from_secs(30));
        assert!(server.backup);
    }

    #[test]
    fn builder_clamps_minimums() {
        let server = UpstreamServer::new(addr(8080))
            .with_weight(0)
            .with_max_fails(0);
        assert_eq!(server.weight, 1);
        assert_eq!(server.max_fails, 1);
    }

    #[test]
    fn pool_debug_format() {
        let pool = pool_with_servers(&[8001]);
        let dbg = format!("{:?}", pool);
        assert!(dbg.contains("test"));
        assert!(dbg.contains("server_count"));
    }

    #[test]
    #[should_panic(expected = "at least one server")]
    fn empty_pool_panics() {
        UpstreamPool::new("empty", vec![], Algorithm::RoundRobin);
    }
}