Ferrit Explore
中文·繁體·EN·日本語 Sign in Register
cielxl / veld / src / config / directives.rs
use super::ast::{Block, Value};
use std::collections::HashMap;
use std::path::PathBuf;
use std::time::Duration;

/// Interpreted server configuration
#[derive(Debug, Clone)]
pub struct ParsedConfig {
    pub worker_processes: usize,
    pub error_log: (PathBuf, String),
    pub pid_file: PathBuf,
    pub events: EventsConfig,
    pub http: HttpConfig,
}

#[derive(Debug, Clone)]
pub struct EventsConfig {
    pub worker_connections: usize,
    pub multi_accept: bool,
    pub use_epoll: bool,
}

#[derive(Debug, Clone)]
pub struct HttpConfig {
    pub servers: Vec<ServerConfig>,
    pub upstreams: HashMap<String, UpstreamConfig>,
    pub default_type: String,
    pub sendfile: bool,
    pub tcp_nopush: bool,
    pub tcp_nodelay: bool,
    pub keepalive_timeout: Duration,
    pub client_max_body_size: u64,
    pub gzip: GzipConfig,
    pub access_log: Option<(PathBuf, String)>,
    pub log_formats: HashMap<String, String>,
}

#[derive(Debug, Clone)]
pub struct ServerConfig {
    pub listen: Vec<ListenAddr>,
    pub server_name: Vec<String>,
    pub root: PathBuf,
    pub index: Vec<String>,
    pub locations: Vec<LocationConfig>,
    pub ssl: Option<SslConfig>,
    pub error_pages: HashMap<u16, String>,
    /// A server-level `return <code> [target];` directive.  Used for the
    /// canonical HTTP->HTTPS redirect server block.  Applied to every
    /// request before location matching.
    pub return_directive: Option<(u16, Option<String>)>,
}

#[derive(Debug, Clone)]
pub struct ListenAddr {
    pub addr: String,
    pub port: u16,
    pub ssl: bool,
    pub default_server: bool,
}

#[derive(Debug, Clone)]
pub struct LocationConfig {
    pub path: String,
    pub match_type: LocationMatch,
    pub root: Option<PathBuf>,
    pub alias: Option<PathBuf>,
    pub index: Vec<String>,
    pub proxy_pass: Option<String>,
    pub try_files: Vec<String>,
    pub autoindex: bool,
    pub gzip: Option<bool>,
    pub headers: HashMap<String, String>,
    /// `add_header NAME VALUE;` directives, in declaration order (the map
    /// above is kept for backward compatibility but loses ordering).
    pub add_headers: Vec<(String, String)>,
    /// `expires <time>;` resolved to **seconds**, emitted as
    /// `Cache-Control: max-age=<n>` on static responses.
    pub expires: Option<u64>,
    pub rewrite: Vec<RewriteRule>,
    pub return_directive: Option<(u16, Option<String>)>,
}

#[derive(Debug, Clone)]
pub enum LocationMatch {
    Prefix,      // location /path
    Exact,       // location = /path
    Regex,       // location ~ pattern
    RegexNoCase, // location ~* pattern
}

#[derive(Debug, Clone)]
pub struct SslConfig {
    pub certificate: PathBuf,
    pub certificate_key: PathBuf,
    pub protocols: Vec<String>,
}

#[derive(Debug, Clone)]
pub struct UpstreamConfig {
    pub name: String,
    pub servers: Vec<UpstreamServer>,
    pub algorithm: String,
}

#[derive(Debug, Clone)]
pub struct UpstreamServer {
    pub addr: String,
    pub port: u16,
    pub weight: u32,
    pub max_fails: u32,
    pub fail_timeout: Duration,
    pub backup: bool,
}

#[derive(Debug, Clone)]
pub struct GzipConfig {
    pub enabled: bool,
    pub min_length: usize,
    pub types: Vec<String>,
    pub comp_level: u32,
}

#[derive(Debug, Clone)]
pub struct RewriteRule {
    pub pattern: String,
    pub replacement: String,
    pub redirect: bool,
    pub last: bool,
}

/// Interpret a parsed config AST into structured configuration
pub fn interpret(block: &Block) -> Result<ParsedConfig, super::parser::ConfigError> {
    let worker_processes = match block.get("worker_processes") {
        Some(d) => {
            let val = d.first_arg_str();
            if val == "auto" {
                num_cpus::get()
            } else {
                val.parse().unwrap_or(1)
            }
        }
        None => num_cpus::get(),
    };

    let error_log = match block.get("error_log") {
        Some(d) => {
            let path = PathBuf::from(d.first_arg_str());
            let level = d
                .args
                .get(1)
                .map(|v| v.as_str().to_string())
                .unwrap_or_else(|| "error".to_string());
            (path, level)
        }
        None => (PathBuf::from("logs/error.log"), "error".to_string()),
    };

    let pid_file = block
        .get_str("pid")
        .map(PathBuf::from)
        .unwrap_or_else(|| PathBuf::from("logs/veld.pid"));

    let events = interpret_events(block.get("events").and_then(|d| d.block.as_ref()));

    let http_block = block.get("http").and_then(|d| d.block.as_ref());
    let http = interpret_http(http_block);

    Ok(ParsedConfig {
        worker_processes,
        error_log,
        pid_file,
        events,
        http,
    })
}

