use std::collections::HashMap;
use std::sync::Arc;
use crate::config::directives::{LocationConfig, ParsedConfig, ServerConfig};
use crate::config::LocationMatch;
use crate::handler::static_file::StaticFileHandler;
use crate::http::request::Request;
use crate::http::response::Response;
use crate::http::status::HttpStatusCode;
use crate::proxy::{self, UpstreamTarget};
/// Request processing pipeline.
pub struct PipelineProcessor {
config: Arc<ParsedConfig>,
static_handlers: HashMap<String, StaticFileHandler>,
}
impl PipelineProcessor {
pub fn new(config: Arc<ParsedConfig>) -> Self {
// Key handlers by document root, not server_name: several server
// blocks can share a name (e.g. the :80 redirect block lists every
// host) yet have different roots, and several can share a root.
let mut static_handlers = HashMap::new();
for server in &config.http.servers {
let key = server.root.to_string_lossy().to_string();
static_handlers.entry(key).or_insert_with(|| {
StaticFileHandler::new(
server.root.clone(),
server.index.clone(),
false,
true,
config.http.gzip.enabled,
config.http.gzip.types.clone(),
config.http.gzip.min_length,
)
});
}
Self {
config,
static_handlers,
}
}
/// If this request is a connection upgrade (WebSocket) that routes to a
/// `proxy_pass` location, return the upstream target so the connection
/// layer can splice the sockets instead of buffering.
pub fn upgrade_target(&self, request: &Request, port: u16) -> Option<UpstreamTarget> {
if !proxy::is_upgrade_request(request) {
return None;
}
let server = self.find_server(request, port)?;
let location = self.find_location(request, server)?;
let pass = location.proxy_pass.as_ref()?;
proxy::parse_target(pass)
}
/// Process a request and produce a response. `scheme` is `"http"` or
/// `"https"` depending on the listener that accepted the connection; it
/// drives `X-Forwarded-Proto` and `$scheme` interpolation.
pub async fn process(&self, request: &Request, scheme: &str, port: u16) -> Response {
let server_config = match self
.find_server(request, port)
.or_else(|| self.config.http.servers.first())
{
Some(s) => s,
None => return Response::internal_error(),
};
// Server-level `return` (HTTP->HTTPS redirect block) takes priority.
if let Some((code, target)) = &server_config.return_directive {
return build_return(*code, target.as_deref(), request, scheme);
}
let location = self.find_location(request, server_config);
if let Some(location) = location {
// Location-level `return`.
if let Some((code, target)) = &location.return_directive {
return build_return(*code, target.as_deref(), request, scheme);
}
// Reverse proxy.
if let Some(pass) = &location.proxy_pass {
return match proxy::parse_target(pass) {
Some(target) => proxy::forward(request, &target, scheme).await,
None => Response::bad_gateway(),
};
}
// Static file, then decorate with add_header / expires.
let mut resp = self.serve_static(server_config, request);
apply_location_headers(&mut resp, location);
return resp;
}
// No location matched: serve from the server root.
let mut resp = self.serve_static(server_config, request);
// Apply add_header/expires from a catch-all `location /` if present.
if let Some(loc) = server_config
.locations
.iter()
.find(|l| matches!(l.match_type, LocationMatch::Prefix) && l.path == "/")
{
apply_location_headers(&mut resp, loc);
}
resp
}
fn serve_static(&self, server: &ServerConfig, request: &Request) -> Response {
let key = server.root.to_string_lossy().to_string();
if let Some(handler) = self.static_handlers.get(&key) {
return handler.handle(request);
}
if let Some(handler) = self.static_handlers.values().next() {
return handler.handle(request);
}
Response::not_found()
}
fn find_server(&self, request: &Request, port: u16) -> Option<&ServerConfig> {
let host = request.host().unwrap_or("");
let hostname = host.split(':').next().unwrap_or(host);
// Only servers that actually listen on the port this connection
// arrived on are candidates — mirrors nginx, where a request on :80
// never matches a :443-only server block.
let listens_here = |s: &&ServerConfig| s.listen.iter().any(|l| l.port == port);
// Exact server_name match on this port.
if let Some(s) = self
.config
.http
.servers
.iter()
.filter(listens_here)
.find(|s| s.server_name.iter().any(|name| name == hostname))
{
return Some(s);
}
// Default server for this port (explicit `_`/empty, else the first).
self.config
.http
.servers
.iter()
.filter(listens_here)
.find(|s| s.server_name.iter().any(|n| n == "_" || n.is_empty()))
.or_else(|| self.config.http.servers.iter().find(listens_here))
}
fn find_location<'a>(
&self,
request: &Request,
server: &'a ServerConfig,
) -> Option<&'a LocationConfig> {
let path = &request.path;
// nginx precedence: exact (`=`) > regex (`~`/`~*`, first match) >
// longest matching prefix.
let mut exact: Option<&LocationConfig> = None;
let mut regex: Option<&LocationConfig> = None;
let mut prefix: Option<&LocationConfig> = None;
let mut prefix_len = 0usize;
for location in &server.locations {
match location.match_type {
LocationMatch::Exact => {
if path == &location.path {
exact = Some(location);
}
}
LocationMatch::Prefix => {
if path.starts_with(&location.path) && location.path.len() >= prefix_len {
prefix = Some(location);
prefix_len = location.path.len();
}
}
LocationMatch::Regex | LocationMatch::RegexNoCase => {
let ci = matches!(location.match_type, LocationMatch::RegexNoCase);
if regex.is_none() && regex_match(&location.path, path, ci) {
regex = Some(location);
}
}
}
}
exact.or(regex).or(prefix)
}
}
/// Build a redirect/return response, interpolating the nginx variables that
/// appear in a typical `return 301 https://$host$request_uri;`.
fn build_return(code: u16, target: Option<&str>, request: &Request, scheme: &str) -> Response {
let status = HttpStatusCode::from_u16(code).unwrap_or(HttpStatusCode::MOVED_PERMANENTLY);
match target {
Some(t) => {
let host = request.host().unwrap_or("");
let url = t
.replace("$host", host)
.replace("$request_uri", request.uri())
.replace("$scheme", scheme)
.replace("$uri", &request.path);
if (300..400).contains(&code) {
Response::new()
.status(status)
.header("Location", &url)
.body_str("")
} else {
// Non-3xx return with a second argument: use it as the body.
Response::new().status(status).body_str(&url)
}
}
None => Response::new().status(status).body_str(""),
}
}
/// Apply a location's `add_header` and `expires` directives to a response.
fn apply_location_headers(resp: &mut Response, location: &LocationConfig) {
if resp.status != HttpStatusCode::OK && resp.status != HttpStatusCode::NOT_MODIFIED {
// nginx only adds these on 2xx/3xx by default; keep it simple and
// skip error responses so 404s aren't cached.
return;
}
if let Some(secs) = location.expires {
resp.set_header("Cache-Control", &format!("max-age={secs}"));
}
// nginx `add_header` appends, so a directive can coexist with the
// Cache-Control emitted by `expires` (both lines reach the client).
for (name, value) in &location.add_headers {
resp.headers.append(name.as_str(), value.as_str());
}
}
/// Match the small subset of regex used by typical static-asset locations:
/// `\.(ext1|ext2|...)$`. Anything more exotic falls back to a substring test.
fn regex_match(pattern: &str, path: &str, case_insensitive: bool) -> bool {
let hay = if case_insensitive {
path.to_ascii_lowercase()
} else {
path.to_string()
};
// Extension-alternation form: `\.(a|b|c)$`
if let Some(start) = pattern.find('(') {
if let Some(end) = pattern.find(')') {
if pattern[..start].ends_with("\\.") && pattern[end..].starts_with(")$") {
let exts = &pattern[start + 1..end];
for tok in exts.split('|') {
for cand in expand_optional(tok) {
let needle = format!(
".{}",
if case_insensitive {
cand.to_ascii_lowercase()
} else {
cand.clone()
}
);
if hay.ends_with(&needle) {
return true;
}
}
}
return false;
}
}
}
// Fallback: treat the pattern literally.
let pat = if case_insensitive {
pattern.to_ascii_lowercase()
} else {
pattern.to_string()
};
hay.contains(pat.trim_start_matches('^').trim_end_matches('$'))
}
/// Expand a regex token containing a single optional char, e.g. `woff2?`
/// into `["woff2", "woff"]`.
fn expand_optional(tok: &str) -> Vec<String> {
if let Some(idx) = tok.find('?') {
if idx >= 1 {
let with = format!("{}{}", &tok[..idx], &tok[idx + 1..]);
let without = format!("{}{}", &tok[..idx - 1], &tok[idx + 1..]);
return vec![with, without];
}
}
vec![tok.to_string()]
}