Ferrit Explore
中文·繁體·EN·日本語 Sign in Register
cielxl / veld / src / proxy / mod.rs
pub mod health_check;
pub mod load_balance;
pub mod upstream;

pub use load_balance::{
    create_balancer, Algorithm, IpHash, LeastConnections, LoadBalancer, RoundRobin,
};
pub use upstream::{ServerState, UpstreamPool, UpstreamServer};

use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tracing::{debug, warn};

use crate::http::is_hop_by_hop_header;
use crate::http::request::{Request, Version};
use crate::http::response::Response;
use crate::http::status::HttpStatusCode;

const CONNECT_TIMEOUT: Duration = Duration::from_secs(60);
const READ_TIMEOUT: Duration = Duration::from_secs(60);

/// A resolved upstream endpoint (`host:port`).
#[derive(Debug, Clone)]
pub struct UpstreamTarget {
    pub host: String,
    pub port: u16,
}

impl UpstreamTarget {
    pub fn addr(&self) -> String {
        format!("{}:{}", self.host, self.port)
    }
}

/// Parse a `proxy_pass` value such as `http://127.0.0.1:8090` (or a bare
/// `127.0.0.1:8090`) into a connectable target.  The path component, if any,
/// is ignored: like nginx with a host-only `proxy_pass`, the original request
/// URI is forwarded unchanged.
pub fn parse_target(s: &str) -> Option<UpstreamTarget> {
    let s = s.trim();
    let rest = s
        .strip_prefix("http://")
        .or_else(|| s.strip_prefix("https://"))
        .unwrap_or(s);
    // Drop any path/query.
    let authority = rest.split(['/', '?']).next().unwrap_or(rest);
    let (host, port) = match authority.rsplit_once(':') {
        Some((h, p)) => (h.to_string(), p.parse().ok()?),
        None => (authority.to_string(), 80u16),
    };
    if host.is_empty() {
        return None;
    }
    Some(UpstreamTarget { host, port })
}

/// Forward a buffered request to `target` and return the upstream's response.
///
/// This is a correct HTTP/1.1 reverse proxy for non-streaming responses
/// (API JSON, redirects, etc.): it parses the upstream status line, headers,
/// and body (honouring `Transfer-Encoding: chunked`), strips hop-by-hop
/// headers, and sets `Host` / `X-Forwarded-*` the way nginx's default
/// `proxy_set_header`s do.
pub async fn forward(req: &Request, target: &UpstreamTarget, scheme: &str) -> Response {
    let addr = target.addr();

    let stream = match tokio::time::timeout(CONNECT_TIMEOUT, TcpStream::connect(&addr)).await {
        Ok(Ok(s)) => s,
        Ok(Err(e)) => {
            warn!("proxy: connect to {addr} failed: {e}");
            return Response::bad_gateway();
        }
        Err(_) => {
            warn!("proxy: connect to {addr} timed out");
            return Response::gateway_timeout();
        }
    };
    let _ = stream.set_nodelay(true);
    let mut stream = stream;

    let request_bytes = build_upstream_request(req, target, scheme);
    if let Err(e) = stream.write_all(&request_bytes).await {
        warn!("proxy: write to {addr} failed: {e}");
        return Response::bad_gateway();
    }
    if let Err(e) = stream.flush().await {
        warn!("proxy: flush to {addr} failed: {e}");
        return Response::bad_gateway();
    }

    // We send `Connection: close`, so the upstream closes the socket once the
    // full response is written; reading to EOF therefore yields the complete
    // message.
    let mut raw = Vec::with_capacity(8192);
    match tokio::time::timeout(READ_TIMEOUT, stream.read_to_end(&mut raw)).await {
        Ok(Ok(_)) => {}
        Ok(Err(e)) => {
            warn!("proxy: read from {addr} failed: {e}");
            return Response::bad_gateway();
        }
        Err(_) => {
            warn!("proxy: read from {addr} timed out");
            return Response::gateway_timeout();
        }
    }

    match parse_upstream_response(&raw) {
        Some(resp) => resp,
        None => {
            warn!("proxy: malformed response from {addr}");
            Response::bad_gateway()
        }
    }
}

/// Serialize the outbound request to the upstream, applying nginx-style
/// proxy headers.
fn build_upstream_request(req: &Request, target: &UpstreamTarget, scheme: &str) -> Vec<u8> {
    let mut buf = Vec::with_capacity(512);
    buf.extend_from_slice(req.method().as_bytes());
    buf.push(b' ');
    buf.extend_from_slice(req.uri().as_bytes());
    buf.extend_from_slice(b" HTTP/1.1\r\n");

    // Preserve the client's Host header (proxy_set_header Host $host); fall
    // back to the upstream authority if the client sent none.
    let host = req.host().unwrap_or(&target.host);
    let client_ip = req
        .remote_addr
        .map(|a| a.ip().to_string())
        .unwrap_or_else(|| "unknown".to_string());

    for (name, values) in req.headers.iter() {
        let lname = name.as_str();
        if is_hop_by_hop_header(lname)
            || lname == "host"
            || lname == "x-forwarded-for"
            || lname == "x-forwarded-proto"
            || lname == "x-real-ip"
        {
            continue;
        }
        for value in values {
            buf.extend_from_slice(name.as_original().as_bytes());
            buf.extend_from_slice(b": ");
            buf.extend_from_slice(value.as_bytes());
            buf.extend_from_slice(b"\r\n");
        }
    }

    let xff = match req.header("x-forwarded-for") {
        Some(existing) => format!("{existing}, {client_ip}"),
        None => client_ip.clone(),
    };

    write_header(&mut buf, "Host", host);
    write_header(&mut buf, "X-Real-IP", &client_ip);
    write_header(&mut buf, "X-Forwarded-For", &xff);
    write_header(&mut buf, "X-Forwarded-Proto", scheme);
    write_header(&mut buf, "Connection", "close");
    buf.extend_from_slice(b"\r\n");

    if let Some(ref body) = req.body {
        buf.extend_from_slice(body);
    }
    buf
}

