Ferrit Explore
中文·繁體·EN·日本語 Sign in Register
cielxl / veld / src / tls / config.rs
use std::fs::File;
use std::io::BufReader;
use std::path::{Path, PathBuf};

use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls::ServerConfig;

/// TLS error type
#[derive(Debug)]
pub enum TlsError {
    CertReadError {
        path: PathBuf,
        source: std::io::Error,
    },
    KeyReadError {
        path: PathBuf,
        source: std::io::Error,
    },
    NoCertificatesFound {
        path: PathBuf,
    },
    NoPrivateKeyFound {
        path: PathBuf,
    },
    InvalidPrivateKey {
        path: PathBuf,
        reason: String,
    },
    ConfigError(String),
}

impl std::fmt::Display for TlsError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            TlsError::CertReadError { path, source } => {
                write!(f, "failed to read cert `{}`: {}", path.display(), source)
            }
            TlsError::KeyReadError { path, source } => {
                write!(f, "failed to read key `{}`: {}", path.display(), source)
            }
            TlsError::NoCertificatesFound { path } => {
                write!(f, "no certificates found in `{}`", path.display())
            }
            TlsError::NoPrivateKeyFound { path } => {
                write!(f, "no private key found in `{}`", path.display())
            }
            TlsError::InvalidPrivateKey { path, reason } => {
                write!(f, "invalid key in `{}`: {}", path.display(), reason)
            }
            TlsError::ConfigError(msg) => write!(f, "TLS config error: {}", msg),
        }
    }
}

impl std::error::Error for TlsError {}

/// TLS configuration for a server
#[derive(Debug, Clone)]
pub struct TlsConfig {
    pub cert_path: PathBuf,
    pub key_path: PathBuf,
    pub protocols: Vec<String>,
    pub prefer_server_ciphers: bool,
}

impl TlsConfig {
    pub fn new(cert_path: PathBuf, key_path: PathBuf) -> Self {
        Self {
            cert_path,
            key_path,
            protocols: vec!["TLSv1.2".to_string(), "TLSv1.3".to_string()],
            prefer_server_ciphers: true,
        }
    }

    /// Build a rustls ServerConfig from this TLS configuration
    pub fn build_server_config(&self) -> Result<ServerConfig, TlsError> {
        let certs = self.load_certs(&self.cert_path)?;
        let key = self.load_key(&self.key_path)?;

        let mut config = ServerConfig::builder()
            .with_no_client_auth()
            .with_single_cert(certs, key)
            .map_err(|e| TlsError::ConfigError(e.to_string()))?;

        // Configure ALPN protocols
        config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];

        Ok(config)
    }

    /// Load certificates from a PEM file
    fn load_certs(&self, path: &Path) -> Result<Vec<CertificateDer<'static>>, TlsError> {
        let file = File::open(path).map_err(|e| TlsError::CertReadError {
            path: path.to_path_buf(),
            source: e,
        })?;

        let mut reader = BufReader::new(file);
        let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
            .collect::<Result<Vec<_>, _>>()
            .map_err(|e| TlsError::CertReadError {
                path: path.to_path_buf(),
                source: e,
            })?;

        if certs.is_empty() {
            return Err(TlsError::NoCertificatesFound {
                path: path.to_path_buf(),
            });
        }

        Ok(certs)
    }

    /// Load a private key from a PEM file
    fn load_key(&self, path: &Path) -> Result<PrivateKeyDer<'static>, TlsError> {
        let file = File::open(path).map_err(|e| TlsError::KeyReadError {
            path: path.to_path_buf(),
            source: e,
        })?;

        let mut reader = BufReader::new(file);

        // Try PKCS8 first, then RSA, then EC
        if let Ok(Some(key)) = rustls_pemfile::pkcs8_private_keys(&mut reader)
            .next()
            .transpose()
        {
            return Ok(key.into());
        }

        // Reset reader
        let file = File::open(path).map_err(|e| TlsError::KeyReadError {
            path: path.to_path_buf(),
            source: e,
        })?;
        let mut reader = BufReader::new(file);

        if let Ok(Some(key)) = rustls_pemfile::rsa_private_keys(&mut reader)
            .next()
            .transpose()
        {
            return Ok(key.into());
        }

        Err(TlsError::NoPrivateKeyFound {
            path: path.to_path_buf(),
        })
    }
}