Ferrit Explore
中文·繁體·EN·日本語 Sign in Register
cielxl / veld / src / proxy / load_balance.rs
//! Load-balancing algorithms for upstream server selection.
//!
//! Provides a [`LoadBalancer`] trait and three concrete implementations:
//! [`RoundRobin`], [`LeastConnections`], and [`IpHash`].
//!
//! All implementations skip servers where `is_alive()` returns `false`
//! and respect server weights for proportional traffic distribution.

use std::hash::{Hash, Hasher};
use std::net::IpAddr;
use std::sync::atomic::{AtomicUsize, Ordering};

use super::upstream::UpstreamServer;

// ---------------------------------------------------------------------------
// Algorithm
// ---------------------------------------------------------------------------

/// Load-balancing algorithm for upstream server selection.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Algorithm {
    /// Weighted round-robin: cycles through servers proportionally to weight.
    RoundRobin,
    /// Selects the server with the fewest active connections.
    /// Ties broken by weight (higher weight wins).
    LeastConnections,
    /// Deterministic hash of the client IP for consistent server affinity.
    IpHash,
}

impl Default for Algorithm {
    fn default() -> Self {
        Self::RoundRobin
    }
}

// ---------------------------------------------------------------------------
// LoadBalancer trait
// ---------------------------------------------------------------------------

/// Trait for load-balancing algorithms.
///
/// Implementations must:
/// - Skip servers where `server.is_alive()` returns `false`
/// - Respect server weights for proportional traffic distribution
/// - Use lock-free internals for thread-safe operation
pub trait LoadBalancer: Send + Sync {
    /// Select the next server from the candidate list.
    ///
    /// Returns the index into `servers` for the selected server,
    /// or `None` if no servers are alive.
    fn next_server(&self, servers: &[UpstreamServer]) -> Option<usize>;

    /// Returns the configured algorithm variant.
    fn algorithm(&self) -> Algorithm;
}

// ---------------------------------------------------------------------------
// RoundRobin
// ---------------------------------------------------------------------------

/// Weighted round-robin load balancer.
///
/// Cycles through servers proportionally to their weight using a smooth
/// weighted round-robin algorithm (similar to nginx).  A server with
/// weight 3 receives roughly three times the traffic of a server with
/// weight 1.
///
/// Thread-safe: the cursor is an atomic counter.
pub struct RoundRobin {
    cursor: AtomicUsize,
}

impl RoundRobin {
    pub fn new() -> Self {
        Self {
            cursor: AtomicUsize::new(0),
        }
    }
}

impl Default for RoundRobin {
    fn default() -> Self {
        Self::new()
    }
}

impl LoadBalancer for RoundRobin {
    fn next_server(&self, servers: &[UpstreamServer]) -> Option<usize> {
        let total: u32 = servers
            .iter()
            .filter(|s| s.is_alive())
            .map(|s| s.weight)
            .sum();

        if total == 0 {
            return None;
        }

        let pos = self.cursor.fetch_add(1, Ordering::Relaxed);
        let target = (pos as u32) % total;
        let mut cumulative = 0u32;

        for (i, server) in servers.iter().enumerate() {
            if !server.is_alive() {
                continue;
            }
            cumulative += server.weight;
            if target < cumulative {
                return Some(i);
            }
        }

        // Unreachable when total > 0, but satisfy the compiler.
        servers.iter().rposition(|s| s.is_alive())
    }

    fn algorithm(&self) -> Algorithm {
        Algorithm::RoundRobin
    }
}

// ---------------------------------------------------------------------------
// LeastConnections
// ---------------------------------------------------------------------------

/// Least-connections load balancer with weight tie-breaking.
///
/// Selects the server with the fewest active connections.  When multiple
/// servers share the same connection count, the one with the highest
/// weight wins.
///
/// Call [`increment`] after assigning a request and [`decrement`] after
/// the upstream connection closes.  If the connection counts vector has
/// not been sized to match the server slice, missing entries default to 0.
///
/// Thread-safe: connection counts are atomic integers.
pub struct LeastConnections {
    /// Active connection count per server index.
    active: Vec<AtomicUsize>,
}

impl LeastConnections {
    /// Create a new least-connections balancer for `server_count` servers.
    pub fn new(server_count: usize) -> Self {
        Self {
            active: (0..server_count).map(|_| AtomicUsize::new(0)).collect(),
        }
    }

    /// Record a new connection to the server at `index`.
    pub fn increment(&self, index: usize) {
        if let Some(slot) = self.active.get(index) {
            slot.fetch_add(1, Ordering::Relaxed);
        }
    }

