pub mod health_check;
pub mod load_balance;
pub mod upstream;
pub use load_balance::{
create_balancer, Algorithm, IpHash, LeastConnections, LoadBalancer, RoundRobin,
};
pub use upstream::{ServerState, UpstreamPool, UpstreamServer};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tracing::{debug, warn};
use crate::http::is_hop_by_hop_header;
use crate::http::request::{Request, Version};
use crate::http::response::Response;
use crate::http::status::HttpStatusCode;
const CONNECT_TIMEOUT: Duration = Duration::from_secs(60);
const READ_TIMEOUT: Duration = Duration::from_secs(60);
/// A resolved upstream endpoint (`host:port`).
#[derive(Debug, Clone)]
pub struct UpstreamTarget {
pub host: String,
pub port: u16,
}
impl UpstreamTarget {
pub fn addr(&self) -> String {
format!("{}:{}", self.host, self.port)
}
}
/// Parse a `proxy_pass` value such as `http://127.0.0.1:8090` (or a bare
/// `127.0.0.1:8090`) into a connectable target. The path component, if any,
/// is ignored: like nginx with a host-only `proxy_pass`, the original request
/// URI is forwarded unchanged.
pub fn parse_target(s: &str) -> Option<UpstreamTarget> {
let s = s.trim();
let rest = s
.strip_prefix("http://")
.or_else(|| s.strip_prefix("https://"))
.unwrap_or(s);
// Drop any path/query.
let authority = rest.split(['/', '?']).next().unwrap_or(rest);
let (host, port) = match authority.rsplit_once(':') {
Some((h, p)) => (h.to_string(), p.parse().ok()?),
None => (authority.to_string(), 80u16),
};
if host.is_empty() {
return None;
}
Some(UpstreamTarget { host, port })
}
/// Forward a buffered request to `target` and return the upstream's response.
///
/// This is a correct HTTP/1.1 reverse proxy for non-streaming responses
/// (API JSON, redirects, etc.): it parses the upstream status line, headers,
/// and body (honouring `Transfer-Encoding: chunked`), strips hop-by-hop
/// headers, and sets `Host` / `X-Forwarded-*` the way nginx's default
/// `proxy_set_header`s do.
pub async fn forward(req: &Request, target: &UpstreamTarget, scheme: &str) -> Response {
let addr = target.addr();
let stream = match tokio::time::timeout(CONNECT_TIMEOUT, TcpStream::connect(&addr)).await {
Ok(Ok(s)) => s,
Ok(Err(e)) => {
warn!("proxy: connect to {addr} failed: {e}");
return Response::bad_gateway();
}
Err(_) => {
warn!("proxy: connect to {addr} timed out");
return Response::gateway_timeout();
}
};
let _ = stream.set_nodelay(true);
let mut stream = stream;
let request_bytes = build_upstream_request(req, target, scheme);
if let Err(e) = stream.write_all(&request_bytes).await {
warn!("proxy: write to {addr} failed: {e}");
return Response::bad_gateway();
}
if let Err(e) = stream.flush().await {
warn!("proxy: flush to {addr} failed: {e}");
return Response::bad_gateway();
}
// We send `Connection: close`, so the upstream closes the socket once the
// full response is written; reading to EOF therefore yields the complete
// message.
let mut raw = Vec::with_capacity(8192);
match tokio::time::timeout(READ_TIMEOUT, stream.read_to_end(&mut raw)).await {
Ok(Ok(_)) => {}
Ok(Err(e)) => {
warn!("proxy: read from {addr} failed: {e}");
return Response::bad_gateway();
}
Err(_) => {
warn!("proxy: read from {addr} timed out");
return Response::gateway_timeout();
}
}
match parse_upstream_response(&raw) {
Some(resp) => resp,
None => {
warn!("proxy: malformed response from {addr}");
Response::bad_gateway()
}
}
}
/// Serialize the outbound request to the upstream, applying nginx-style
/// proxy headers.
fn build_upstream_request(req: &Request, target: &UpstreamTarget, scheme: &str) -> Vec<u8> {
let mut buf = Vec::with_capacity(512);
buf.extend_from_slice(req.method().as_bytes());
buf.push(b' ');
buf.extend_from_slice(req.uri().as_bytes());
buf.extend_from_slice(b" HTTP/1.1\r\n");
// Preserve the client's Host header (proxy_set_header Host $host); fall
// back to the upstream authority if the client sent none.
let host = req.host().unwrap_or(&target.host);
let client_ip = req
.remote_addr
.map(|a| a.ip().to_string())
.unwrap_or_else(|| "unknown".to_string());
for (name, values) in req.headers.iter() {
let lname = name.as_str();
if is_hop_by_hop_header(lname)
|| lname == "host"
|| lname == "x-forwarded-for"
|| lname == "x-forwarded-proto"
|| lname == "x-real-ip"
{
continue;
}
for value in values {
buf.extend_from_slice(name.as_original().as_bytes());
buf.extend_from_slice(b": ");
buf.extend_from_slice(value.as_bytes());
buf.extend_from_slice(b"\r\n");
}
}
let xff = match req.header("x-forwarded-for") {
Some(existing) => format!("{existing}, {client_ip}"),
None => client_ip.clone(),
};
write_header(&mut buf, "Host", host);
write_header(&mut buf, "X-Real-IP", &client_ip);
write_header(&mut buf, "X-Forwarded-For", &xff);
write_header(&mut buf, "X-Forwarded-Proto", scheme);
write_header(&mut buf, "Connection", "close");
buf.extend_from_slice(b"\r\n");
if let Some(ref body) = req.body {
buf.extend_from_slice(body);
}
buf
}
#[inline]
fn write_header(buf: &mut Vec<u8>, name: &str, value: &str) {
buf.extend_from_slice(name.as_bytes());
buf.extend_from_slice(b": ");
buf.extend_from_slice(value.as_bytes());
buf.extend_from_slice(b"\r\n");
}
/// Parse a raw upstream HTTP/1.x response into a [`Response`].
fn parse_upstream_response(raw: &[u8]) -> Option<Response> {
let split = find_double_crlf(raw)?;
let head = &raw[..split];
let body_raw = &raw[split + 4..];
let mut lines = head.split(|&b| b == b'\n');
let status_line = lines.next()?;
let status_line = std::str::from_utf8(trim_cr(status_line)).ok()?;
let mut parts = status_line.splitn(3, ' ');
let _version = parts.next()?;
let code: u16 = parts.next()?.parse().ok()?;
let status = HttpStatusCode::from_u16(code).unwrap_or(HttpStatusCode::OK);
let mut response = Response::new().status(status);
response.version = Version::Http11;
let mut chunked = false;
for line in lines {
let line = trim_cr(line);
if line.is_empty() {
continue;
}
let line = match std::str::from_utf8(line) {
Ok(s) => s,
Err(_) => continue,
};
let (name, value) = match line.split_once(':') {
Some((n, v)) => (n.trim(), v.trim()),
None => continue,
};
let lname = name.to_ascii_lowercase();
if lname == "transfer-encoding" {
if value.to_ascii_lowercase().contains("chunked") {
chunked = true;
}
continue; // we re-frame with Content-Length
}
if is_hop_by_hop_header(&lname) || lname == "content-length" {
continue;
}
response.headers.insert(name, value);
}
let body = if chunked {
dechunk(body_raw)
} else {
body_raw.to_vec()
};
Some(response.with_body_bytes(body))
}
/// Decode an HTTP/1.1 chunked transfer-coding body.
fn dechunk(mut data: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(data.len());
while let Some(nl) = find_crlf(data) {
let size_line = trim_cr(&data[..nl]);
// Chunk size may carry extensions after a ';'.
let size_str = size_line.split(|&b| b == b';').next().unwrap_or(size_line);
let size =
match usize::from_str_radix(std::str::from_utf8(size_str).unwrap_or("").trim(), 16) {
Ok(s) => s,
Err(_) => break,
};
data = &data[nl + 2..];
if size == 0 {
break;
}
if data.len() < size {
out.extend_from_slice(data);
break;
}
out.extend_from_slice(&data[..size]);
data = &data[size..];
// Skip trailing CRLF after the chunk data.
if data.len() >= 2 {
data = &data[2..];
}
}
out
}
#[inline]
fn trim_cr(line: &[u8]) -> &[u8] {
if let Some((&b'\r', rest)) = line.split_last() {
let _ = b'\r';
rest
} else {
line
}
}
fn find_double_crlf(buf: &[u8]) -> Option<usize> {
buf.windows(4).position(|w| w == b"\r\n\r\n")
}
fn find_crlf(buf: &[u8]) -> Option<usize> {
buf.windows(2).position(|w| w == b"\r\n")
}
/// Tunnel a connection-upgrade (e.g. WebSocket) between a client stream and an
/// upstream TCP socket.
///
/// `head` is the already-parsed-and-serialized client request (the upgrade
/// handshake) to replay to the upstream; after that the two sockets are
/// spliced bidirectionally until either side closes.
pub async fn websocket_tunnel<S>(
client: &mut S,
head: &[u8],
target: &UpstreamTarget,
) -> std::io::Result<()>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let addr = target.addr();
let mut upstream = match tokio::time::timeout(CONNECT_TIMEOUT, TcpStream::connect(&addr)).await
{
Ok(Ok(s)) => s,
Ok(Err(e)) => {
warn!("ws: connect to {addr} failed: {e}");
let _ = client
.write_all(
b"HTTP/1.1 502 Bad Gateway\r\nConnection: close\r\nContent-Length: 0\r\n\r\n",
)
.await;
return Err(e);
}
Err(_) => {
warn!("ws: connect to {addr} timed out");
return Ok(());
}
};
let _ = upstream.set_nodelay(true);
upstream.write_all(head).await?;
upstream.flush().await?;
debug!("ws: tunnel established to {addr}");
match tokio::io::copy_bidirectional(client, &mut upstream).await {
Ok((c2u, u2c)) => debug!("ws: tunnel closed ({c2u} up, {u2c} down)"),
Err(e) => debug!("ws: tunnel error: {e}"),
}
Ok(())
}
/// Returns true if the request is a connection upgrade we should tunnel
/// (e.g. WebSocket) rather than buffer.
pub fn is_upgrade_request(req: &Request) -> bool {
let connection_upgrade = req
.header("connection")
.map(|c| c.to_ascii_lowercase().contains("upgrade"))
.unwrap_or(false);
let has_upgrade = req.header("upgrade").is_some();
connection_upgrade && has_upgrade
}