pub mod config;
pub use config::{TlsConfig, TlsError};
use std::collections::HashMap;
use std::fs::File;
use std::io::BufReader;
use std::path::Path;
use std::sync::Arc;
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls::server::{ClientHello, ResolvesServerCert};
use rustls::sign::CertifiedKey;
use rustls::ServerConfig;
use tokio_rustls::TlsAcceptor;
/// A certificate + private key pair loaded for one `server_name`.
struct NamedCert {
names: Vec<String>,
certified: Arc<CertifiedKey>,
}
/// SNI-based certificate resolver.
///
/// nginx terminates TLS for several virtual hosts on the same `:443`
/// listener, each presenting its own certificate selected from the TLS
/// ClientHello's Server Name Indication. This resolver reproduces that:
/// it matches the requested SNI host against the configured certificates
/// (exact match first, then wildcard `*.example.com`) and falls back to the
/// first configured certificate when the client sends no/unknown SNI.
#[derive(Debug)]
pub struct SniResolver {
by_name: HashMap<String, Arc<CertifiedKey>>,
wildcards: Vec<(String, Arc<CertifiedKey>)>,
default: Arc<CertifiedKey>,
}
impl std::fmt::Debug for NamedCert {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NamedCert")
.field("names", &self.names)
.finish()
}
}
impl ResolvesServerCert for SniResolver {
fn resolve(&self, client_hello: ClientHello) -> Option<Arc<CertifiedKey>> {
if let Some(sni) = client_hello.server_name() {
let sni = sni.to_ascii_lowercase();
if let Some(ck) = self.by_name.get(&sni) {
return Some(ck.clone());
}
// Wildcard match: `*.example.com` covers `foo.example.com`.
for (suffix, ck) in &self.wildcards {
if sni.ends_with(suffix) {
return Some(ck.clone());
}
}
}
Some(self.default.clone())
}
}
/// TLS manager that builds a single SNI-aware [`TlsAcceptor`] covering every
/// HTTPS virtual host.
pub struct TlsManager {
certs: Vec<NamedCert>,
}
impl TlsManager {
pub fn new() -> Self {
Self { certs: Vec::new() }
}
/// Register a certificate/key for one or more server names.
pub fn add_cert(
&mut self,
names: &[String],
cert_path: &Path,
key_path: &Path,
) -> Result<(), TlsError> {
let certified = build_certified_key(cert_path, key_path)?;
self.certs.push(NamedCert {
names: names.iter().map(|n| n.to_ascii_lowercase()).collect(),
certified: Arc::new(certified),
});
Ok(())
}
pub fn has_certs(&self) -> bool {
!self.certs.is_empty()
}
/// Build the SNI resolver and wrap it in a `TlsAcceptor`.
pub fn build_acceptor(&self) -> Result<TlsAcceptor, TlsError> {
let default = self
.certs
.first()
.map(|c| c.certified.clone())
.ok_or_else(|| TlsError::ConfigError("no TLS certificates configured".to_string()))?;
let mut by_name = HashMap::new();
let mut wildcards = Vec::new();
for nc in &self.certs {
for name in &nc.names {
if let Some(rest) = name.strip_prefix("*.") {
// Store the matchable suffix including the dot, e.g.
// `*.example.com` -> `.example.com`.
wildcards.push((format!(".{rest}"), nc.certified.clone()));
} else {
by_name.insert(name.clone(), nc.certified.clone());
}
}
}
let resolver = SniResolver {
by_name,
wildcards,
default,
};
let mut config = ServerConfig::builder()
.with_no_client_auth()
.with_cert_resolver(Arc::new(resolver));
// Only advertise HTTP/1.1: this server speaks HTTP/1.1 on the wire,
// and advertising `h2` would let a browser negotiate a protocol we
// do not implement.
config.alpn_protocols = vec![b"http/1.1".to_vec()];
Ok(TlsAcceptor::from(Arc::new(config)))
}
}
impl Default for TlsManager {
fn default() -> Self {
Self::new()
}
}
/// Load a PEM certificate chain + private key and assemble a rustls
/// [`CertifiedKey`] (the unit the SNI resolver hands back per connection).
fn build_certified_key(cert_path: &Path, key_path: &Path) -> Result<CertifiedKey, TlsError> {
let certs = load_certs(cert_path)?;
let key = load_key(key_path)?;
let signing_key = rustls::crypto::ring::sign::any_supported_type(&key).map_err(|e| {
TlsError::InvalidPrivateKey {
path: key_path.to_path_buf(),
reason: e.to_string(),
}
})?;
Ok(CertifiedKey::new(certs, signing_key))
}
fn load_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>, TlsError> {
let file = File::open(path).map_err(|e| TlsError::CertReadError {
path: path.to_path_buf(),
source: e,
})?;
let mut reader = BufReader::new(file);
let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
.collect::<Result<Vec<_>, _>>()
.map_err(|e| TlsError::CertReadError {
path: path.to_path_buf(),
source: e,
})?;
if certs.is_empty() {
return Err(TlsError::NoCertificatesFound {
path: path.to_path_buf(),
});
}
Ok(certs)
}
fn load_key(path: &Path) -> Result<PrivateKeyDer<'static>, TlsError> {
// PKCS#8 first, then PKCS#1 (RSA), then SEC1 (EC).
let open = || -> Result<BufReader<File>, TlsError> {
File::open(path)
.map(BufReader::new)
.map_err(|e| TlsError::KeyReadError {
path: path.to_path_buf(),
source: e,
})
};
if let Ok(Some(key)) = rustls_pemfile::pkcs8_private_keys(&mut open()?)
.next()
.transpose()
{
return Ok(key.into());
}
if let Ok(Some(key)) = rustls_pemfile::rsa_private_keys(&mut open()?)
.next()
.transpose()
{
return Ok(key.into());
}
if let Ok(Some(key)) = rustls_pemfile::ec_private_keys(&mut open()?)
.next()
.transpose()
{
return Ok(key.into());
}
Err(TlsError::NoPrivateKeyFound {
path: path.to_path_buf(),
})
}