Ferrit Explore
中文·繁體·EN·日本語 Sign in Register
cielxl / veld / src / core / pipeline.rs
use std::collections::HashMap;
use std::sync::Arc;

use crate::config::directives::{LocationConfig, ParsedConfig, ServerConfig};
use crate::config::LocationMatch;
use crate::handler::static_file::StaticFileHandler;
use crate::http::request::Request;
use crate::http::response::Response;
use crate::http::status::HttpStatusCode;
use crate::proxy::{self, UpstreamTarget};

/// Request processing pipeline.
pub struct PipelineProcessor {
    config: Arc<ParsedConfig>,
    static_handlers: HashMap<String, StaticFileHandler>,
}

impl PipelineProcessor {
    pub fn new(config: Arc<ParsedConfig>) -> Self {
        // Key handlers by document root, not server_name: several server
        // blocks can share a name (e.g. the :80 redirect block lists every
        // host) yet have different roots, and several can share a root.
        let mut static_handlers = HashMap::new();
        for server in &config.http.servers {
            let key = server.root.to_string_lossy().to_string();
            static_handlers.entry(key).or_insert_with(|| {
                StaticFileHandler::new(
                    server.root.clone(),
                    server.index.clone(),
                    false,
                    true,
                    config.http.gzip.enabled,
                    config.http.gzip.types.clone(),
                    config.http.gzip.min_length,
                )
            });
        }

        Self {
            config,
            static_handlers,
        }
    }

    /// If this request is a connection upgrade (WebSocket) that routes to a
    /// `proxy_pass` location, return the upstream target so the connection
    /// layer can splice the sockets instead of buffering.
    pub fn upgrade_target(&self, request: &Request, port: u16) -> Option<UpstreamTarget> {
        if !proxy::is_upgrade_request(request) {
            return None;
        }
        let server = self.find_server(request, port)?;
        let location = self.find_location(request, server)?;
        let pass = location.proxy_pass.as_ref()?;
        proxy::parse_target(pass)
    }

    /// Process a request and produce a response.  `scheme` is `"http"` or
    /// `"https"` depending on the listener that accepted the connection; it
    /// drives `X-Forwarded-Proto` and `$scheme` interpolation.
    pub async fn process(&self, request: &Request, scheme: &str, port: u16) -> Response {
        let server_config = match self
            .find_server(request, port)
            .or_else(|| self.config.http.servers.first())
        {
            Some(s) => s,
            None => return Response::internal_error(),
        };

        // Server-level `return` (HTTP->HTTPS redirect block) takes priority.
        if let Some((code, target)) = &server_config.return_directive {
            return build_return(*code, target.as_deref(), request, scheme);
        }

        let location = self.find_location(request, server_config);

        if let Some(location) = location {
            // Location-level `return`.
            if let Some((code, target)) = &location.return_directive {
                return build_return(*code, target.as_deref(), request, scheme);
            }

            // Reverse proxy.
            if let Some(pass) = &location.proxy_pass {
                return match proxy::parse_target(pass) {
                    Some(target) => proxy::forward(request, &target, scheme).await,
                    None => Response::bad_gateway(),
                };
            }

            // Static file, then decorate with add_header / expires.
            let mut resp = self.serve_static(server_config, request);
            apply_location_headers(&mut resp, location);
            return resp;
        }

        // No location matched: serve from the server root.
        let mut resp = self.serve_static(server_config, request);
        // Apply add_header/expires from a catch-all `location /` if present.
        if let Some(loc) = server_config
            .locations
            .iter()
            .find(|l| matches!(l.match_type, LocationMatch::Prefix) && l.path == "/")
        {
            apply_location_headers(&mut resp, loc);
        }
        resp
    }

    fn serve_static(&self, server: &ServerConfig, request: &Request) -> Response {
        let key = server.root.to_string_lossy().to_string();
        if let Some(handler) = self.static_handlers.get(&key) {
            return handler.handle(request);
        }
        if let Some(handler) = self.static_handlers.values().next() {
            return handler.handle(request);
        }
        Response::not_found()
    }

    fn find_server(&self, request: &Request, port: u16) -> Option<&ServerConfig> {
        let host = request.host().unwrap_or("");
        let hostname = host.split(':').next().unwrap_or(host);

        // Only servers that actually listen on the port this connection
        // arrived on are candidates — mirrors nginx, where a request on :80
        // never matches a :443-only server block.
        let listens_here = |s: &&ServerConfig| s.listen.iter().any(|l| l.port == port);

        // Exact server_name match on this port.
        if let Some(s) = self
            .config
            .http
            .servers
            .iter()
            .filter(listens_here)
            .find(|s| s.server_name.iter().any(|name| name == hostname))
        {
            return Some(s);
        }
        // Default server for this port (explicit `_`/empty, else the first).
        self.config
            .http
            .servers
            .iter()
            .filter(listens_here)
            .find(|s| s.server_name.iter().any(|n| n == "_" || n.is_empty()))
            .or_else(|| self.config.http.servers.iter().find(listens_here))
    }