#[inline]
fn write_header(buf: &mut Vec<u8>, name: &str, value: &str) {
    buf.extend_from_slice(name.as_bytes());
    buf.extend_from_slice(b": ");
    buf.extend_from_slice(value.as_bytes());
    buf.extend_from_slice(b"\r\n");
}

/// Parse a raw upstream HTTP/1.x response into a [`Response`].
fn parse_upstream_response(raw: &[u8]) -> Option<Response> {
    let split = find_double_crlf(raw)?;
    let head = &raw[..split];
    let body_raw = &raw[split + 4..];

    let mut lines = head.split(|&b| b == b'\n');
    let status_line = lines.next()?;
    let status_line = std::str::from_utf8(trim_cr(status_line)).ok()?;
    let mut parts = status_line.splitn(3, ' ');
    let _version = parts.next()?;
    let code: u16 = parts.next()?.parse().ok()?;
    let status = HttpStatusCode::from_u16(code).unwrap_or(HttpStatusCode::OK);

    let mut response = Response::new().status(status);
    response.version = Version::Http11;

    let mut chunked = false;
    for line in lines {
        let line = trim_cr(line);
        if line.is_empty() {
            continue;
        }
        let line = match std::str::from_utf8(line) {
            Ok(s) => s,
            Err(_) => continue,
        };
        let (name, value) = match line.split_once(':') {
            Some((n, v)) => (n.trim(), v.trim()),
            None => continue,
        };
        let lname = name.to_ascii_lowercase();
        if lname == "transfer-encoding" {
            if value.to_ascii_lowercase().contains("chunked") {
                chunked = true;
            }
            continue; // we re-frame with Content-Length
        }
        if is_hop_by_hop_header(&lname) || lname == "content-length" {
            continue;
        }
        response.headers.insert(name, value);
    }

    let body = if chunked {
        dechunk(body_raw)
    } else {
        body_raw.to_vec()
    };

    Some(response.with_body_bytes(body))
}

/// Decode an HTTP/1.1 chunked transfer-coding body.
fn dechunk(mut data: &[u8]) -> Vec<u8> {
    let mut out = Vec::with_capacity(data.len());
    while let Some(nl) = find_crlf(data) {
        let size_line = trim_cr(&data[..nl]);
        // Chunk size may carry extensions after a ';'.
        let size_str = size_line.split(|&b| b == b';').next().unwrap_or(size_line);
        let size =
            match usize::from_str_radix(std::str::from_utf8(size_str).unwrap_or("").trim(), 16) {
                Ok(s) => s,
                Err(_) => break,
            };
        data = &data[nl + 2..];
        if size == 0 {
            break;
        }
        if data.len() < size {
            out.extend_from_slice(data);
            break;
        }
        out.extend_from_slice(&data[..size]);
        data = &data[size..];
        // Skip trailing CRLF after the chunk data.
        if data.len() >= 2 {
            data = &data[2..];
        }
    }
    out
}

#[inline]
fn trim_cr(line: &[u8]) -> &[u8] {
    if let Some((&b'\r', rest)) = line.split_last() {
        let _ = b'\r';
        rest
    } else {
        line
    }
}

fn find_double_crlf(buf: &[u8]) -> Option<usize> {
    buf.windows(4).position(|w| w == b"\r\n\r\n")
}

fn find_crlf(buf: &[u8]) -> Option<usize> {
    buf.windows(2).position(|w| w == b"\r\n")
}

/// Tunnel a connection-upgrade (e.g. WebSocket) between a client stream and an
/// upstream TCP socket.
///
/// `head` is the already-parsed-and-serialized client request (the upgrade
/// handshake) to replay to the upstream; after that the two sockets are
/// spliced bidirectionally until either side closes.
pub async fn websocket_tunnel<S>(
    client: &mut S,
    head: &[u8],
    target: &UpstreamTarget,
) -> std::io::Result<()>
where
    S: AsyncRead + AsyncWrite + Unpin,
{
    let addr = target.addr();
    let mut upstream = match tokio::time::timeout(CONNECT_TIMEOUT, TcpStream::connect(&addr)).await
    {
        Ok(Ok(s)) => s,
        Ok(Err(e)) => {
            warn!("ws: connect to {addr} failed: {e}");
            let _ = client
                .write_all(
                    b"HTTP/1.1 502 Bad Gateway\r\nConnection: close\r\nContent-Length: 0\r\n\r\n",
                )
                .await;
            return Err(e);
        }
        Err(_) => {
            warn!("ws: connect to {addr} timed out");
            return Ok(());
        }
    };
    let _ = upstream.set_nodelay(true);

    upstream.write_all(head).await?;
    upstream.flush().await?;

    debug!("ws: tunnel established to {addr}");
    match tokio::io::copy_bidirectional(client, &mut upstream).await {
        Ok((c2u, u2c)) => debug!("ws: tunnel closed ({c2u} up, {u2c} down)"),
        Err(e) => debug!("ws: tunnel error: {e}"),
    }
    Ok(())
}

/// Returns true if the request is a connection upgrade we should tunnel
/// (e.g. WebSocket) rather than buffer.
pub fn is_upgrade_request(req: &Request) -> bool {
    let connection_upgrade = req
        .header("connection")
        .map(|c| c.to_ascii_lowercase().contains("upgrade"))
        .unwrap_or(false);
    let has_upgrade = req.header("upgrade").is_some();
    connection_upgrade && has_upgrade
}