Ferrit Explore
中文·繁體·EN·日本語 Sign in Register
cielxl / veld / src / tls / mod.rs
pub mod config;

pub use config::{TlsConfig, TlsError};

use std::collections::HashMap;
use std::fs::File;
use std::io::BufReader;
use std::path::Path;
use std::sync::Arc;

use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls::server::{ClientHello, ResolvesServerCert};
use rustls::sign::CertifiedKey;
use rustls::ServerConfig;
use tokio_rustls::TlsAcceptor;

/// A certificate + private key pair loaded for one `server_name`.
struct NamedCert {
    names: Vec<String>,
    certified: Arc<CertifiedKey>,
}

/// SNI-based certificate resolver.
///
/// nginx terminates TLS for several virtual hosts on the same `:443`
/// listener, each presenting its own certificate selected from the TLS
/// ClientHello's Server Name Indication.  This resolver reproduces that:
/// it matches the requested SNI host against the configured certificates
/// (exact match first, then wildcard `*.example.com`) and falls back to the
/// first configured certificate when the client sends no/unknown SNI.
#[derive(Debug)]
pub struct SniResolver {
    by_name: HashMap<String, Arc<CertifiedKey>>,
    wildcards: Vec<(String, Arc<CertifiedKey>)>,
    default: Arc<CertifiedKey>,
}

impl std::fmt::Debug for NamedCert {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("NamedCert")
            .field("names", &self.names)
            .finish()
    }
}

impl ResolvesServerCert for SniResolver {
    fn resolve(&self, client_hello: ClientHello) -> Option<Arc<CertifiedKey>> {
        if let Some(sni) = client_hello.server_name() {
            let sni = sni.to_ascii_lowercase();
            if let Some(ck) = self.by_name.get(&sni) {
                return Some(ck.clone());
            }
            // Wildcard match: `*.example.com` covers `foo.example.com`.
            for (suffix, ck) in &self.wildcards {
                if sni.ends_with(suffix) {
                    return Some(ck.clone());
                }
            }
        }
        Some(self.default.clone())
    }
}

/// TLS manager that builds a single SNI-aware [`TlsAcceptor`] covering every
/// HTTPS virtual host.
pub struct TlsManager {
    certs: Vec<NamedCert>,
}

impl TlsManager {
    pub fn new() -> Self {
        Self { certs: Vec::new() }
    }

    /// Register a certificate/key for one or more server names.
    pub fn add_cert(
        &mut self,
        names: &[String],
        cert_path: &Path,
        key_path: &Path,
    ) -> Result<(), TlsError> {
        let certified = build_certified_key(cert_path, key_path)?;
        self.certs.push(NamedCert {
            names: names.iter().map(|n| n.to_ascii_lowercase()).collect(),
            certified: Arc::new(certified),
        });
        Ok(())
    }

    pub fn has_certs(&self) -> bool {
        !self.certs.is_empty()
    }

    /// Build the SNI resolver and wrap it in a `TlsAcceptor`.
    pub fn build_acceptor(&self) -> Result<TlsAcceptor, TlsError> {
        let default = self
            .certs
            .first()
            .map(|c| c.certified.clone())
            .ok_or_else(|| TlsError::ConfigError("no TLS certificates configured".to_string()))?;

        let mut by_name = HashMap::new();
        let mut wildcards = Vec::new();
        for nc in &self.certs {
            for name in &nc.names {
                if let Some(rest) = name.strip_prefix("*.") {
                    // Store the matchable suffix including the dot, e.g.
                    // `*.example.com` -> `.example.com`.
                    wildcards.push((format!(".{rest}"), nc.certified.clone()));
                } else {
                    by_name.insert(name.clone(), nc.certified.clone());
                }
            }
        }

        let resolver = SniResolver {
            by_name,
            wildcards,
            default,
        };

        let mut config = ServerConfig::builder()
            .with_no_client_auth()
            .with_cert_resolver(Arc::new(resolver));

        // Only advertise HTTP/1.1: this server speaks HTTP/1.1 on the wire,
        // and advertising `h2` would let a browser negotiate a protocol we
        // do not implement.
        config.alpn_protocols = vec![b"http/1.1".to_vec()];

        Ok(TlsAcceptor::from(Arc::new(config)))
    }
}

impl Default for TlsManager {
    fn default() -> Self {
        Self::new()
    }
}

/// Load a PEM certificate chain + private key and assemble a rustls
/// [`CertifiedKey`] (the unit the SNI resolver hands back per connection).
fn build_certified_key(cert_path: &Path, key_path: &Path) -> Result<CertifiedKey, TlsError> {
    let certs = load_certs(cert_path)?;
    let key = load_key(key_path)?;
    let signing_key = rustls::crypto::ring::sign::any_supported_type(&key).map_err(|e| {
        TlsError::InvalidPrivateKey {
            path: key_path.to_path_buf(),
            reason: e.to_string(),
        }
    })?;
    Ok(CertifiedKey::new(certs, signing_key))
}

fn load_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>, TlsError> {
    let file = File::open(path).map_err(|e| TlsError::CertReadError {
        path: path.to_path_buf(),
        source: e,
    })?;
    let mut reader = BufReader::new(file);
    let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
        .collect::<Result<Vec<_>, _>>()
        .map_err(|e| TlsError::CertReadError {
            path: path.to_path_buf(),
            source: e,
        })?;
    if certs.is_empty() {
        return Err(TlsError::NoCertificatesFound {
            path: path.to_path_buf(),
        });
    }
    Ok(certs)
}

fn load_key(path: &Path) -> Result<PrivateKeyDer<'static>, TlsError> {
    // PKCS#8 first, then PKCS#1 (RSA), then SEC1 (EC).
    let open = || -> Result<BufReader<File>, TlsError> {
        File::open(path)
            .map(BufReader::new)
            .map_err(|e| TlsError::KeyReadError {
                path: path.to_path_buf(),
                source: e,
            })
    };

    if let Ok(Some(key)) = rustls_pemfile::pkcs8_private_keys(&mut open()?)
        .next()
        .transpose()
    {
        return Ok(key.into());
    }
    if let Ok(Some(key)) = rustls_pemfile::rsa_private_keys(&mut open()?)
        .next()
        .transpose()
    {
        return Ok(key.into());
    }
    if let Ok(Some(key)) = rustls_pemfile::ec_private_keys(&mut open()?)
        .next()
        .transpose()
    {
        return Ok(key.into());
    }
    Err(TlsError::NoPrivateKeyFound {
        path: path.to_path_buf(),
    })
}