//! Load-balancing algorithms for upstream server selection.
//!
//! Provides a [`LoadBalancer`] trait and three concrete implementations:
//! [`RoundRobin`], [`LeastConnections`], and [`IpHash`].
//!
//! All implementations skip servers where `is_alive()` returns `false`
//! and respect server weights for proportional traffic distribution.
use std::hash::{Hash, Hasher};
use std::net::IpAddr;
use std::sync::atomic::{AtomicUsize, Ordering};
use super::upstream::UpstreamServer;
// ---------------------------------------------------------------------------
// Algorithm
// ---------------------------------------------------------------------------
/// Load-balancing algorithm for upstream server selection.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Algorithm {
/// Weighted round-robin: cycles through servers proportionally to weight.
RoundRobin,
/// Selects the server with the fewest active connections.
/// Ties broken by weight (higher weight wins).
LeastConnections,
/// Deterministic hash of the client IP for consistent server affinity.
IpHash,
}
impl Default for Algorithm {
fn default() -> Self {
Self::RoundRobin
}
}
// ---------------------------------------------------------------------------
// LoadBalancer trait
// ---------------------------------------------------------------------------
/// Trait for load-balancing algorithms.
///
/// Implementations must:
/// - Skip servers where `server.is_alive()` returns `false`
/// - Respect server weights for proportional traffic distribution
/// - Use lock-free internals for thread-safe operation
pub trait LoadBalancer: Send + Sync {
/// Select the next server from the candidate list.
///
/// Returns the index into `servers` for the selected server,
/// or `None` if no servers are alive.
fn next_server(&self, servers: &[UpstreamServer]) -> Option<usize>;
/// Returns the configured algorithm variant.
fn algorithm(&self) -> Algorithm;
}
// ---------------------------------------------------------------------------
// RoundRobin
// ---------------------------------------------------------------------------
/// Weighted round-robin load balancer.
///
/// Cycles through servers proportionally to their weight using a smooth
/// weighted round-robin algorithm (similar to nginx). A server with
/// weight 3 receives roughly three times the traffic of a server with
/// weight 1.
///
/// Thread-safe: the cursor is an atomic counter.
pub struct RoundRobin {
cursor: AtomicUsize,
}
impl RoundRobin {
pub fn new() -> Self {
Self {
cursor: AtomicUsize::new(0),
}
}
}
impl Default for RoundRobin {
fn default() -> Self {
Self::new()
}
}
impl LoadBalancer for RoundRobin {
fn next_server(&self, servers: &[UpstreamServer]) -> Option<usize> {
let total: u32 = servers
.iter()
.filter(|s| s.is_alive())
.map(|s| s.weight)
.sum();
if total == 0 {
return None;
}
let pos = self.cursor.fetch_add(1, Ordering::Relaxed);
let target = (pos as u32) % total;
let mut cumulative = 0u32;
for (i, server) in servers.iter().enumerate() {
if !server.is_alive() {
continue;
}
cumulative += server.weight;
if target < cumulative {
return Some(i);
}
}
// Unreachable when total > 0, but satisfy the compiler.
servers.iter().rposition(|s| s.is_alive())
}
fn algorithm(&self) -> Algorithm {
Algorithm::RoundRobin
}
}
// ---------------------------------------------------------------------------
// LeastConnections
// ---------------------------------------------------------------------------
/// Least-connections load balancer with weight tie-breaking.
///
/// Selects the server with the fewest active connections. When multiple
/// servers share the same connection count, the one with the highest
/// weight wins.
///
/// Call [`increment`] after assigning a request and [`decrement`] after
/// the upstream connection closes. If the connection counts vector has
/// not been sized to match the server slice, missing entries default to 0.
///
/// Thread-safe: connection counts are atomic integers.
pub struct LeastConnections {
/// Active connection count per server index.
active: Vec<AtomicUsize>,
}
impl LeastConnections {
/// Create a new least-connections balancer for `server_count` servers.
pub fn new(server_count: usize) -> Self {
Self {
active: (0..server_count).map(|_| AtomicUsize::new(0)).collect(),
}
}
/// Record a new connection to the server at `index`.
pub fn increment(&self, index: usize) {
if let Some(slot) = self.active.get(index) {
slot.fetch_add(1, Ordering::Relaxed);
}
}
/// Record a closed connection to the server at `index`.
pub fn decrement(&self, index: usize) {
if let Some(slot) = self.active.get(index) {
// Saturating subtraction: never underflow.
slot.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
Some(v.saturating_sub(1))
})
.ok();
}
}
/// Return the active connection count for the server at `index`.
pub fn connections(&self, index: usize) -> usize {
self.active
.get(index)
.map(|a| a.load(Ordering::Relaxed))
.unwrap_or(0)
}
}
impl LoadBalancer for LeastConnections {
fn next_server(&self, servers: &[UpstreamServer]) -> Option<usize> {
let mut best_idx = None::<usize>;
let mut best_conns = usize::MAX;
let mut best_weight = 0u32;
for (i, server) in servers.iter().enumerate() {
if !server.is_alive() {
continue;
}
let conns = self.connections(i);
if best_idx.is_none()
|| conns < best_conns
|| (conns == best_conns && server.weight > best_weight)
{
best_idx = Some(i);
best_conns = conns;
best_weight = server.weight;
}
}
best_idx
}
fn algorithm(&self) -> Algorithm {
Algorithm::LeastConnections
}
}
// ---------------------------------------------------------------------------
// IpHash
// ---------------------------------------------------------------------------
/// IP-hash load balancer for consistent client-to-server affinity.
///
/// Hashes the client IP address to deterministically select a server.
/// The same client IP always maps to the same backend as long as the
/// server set remains stable. Weighted: servers with higher weight
/// occupy a proportionally larger share of the hash space.
///
/// Set the client IP via [`set_client_ip`] before each selection, or
/// pass it at construction time.
///
/// Thread-safe: the hash is computed on the fly from the stored IP.
/// Use `RwLock` or `Mutex` to protect mutable `client_ip` updates if
/// sharing across threads.
pub struct IpHash {
client_ip: IpAddr,
}
impl IpHash {
/// Create a new IP-hash balancer for the given client address.
pub fn new(client_ip: IpAddr) -> Self {
Self { client_ip }
}
/// Update the client IP for subsequent selections.
pub fn set_client_ip(&mut self, ip: IpAddr) {
self.client_ip = ip;
}
/// Return the currently bound client IP.
pub fn client_ip(&self) -> IpAddr {
self.client_ip
}
}
impl LoadBalancer for IpHash {
fn next_server(&self, servers: &[UpstreamServer]) -> Option<usize> {
let total: u32 = servers
.iter()
.filter(|s| s.is_alive())
.map(|s| s.weight)
.sum();
if total == 0 {
return None;
}
let mut hasher = std::collections::hash_map::DefaultHasher::new();
self.client_ip.hash(&mut hasher);
let hash = hasher.finish();
let target = (hash as u32) % total;
let mut cumulative = 0u32;
for (i, server) in servers.iter().enumerate() {
if !server.is_alive() {
continue;
}
cumulative += server.weight;
if target < cumulative {
return Some(i);
}
}
// Unreachable when total > 0.
servers.iter().rposition(|s| s.is_alive())
}
fn algorithm(&self) -> Algorithm {
Algorithm::IpHash
}
}
// ---------------------------------------------------------------------------
// Factory
// ---------------------------------------------------------------------------
/// Create a boxed [`LoadBalancer`] for the given algorithm.
///
/// - `RoundRobin` requires no extra state.
/// - `LeastConnections` allocates tracking slots for `server_count` servers.
/// - `IpHash` binds to `client_ip` (defaults to `0.0.0.0` if `None`).
pub fn create_balancer(
algorithm: Algorithm,
client_ip: Option<IpAddr>,
server_count: usize,
) -> Box<dyn LoadBalancer> {
match algorithm {
Algorithm::RoundRobin => Box::new(RoundRobin::new()),
Algorithm::LeastConnections => Box::new(LeastConnections::new(server_count)),
Algorithm::IpHash => {
let ip = client_ip.unwrap_or(IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED));
Box::new(IpHash::new(ip))
}
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::super::upstream::ServerState;
use super::*;
use std::net::SocketAddr;
// -- helpers -------------------------------------------------------------
fn server(port: u16) -> UpstreamServer {
UpstreamServer::new(SocketAddr::new([127, 0, 0, 1].into(), port))
}
fn weighted(port: u16, weight: u32) -> UpstreamServer {
UpstreamServer::new(SocketAddr::new([127, 0, 0, 1].into(), port)).with_weight(weight)
}
fn down(port: u16) -> UpstreamServer {
UpstreamServer::new(SocketAddr::new([127, 0, 0, 1].into(), port))
.with_state(ServerState::Down)
}
// -- RoundRobin ----------------------------------------------------------
#[test]
fn rr_empty_returns_none() {
let rr = RoundRobin::new();
assert_eq!(rr.next_server(&[]), None);
}
#[test]
fn rr_single_server_always_same() {
let rr = RoundRobin::new();
let servers = [server(8001)];
for _ in 0..5 {
assert_eq!(rr.next_server(&servers), Some(0));
}
}
#[test]
fn rr_cycles_through_servers() {
let rr = RoundRobin::new();
let servers = [server(8001), server(8002), server(8003)];
for expected in 0..6 {
assert_eq!(rr.next_server(&servers), Some(expected % 3));
}
}
#[test]
fn rr_weighted_distribution() {
let rr = RoundRobin::new();
let servers = [weighted(8001, 3), server(8002)];
let mut counts = [0usize; 2];
for _ in 0..400 {
if let Some(idx) = rr.next_server(&servers) {
counts[idx] += 1;
}
}
// Server 0 (weight 3) should get ~3x the traffic of server 1.
assert!(counts[0] > counts[1] * 2, "counts: {:?}", counts);
}
#[test]
fn rr_skips_down_servers() {
let rr = RoundRobin::new();
let servers = [down(8001), server(8002), down(8003)];
for _ in 0..5 {
assert_eq!(rr.next_server(&servers), Some(1));
}
}
#[test]
fn rr_all_down_returns_none() {
let rr = RoundRobin::new();
let servers = [down(8001), down(8002)];
assert_eq!(rr.next_server(&servers), None);
}
#[test]
fn rr_algorithm_variant() {
assert_eq!(RoundRobin::new().algorithm(), Algorithm::RoundRobin);
}
// -- LeastConnections ----------------------------------------------------
#[test]
fn lc_picks_fewest_connections() {
let lc = LeastConnections::new(3);
let servers = [server(8001), server(8002), server(8003)];
lc.increment(0);
lc.increment(0);
lc.increment(2);
// Server 0: 2 conns, server 1: 0, server 2: 1 -> picks 1
assert_eq!(lc.next_server(&servers), Some(1));
}
#[test]
fn lc_tie_breaks_by_weight() {
let lc = LeastConnections::new(3);
let servers = [weighted(8001, 1), weighted(8002, 5), weighted(8003, 1)];
// All have 0 connections; server 1 has highest weight.
assert_eq!(lc.next_server(&servers), Some(1));
}
#[test]
fn lc_skips_down_servers() {
let lc = LeastConnections::new(3);
let servers = [server(8001), down(8002), server(8003)];
lc.increment(0);
// Server 0: 1 conn, server 1: down, server 2: 0 conns -> picks 2
assert_eq!(lc.next_server(&servers), Some(2));
}
#[test]
fn lc_increment_decrement_tracking() {
let lc = LeastConnections::new(2);
lc.increment(0);
lc.increment(0);
assert_eq!(lc.connections(0), 2);
assert_eq!(lc.connections(1), 0);
lc.decrement(0);
assert_eq!(lc.connections(0), 1);
}
#[test]
fn lc_decrement_saturates_at_zero() {
let lc = LeastConnections::new(1);
lc.decrement(0);
assert_eq!(lc.connections(0), 0);
}
#[test]
fn lc_out_of_bounds_is_noop() {
let lc = LeastConnections::new(2);
lc.increment(99);
lc.decrement(99);
assert_eq!(lc.connections(99), 0);
}
#[test]
fn lc_algorithm_variant() {
assert_eq!(
LeastConnections::new(1).algorithm(),
Algorithm::LeastConnections
);
}
// -- IpHash --------------------------------------------------------------
#[test]
fn ip_hash_deterministic() {
let servers = [server(8001), server(8002), server(8003), server(8004)];
let ip: IpAddr = "192.168.1.100".parse().unwrap();
let h = IpHash::new(ip);
let first = h.next_server(&servers);
for _ in 0..20 {
let h2 = IpHash::new(ip);
assert_eq!(h2.next_server(&servers), first);
}
}
#[test]
fn ip_hash_different_ips_distribute() {
let servers = [server(8001), server(8002), server(8003), server(8004)];
let mut seen = std::collections::HashSet::new();
for i in 0..256u16 {
let ip: IpAddr = format!("10.{}.0.1", i).parse().unwrap();
let h = IpHash::new(ip);
if let Some(idx) = h.next_server(&servers) {
seen.insert(idx);
}
}
// 256 IPs over 4 servers should hit all 4.
assert_eq!(seen.len(), 4, "only hit {:?}", seen);
}
#[test]
fn ip_hash_skips_down_servers() {
let servers = [down(8001), server(8002), down(8003)];
let ip: IpAddr = "10.0.0.1".parse().unwrap();
let h = IpHash::new(ip);
assert_eq!(h.next_server(&servers), Some(1));
}
#[test]
fn ip_hash_weighted_distribution() {
let servers = [weighted(8001, 10), weighted(8002, 1)];
let mut hits = [0usize; 2];
for i in 0..1100u16 {
let ip: IpAddr = format!("10.{}.{}.1", i / 256, i % 256).parse().unwrap();
let h = IpHash::new(ip);
if let Some(idx) = h.next_server(&servers) {
hits[idx] += 1;
}
}
// Server 0 (weight 10) should dominate.
assert!(hits[0] > hits[1] * 5, "hits: {:?}", hits);
}
#[test]
fn ip_hash_set_client_ip() {
let servers = [server(8001), server(8002), server(8003), server(8004)];
let mut h = IpHash::new("10.0.0.1".parse().unwrap());
let first = h.next_server(&servers);
// Changing the IP should (likely) change the result.
let mut changed = false;
for i in 2..50u8 {
h.set_client_ip(format!("10.0.0.{}", i).parse().unwrap());
if h.next_server(&servers) != first {
changed = true;
break;
}
}
assert!(changed, "set_client_ip had no effect over 48 IPs");
}
#[test]
fn ip_hash_all_down_returns_none() {
let servers = [down(8001), down(8002)];
let h = IpHash::new("10.0.0.1".parse().unwrap());
assert_eq!(h.next_server(&servers), None);
}
#[test]
fn ip_hash_algorithm_variant() {
let h = IpHash::new("10.0.0.1".parse().unwrap());
assert_eq!(h.algorithm(), Algorithm::IpHash);
}
// -- Factory -------------------------------------------------------------
#[test]
fn create_balancer_round_robin() {
let balancer = create_balancer(Algorithm::RoundRobin, None, 2);
let servers = [server(8001), server(8002)];
assert!(balancer.next_server(&servers).is_some());
assert_eq!(balancer.algorithm(), Algorithm::RoundRobin);
}
#[test]
fn create_balancer_least_conn() {
let balancer = create_balancer(Algorithm::LeastConnections, None, 2);
let servers = [server(8001), server(8002)];
assert!(balancer.next_server(&servers).is_some());
assert_eq!(balancer.algorithm(), Algorithm::LeastConnections);
}
#[test]
fn create_balancer_ip_hash() {
let ip: IpAddr = "10.0.0.1".parse().unwrap();
let balancer = create_balancer(Algorithm::IpHash, Some(ip), 2);
let servers = [server(8001), server(8002)];
assert!(balancer.next_server(&servers).is_some());
assert_eq!(balancer.algorithm(), Algorithm::IpHash);
}
#[test]
fn create_balancer_ip_hash_default_ip() {
let balancer = create_balancer(Algorithm::IpHash, None, 2);
let servers = [server(8001), server(8002)];
assert!(balancer.next_server(&servers).is_some());
}
// -- Algorithm -----------------------------------------------------------
#[test]
fn algorithm_default_is_round_robin() {
assert_eq!(Algorithm::default(), Algorithm::RoundRobin);
}
#[test]
fn algorithm_debug_format() {
assert_eq!(format!("{:?}", Algorithm::RoundRobin), "RoundRobin");
assert_eq!(
format!("{:?}", Algorithm::LeastConnections),
"LeastConnections"
);
assert_eq!(format!("{:?}", Algorithm::IpHash), "IpHash");
}
#[test]
fn algorithm_clone_copy() {
let a = Algorithm::IpHash;
let b = a;
let c = a;
assert_eq!(a, b);
assert_eq!(a, c);
}
}