    /// Record a closed connection to the server at `index`.
    pub fn decrement(&self, index: usize) {
        if let Some(slot) = self.active.get(index) {
            // Saturating subtraction: never underflow.
            slot.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
                Some(v.saturating_sub(1))
            })
            .ok();
        }
    }

    /// Return the active connection count for the server at `index`.
    pub fn connections(&self, index: usize) -> usize {
        self.active
            .get(index)
            .map(|a| a.load(Ordering::Relaxed))
            .unwrap_or(0)
    }
}

impl LoadBalancer for LeastConnections {
    fn next_server(&self, servers: &[UpstreamServer]) -> Option<usize> {
        let mut best_idx = None::<usize>;
        let mut best_conns = usize::MAX;
        let mut best_weight = 0u32;

        for (i, server) in servers.iter().enumerate() {
            if !server.is_alive() {
                continue;
            }
            let conns = self.connections(i);
            if best_idx.is_none()
                || conns < best_conns
                || (conns == best_conns && server.weight > best_weight)
            {
                best_idx = Some(i);
                best_conns = conns;
                best_weight = server.weight;
            }
        }

        best_idx
    }

    fn algorithm(&self) -> Algorithm {
        Algorithm::LeastConnections
    }
}

// ---------------------------------------------------------------------------
// IpHash
// ---------------------------------------------------------------------------

/// IP-hash load balancer for consistent client-to-server affinity.
///
/// Hashes the client IP address to deterministically select a server.
/// The same client IP always maps to the same backend as long as the
/// server set remains stable.  Weighted: servers with higher weight
/// occupy a proportionally larger share of the hash space.
///
/// Set the client IP via [`set_client_ip`] before each selection, or
/// pass it at construction time.
///
/// Thread-safe: the hash is computed on the fly from the stored IP.
/// Use `RwLock` or `Mutex` to protect mutable `client_ip` updates if
/// sharing across threads.
pub struct IpHash {
    client_ip: IpAddr,
}

impl IpHash {
    /// Create a new IP-hash balancer for the given client address.
    pub fn new(client_ip: IpAddr) -> Self {
        Self { client_ip }
    }

    /// Update the client IP for subsequent selections.
    pub fn set_client_ip(&mut self, ip: IpAddr) {
        self.client_ip = ip;
    }

    /// Return the currently bound client IP.
    pub fn client_ip(&self) -> IpAddr {
        self.client_ip
    }
}

impl LoadBalancer for IpHash {
    fn next_server(&self, servers: &[UpstreamServer]) -> Option<usize> {
        let total: u32 = servers
            .iter()
            .filter(|s| s.is_alive())
            .map(|s| s.weight)
            .sum();

        if total == 0 {
            return None;
        }

        let mut hasher = std::collections::hash_map::DefaultHasher::new();
        self.client_ip.hash(&mut hasher);
        let hash = hasher.finish();
        let target = (hash as u32) % total;

        let mut cumulative = 0u32;
        for (i, server) in servers.iter().enumerate() {
            if !server.is_alive() {
                continue;
            }
            cumulative += server.weight;
            if target < cumulative {
                return Some(i);
            }
        }

        // Unreachable when total > 0.
        servers.iter().rposition(|s| s.is_alive())
    }

    fn algorithm(&self) -> Algorithm {
        Algorithm::IpHash
    }
}

// ---------------------------------------------------------------------------
// Factory
// ---------------------------------------------------------------------------

/// Create a boxed [`LoadBalancer`] for the given algorithm.
///
/// - `RoundRobin` requires no extra state.
/// - `LeastConnections` allocates tracking slots for `server_count` servers.
/// - `IpHash` binds to `client_ip` (defaults to `0.0.0.0` if `None`).
pub fn create_balancer(
    algorithm: Algorithm,
    client_ip: Option<IpAddr>,
    server_count: usize,
) -> Box<dyn LoadBalancer> {
    match algorithm {
        Algorithm::RoundRobin => Box::new(RoundRobin::new()),
        Algorithm::LeastConnections => Box::new(LeastConnections::new(server_count)),
        Algorithm::IpHash => {
            let ip = client_ip.unwrap_or(IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED));
            Box::new(IpHash::new(ip))
        }
    }
}

// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------

#[cfg(test)]
mod tests {
    use super::super::upstream::ServerState;
    use super::*;
    use std::net::SocketAddr;

    // -- helpers -------------------------------------------------------------

    fn server(port: u16) -> UpstreamServer {
        UpstreamServer::new(SocketAddr::new([127, 0, 0, 1].into(), port))
    }