fn interpret_events(block: Option<&Block>) -> EventsConfig {
    let block = match block {
        Some(b) => b,
        None => {
            return EventsConfig {
                worker_connections: 1024,
                multi_accept: true,
                use_epoll: true,
            }
        }
    };

    EventsConfig {
        worker_connections: block.get_u64("worker_connections").unwrap_or(1024) as usize,
        multi_accept: block.get_bool("multi_accept").unwrap_or(true),
        use_epoll: block.get_bool("use").map(|_| true).unwrap_or(true),
    }
}

fn interpret_http(block: Option<&Block>) -> HttpConfig {
    let block = match block {
        Some(b) => b,
        None => return HttpConfig::default(),
    };

    let server_blocks = block.get_blocks("server");
    let servers = server_blocks.into_iter().map(interpret_server).collect();

    let mut upstreams = HashMap::new();
    for directive in block.get_all("upstream") {
        if let Some(ref upstream_block) = directive.block {
            let name = directive.first_arg_str().to_string();
            let algorithm = upstream_block
                .get_str("algorithm")
                .unwrap_or("round_robin")
                .to_string();
            let servers = upstream_block
                .get_all("server")
                .iter()
                .map(|d| {
                    let addr_str = d.first_arg_str();
                    let (host, port) = parse_addr_port(addr_str);
                    UpstreamServer {
                        addr: host,
                        port,
                        weight: 1,
                        max_fails: 1,
                        fail_timeout: Duration::from_secs(10),
                        backup: false,
                    }
                })
                .collect();
            upstreams.insert(
                name,
                UpstreamConfig {
                    name: directive.first_arg_str().to_string(),
                    servers,
                    algorithm,
                },
            );
        }
    }

    HttpConfig {
        servers,
        upstreams,
        default_type: block
            .get_str("default_type")
            .unwrap_or("application/octet-stream")
            .to_string(),
        sendfile: block.get_bool("sendfile").unwrap_or(true),
        tcp_nopush: block.get_bool("tcp_nopush").unwrap_or(true),
        tcp_nodelay: block.get_bool("tcp_nodelay").unwrap_or(true),
        keepalive_timeout: Duration::from_secs(block.get_u64("keepalive_timeout").unwrap_or(65)),
        client_max_body_size: block.get_u64("client_max_body_size").unwrap_or(1048576),
        gzip: interpret_gzip(block),
        access_log: block.get("access_log").map(|d| {
            (
                PathBuf::from(d.first_arg_str()),
                d.args
                    .get(1)
                    .map(|v| v.as_str().to_string())
                    .unwrap_or_else(|| "combined".to_string()),
            )
        }),
        log_formats: HashMap::new(),
    }
}

fn interpret_server(block: &Block) -> ServerConfig {
    let listen = block
        .get_all("listen")
        .iter()
        .map(|d| {
            let addr_str = d
                .first_arg()
                .map(|v| v.to_string_lossy())
                .unwrap_or_default();
            let (host, port) = parse_listen_addr(&addr_str);
            ListenAddr {
                addr: host,
                port,
                ssl: d.args.iter().any(|a| a.as_str() == "ssl"),
                default_server: d.args.iter().any(|a| a.as_str() == "default_server"),
            }
        })
        .collect();

    let server_name = block
        .get_str("server_name")
        .map(|s| s.split_whitespace().map(String::from).collect())
        .unwrap_or_else(|| vec!["".to_string()]);

    let root = block
        .get_str("root")
        .map(PathBuf::from)
        .unwrap_or_else(|| PathBuf::from("html"));
    let index = block
        .get_str("index")
        .map(|s| s.split_whitespace().map(String::from).collect())
        .unwrap_or_else(|| vec!["index.html".to_string()]);

    let locations = block
        .directives
        .iter()
        .filter(|d| d.name == "location")
        .filter_map(|d| d.block.as_ref().map(|b| (d, b)))
        .map(|(d, b)| {
            // A location spec can be one arg (`/api/`) or two (`~* PATTERN`,
            // `= /exact`).  Join them so the modifier and pattern survive.
            let spec = d
                .args
                .iter()
                .map(|v| v.to_string_lossy())
                .collect::<Vec<_>>()
                .join(" ");
            interpret_location(&spec, b)
        })
        .collect();

    let ssl = block.get("ssl_certificate").map(|d| SslConfig {
        certificate: PathBuf::from(d.first_arg_str()),
        certificate_key: PathBuf::from(block.get_str("ssl_certificate_key").unwrap_or("")),
        protocols: vec!["TLSv1.2".to_string(), "TLSv1.3".to_string()],
    });

    let error_pages = HashMap::new(); // TODO: parse error_page directives

    // Server-level `return <code> [target];` (used by the HTTP->HTTPS
    // redirect block).  Note nginx also expresses this via `if ($host = ..)`,
    // but a bare server-level return is the clean, equivalent form.
    let return_directive = block.get("return").map(|d| {
        // The status code lexes as a Number, so read it via to_string_lossy
        // (first_arg_str() returns "" for non-string values).
        let code = d
            .first_arg()
            .map(|v| v.to_string_lossy())
            .unwrap_or_default()
            .parse()
            .unwrap_or(302);
        let target = d.args.get(1).map(|v| v.to_string_lossy());
        (code, target)
    });

    ServerConfig {
        listen,
        server_name,
        root,
        index,
        locations,
        ssl,
        error_pages,
        return_directive,
    }
}