    fn find_location<'a>(
        &self,
        request: &Request,
        server: &'a ServerConfig,
    ) -> Option<&'a LocationConfig> {
        let path = &request.path;

        // nginx precedence: exact (`=`) > regex (`~`/`~*`, first match) >
        // longest matching prefix.
        let mut exact: Option<&LocationConfig> = None;
        let mut regex: Option<&LocationConfig> = None;
        let mut prefix: Option<&LocationConfig> = None;
        let mut prefix_len = 0usize;

        for location in &server.locations {
            match location.match_type {
                LocationMatch::Exact => {
                    if path == &location.path {
                        exact = Some(location);
                    }
                }
                LocationMatch::Prefix => {
                    if path.starts_with(&location.path) && location.path.len() >= prefix_len {
                        prefix = Some(location);
                        prefix_len = location.path.len();
                    }
                }
                LocationMatch::Regex | LocationMatch::RegexNoCase => {
                    let ci = matches!(location.match_type, LocationMatch::RegexNoCase);
                    if regex.is_none() && regex_match(&location.path, path, ci) {
                        regex = Some(location);
                    }
                }
            }
        }

        exact.or(regex).or(prefix)
    }
}

/// Build a redirect/return response, interpolating the nginx variables that
/// appear in a typical `return 301 https://$host$request_uri;`.
fn build_return(code: u16, target: Option<&str>, request: &Request, scheme: &str) -> Response {
    let status = HttpStatusCode::from_u16(code).unwrap_or(HttpStatusCode::MOVED_PERMANENTLY);
    match target {
        Some(t) => {
            let host = request.host().unwrap_or("");
            let url = t
                .replace("$host", host)
                .replace("$request_uri", request.uri())
                .replace("$scheme", scheme)
                .replace("$uri", &request.path);
            if (300..400).contains(&code) {
                Response::new()
                    .status(status)
                    .header("Location", &url)
                    .body_str("")
            } else {
                // Non-3xx return with a second argument: use it as the body.
                Response::new().status(status).body_str(&url)
            }
        }
        None => Response::new().status(status).body_str(""),
    }
}

/// Apply a location's `add_header` and `expires` directives to a response.
fn apply_location_headers(resp: &mut Response, location: &LocationConfig) {
    if resp.status != HttpStatusCode::OK && resp.status != HttpStatusCode::NOT_MODIFIED {
        // nginx only adds these on 2xx/3xx by default; keep it simple and
        // skip error responses so 404s aren't cached.
        return;
    }
    if let Some(secs) = location.expires {
        resp.set_header("Cache-Control", &format!("max-age={secs}"));
    }
    // nginx `add_header` appends, so a directive can coexist with the
    // Cache-Control emitted by `expires` (both lines reach the client).
    for (name, value) in &location.add_headers {
        resp.headers.append(name.as_str(), value.as_str());
    }
}

/// Match the small subset of regex used by typical static-asset locations:
/// `\.(ext1|ext2|...)$`.  Anything more exotic falls back to a substring test.
fn regex_match(pattern: &str, path: &str, case_insensitive: bool) -> bool {
    let hay = if case_insensitive {
        path.to_ascii_lowercase()
    } else {
        path.to_string()
    };

    // Extension-alternation form: `\.(a|b|c)$`
    if let Some(start) = pattern.find('(') {
        if let Some(end) = pattern.find(')') {
            if pattern[..start].ends_with("\\.") && pattern[end..].starts_with(")$") {
                let exts = &pattern[start + 1..end];
                for tok in exts.split('|') {
                    for cand in expand_optional(tok) {
                        let needle = format!(
                            ".{}",
                            if case_insensitive {
                                cand.to_ascii_lowercase()
                            } else {
                                cand.clone()
                            }
                        );
                        if hay.ends_with(&needle) {
                            return true;
                        }
                    }
                }
                return false;
            }
        }
    }

    // Fallback: treat the pattern literally.
    let pat = if case_insensitive {
        pattern.to_ascii_lowercase()
    } else {
        pattern.to_string()
    };
    hay.contains(pat.trim_start_matches('^').trim_end_matches('$'))
}

/// Expand a regex token containing a single optional char, e.g. `woff2?`
/// into `["woff2", "woff"]`.
fn expand_optional(tok: &str) -> Vec<String> {
    if let Some(idx) = tok.find('?') {
        if idx >= 1 {
            let with = format!("{}{}", &tok[..idx], &tok[idx + 1..]);
            let without = format!("{}{}", &tok[..idx - 1], &tok[idx + 1..]);
            return vec![with, without];
        }
    }
    vec![tok.to_string()]
}