    fn weighted(port: u16, weight: u32) -> UpstreamServer {
        UpstreamServer::new(SocketAddr::new([127, 0, 0, 1].into(), port)).with_weight(weight)
    }

    fn down(port: u16) -> UpstreamServer {
        UpstreamServer::new(SocketAddr::new([127, 0, 0, 1].into(), port))
            .with_state(ServerState::Down)
    }

    // -- RoundRobin ----------------------------------------------------------

    #[test]
    fn rr_empty_returns_none() {
        let rr = RoundRobin::new();
        assert_eq!(rr.next_server(&[]), None);
    }

    #[test]
    fn rr_single_server_always_same() {
        let rr = RoundRobin::new();
        let servers = [server(8001)];
        for _ in 0..5 {
            assert_eq!(rr.next_server(&servers), Some(0));
        }
    }

    #[test]
    fn rr_cycles_through_servers() {
        let rr = RoundRobin::new();
        let servers = [server(8001), server(8002), server(8003)];
        for expected in 0..6 {
            assert_eq!(rr.next_server(&servers), Some(expected % 3));
        }
    }

    #[test]
    fn rr_weighted_distribution() {
        let rr = RoundRobin::new();
        let servers = [weighted(8001, 3), server(8002)];
        let mut counts = [0usize; 2];
        for _ in 0..400 {
            if let Some(idx) = rr.next_server(&servers) {
                counts[idx] += 1;
            }
        }
        // Server 0 (weight 3) should get ~3x the traffic of server 1.
        assert!(counts[0] > counts[1] * 2, "counts: {:?}", counts);
    }

    #[test]
    fn rr_skips_down_servers() {
        let rr = RoundRobin::new();
        let servers = [down(8001), server(8002), down(8003)];
        for _ in 0..5 {
            assert_eq!(rr.next_server(&servers), Some(1));
        }
    }

    #[test]
    fn rr_all_down_returns_none() {
        let rr = RoundRobin::new();
        let servers = [down(8001), down(8002)];
        assert_eq!(rr.next_server(&servers), None);
    }

    #[test]
    fn rr_algorithm_variant() {
        assert_eq!(RoundRobin::new().algorithm(), Algorithm::RoundRobin);
    }

    // -- LeastConnections ----------------------------------------------------

    #[test]
    fn lc_picks_fewest_connections() {
        let lc = LeastConnections::new(3);
        let servers = [server(8001), server(8002), server(8003)];
        lc.increment(0);
        lc.increment(0);
        lc.increment(2);
        // Server 0: 2 conns, server 1: 0, server 2: 1 -> picks 1
        assert_eq!(lc.next_server(&servers), Some(1));
    }

    #[test]
    fn lc_tie_breaks_by_weight() {
        let lc = LeastConnections::new(3);
        let servers = [weighted(8001, 1), weighted(8002, 5), weighted(8003, 1)];
        // All have 0 connections; server 1 has highest weight.
        assert_eq!(lc.next_server(&servers), Some(1));
    }

    #[test]
    fn lc_skips_down_servers() {
        let lc = LeastConnections::new(3);
        let servers = [server(8001), down(8002), server(8003)];
        lc.increment(0);
        // Server 0: 1 conn, server 1: down, server 2: 0 conns -> picks 2
        assert_eq!(lc.next_server(&servers), Some(2));
    }

    #[test]
    fn lc_increment_decrement_tracking() {
        let lc = LeastConnections::new(2);
        lc.increment(0);
        lc.increment(0);
        assert_eq!(lc.connections(0), 2);
        assert_eq!(lc.connections(1), 0);
        lc.decrement(0);
        assert_eq!(lc.connections(0), 1);
    }

    #[test]
    fn lc_decrement_saturates_at_zero() {
        let lc = LeastConnections::new(1);
        lc.decrement(0);
        assert_eq!(lc.connections(0), 0);
    }

    #[test]
    fn lc_out_of_bounds_is_noop() {
        let lc = LeastConnections::new(2);
        lc.increment(99);
        lc.decrement(99);
        assert_eq!(lc.connections(99), 0);
    }

    #[test]
    fn lc_algorithm_variant() {
        assert_eq!(
            LeastConnections::new(1).algorithm(),
            Algorithm::LeastConnections
        );
    }

    // -- IpHash --------------------------------------------------------------

    #[test]
    fn ip_hash_deterministic() {
        let servers = [server(8001), server(8002), server(8003), server(8004)];
        let ip: IpAddr = "192.168.1.100".parse().unwrap();
        let h = IpHash::new(ip);
        let first = h.next_server(&servers);

        for _ in 0..20 {
            let h2 = IpHash::new(ip);
            assert_eq!(h2.next_server(&servers), first);
        }
    }