fn interpret_location(path: &str, block: &Block) -> LocationConfig {
    let (match_type, path) = if let Some(rest) = path.strip_prefix('=') {
        (LocationMatch::Exact, rest.trim().to_string())
    } else if let Some(rest) = path.strip_prefix("~*") {
        (LocationMatch::RegexNoCase, rest.trim().to_string())
    } else if let Some(rest) = path.strip_prefix('~') {
        (LocationMatch::Regex, rest.trim().to_string())
    } else {
        (LocationMatch::Prefix, path.to_string())
    };

    // `add_header NAME VALUE;` — collect in declaration order.
    let add_headers: Vec<(String, String)> = block
        .get_all("add_header")
        .iter()
        .filter_map(|d| {
            let name = d.args.first().map(|v| v.to_string_lossy())?;
            let value = d
                .args
                .get(1)
                .map(|v| v.to_string_lossy())
                .unwrap_or_default();
            Some((name, value))
        })
        .collect();

    LocationConfig {
        path,
        match_type,
        root: block.get_str("root").map(PathBuf::from),
        alias: block.get_str("alias").map(PathBuf::from),
        index: block
            .get_str("index")
            .map(|s| s.split_whitespace().map(String::from).collect())
            .unwrap_or_default(),
        proxy_pass: block.get_str("proxy_pass").map(String::from),
        try_files: block
            .get_str("try_files")
            .map(|s| s.split_whitespace().map(String::from).collect())
            .unwrap_or_default(),
        autoindex: block.get_bool("autoindex").unwrap_or(false),
        gzip: block.get_bool("gzip"),
        headers: add_headers.iter().cloned().collect(),
        add_headers,
        expires: block
            .get("expires")
            .and_then(|d| d.first_arg())
            .and_then(expires_to_secs),
        rewrite: Vec::new(),
        return_directive: block.get("return").map(|d| {
            let code = d
                .first_arg()
                .map(|v| v.to_string_lossy())
                .unwrap_or_default()
                .parse()
                .unwrap_or(200);
            let body = d.args.get(1).map(|v| v.to_string_lossy());
            (code, body)
        }),
    }
}

/// Resolve an `expires` argument to seconds.  A time token like `7d` is
/// pre-converted by the lexer to a `Time` value in milliseconds; `max` maps
/// to one year; a bare number is taken as seconds.
fn expires_to_secs(v: &Value) -> Option<u64> {
    match v {
        Value::Time(ms) => Some(ms / 1000),
        Value::Number(n) => Some((*n).max(0) as u64),
        Value::Size(s) => Some(*s),
        Value::String(s) => match s.as_str() {
            "max" => Some(31_536_000),
            _ => None, // "off"/"epoch"/unknown -> no max-age
        },
        _ => None,
    }
}

fn interpret_gzip(block: &Block) -> GzipConfig {
    GzipConfig {
        enabled: block.get_bool("gzip").unwrap_or(false),
        min_length: block.get_u64("gzip_min_length").unwrap_or(1024) as usize,
        types: block
            .get_str("gzip_types")
            .map(|s| s.split_whitespace().map(String::from).collect())
            .unwrap_or_default(),
        comp_level: 1,
    }
}

fn parse_addr_port(s: &str) -> (String, u16) {
    if let Some(colon) = s.rfind(':') {
        let host = s[..colon].to_string();
        let port = s[colon + 1..].parse().unwrap_or(80);
        (host, port)
    } else {
        (s.to_string(), 80)
    }
}

fn parse_listen_addr(s: &str) -> (String, u16) {
    if s.contains(':') {
        parse_addr_port(s)
    } else if let Ok(port) = s.parse::<u16>() {
        ("0.0.0.0".to_string(), port)
    } else {
        (s.to_string(), 80)
    }
}

impl Default for HttpConfig {
    fn default() -> Self {
        Self {
            servers: Vec::new(),
            upstreams: HashMap::new(),
            default_type: "application/octet-stream".to_string(),
            sendfile: true,
            tcp_nopush: true,
            tcp_nodelay: true,
            keepalive_timeout: Duration::from_secs(65),
            client_max_body_size: 1048576,
            gzip: GzipConfig {
                enabled: false,
                min_length: 1024,
                types: Vec::new(),
                comp_level: 1,
            },
            access_log: None,
            log_formats: HashMap::new(),
        }
    }
}