    #[test]
    fn ip_hash_different_ips_distribute() {
        let servers = [server(8001), server(8002), server(8003), server(8004)];
        let mut seen = std::collections::HashSet::new();
        for i in 0..256u16 {
            let ip: IpAddr = format!("10.{}.0.1", i).parse().unwrap();
            let h = IpHash::new(ip);
            if let Some(idx) = h.next_server(&servers) {
                seen.insert(idx);
            }
        }
        // 256 IPs over 4 servers should hit all 4.
        assert_eq!(seen.len(), 4, "only hit {:?}", seen);
    }

    #[test]
    fn ip_hash_skips_down_servers() {
        let servers = [down(8001), server(8002), down(8003)];
        let ip: IpAddr = "10.0.0.1".parse().unwrap();
        let h = IpHash::new(ip);
        assert_eq!(h.next_server(&servers), Some(1));
    }

    #[test]
    fn ip_hash_weighted_distribution() {
        let servers = [weighted(8001, 10), weighted(8002, 1)];
        let mut hits = [0usize; 2];
        for i in 0..1100u16 {
            let ip: IpAddr = format!("10.{}.{}.1", i / 256, i % 256).parse().unwrap();
            let h = IpHash::new(ip);
            if let Some(idx) = h.next_server(&servers) {
                hits[idx] += 1;
            }
        }
        // Server 0 (weight 10) should dominate.
        assert!(hits[0] > hits[1] * 5, "hits: {:?}", hits);
    }

    #[test]
    fn ip_hash_set_client_ip() {
        let servers = [server(8001), server(8002), server(8003), server(8004)];
        let mut h = IpHash::new("10.0.0.1".parse().unwrap());
        let first = h.next_server(&servers);

        // Changing the IP should (likely) change the result.
        let mut changed = false;
        for i in 2..50u8 {
            h.set_client_ip(format!("10.0.0.{}", i).parse().unwrap());
            if h.next_server(&servers) != first {
                changed = true;
                break;
            }
        }
        assert!(changed, "set_client_ip had no effect over 48 IPs");
    }

    #[test]
    fn ip_hash_all_down_returns_none() {
        let servers = [down(8001), down(8002)];
        let h = IpHash::new("10.0.0.1".parse().unwrap());
        assert_eq!(h.next_server(&servers), None);
    }

    #[test]
    fn ip_hash_algorithm_variant() {
        let h = IpHash::new("10.0.0.1".parse().unwrap());
        assert_eq!(h.algorithm(), Algorithm::IpHash);
    }

    // -- Factory -------------------------------------------------------------

    #[test]
    fn create_balancer_round_robin() {
        let balancer = create_balancer(Algorithm::RoundRobin, None, 2);
        let servers = [server(8001), server(8002)];
        assert!(balancer.next_server(&servers).is_some());
        assert_eq!(balancer.algorithm(), Algorithm::RoundRobin);
    }

    #[test]
    fn create_balancer_least_conn() {
        let balancer = create_balancer(Algorithm::LeastConnections, None, 2);
        let servers = [server(8001), server(8002)];
        assert!(balancer.next_server(&servers).is_some());
        assert_eq!(balancer.algorithm(), Algorithm::LeastConnections);
    }

    #[test]
    fn create_balancer_ip_hash() {
        let ip: IpAddr = "10.0.0.1".parse().unwrap();
        let balancer = create_balancer(Algorithm::IpHash, Some(ip), 2);
        let servers = [server(8001), server(8002)];
        assert!(balancer.next_server(&servers).is_some());
        assert_eq!(balancer.algorithm(), Algorithm::IpHash);
    }

    #[test]
    fn create_balancer_ip_hash_default_ip() {
        let balancer = create_balancer(Algorithm::IpHash, None, 2);
        let servers = [server(8001), server(8002)];
        assert!(balancer.next_server(&servers).is_some());
    }

    // -- Algorithm -----------------------------------------------------------

    #[test]
    fn algorithm_default_is_round_robin() {
        assert_eq!(Algorithm::default(), Algorithm::RoundRobin);
    }

    #[test]
    fn algorithm_debug_format() {
        assert_eq!(format!("{:?}", Algorithm::RoundRobin), "RoundRobin");
        assert_eq!(
            format!("{:?}", Algorithm::LeastConnections),
            "LeastConnections"
        );
        assert_eq!(format!("{:?}", Algorithm::IpHash), "IpHash");
    }

    #[test]
    fn algorithm_clone_copy() {
        let a = Algorithm::IpHash;
        let b = a;
        let c = a;
        assert_eq!(a, b);
        assert_eq!(a, c);
    }
}