Ferrit Explore
中文·繁體·EN·日本語 Sign in Register
cielxl / magpie / src / config.rs
//! Parser for the standard WireGuard `.conf` format, e.g.
//!
//! ```ini
//! [Interface]
//! PrivateKey = <base64>
//! Address    = 10.7.0.2/24
//! DNS        = 1.1.1.1
//! MTU        = 1420
//!
//! [Peer]
//! PublicKey           = <base64>
//! PresharedKey        = <base64>     (optional)
//! Endpoint            = vpn.example.com:51820
//! AllowedIPs          = 0.0.0.0/0
//! PersistentKeepalive = 25
//! ```

use base64::{engine::general_purpose::STANDARD, Engine};
use boringtun::x25519::{PublicKey, StaticSecret};
use ipnet::IpNet;
use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
use std::str::FromStr;

pub struct Config {
    pub private_key: StaticSecret,
    pub addresses: Vec<IpNet>,
    pub dns: Vec<IpAddr>,
    pub mtu: u16,
    pub peer_public: PublicKey,
    pub preshared: Option<[u8; 32]>,
    pub endpoint: SocketAddr,
    pub endpoint_host: String,
    pub allowed_ips: Vec<IpNet>,
    pub keepalive: Option<u16>,
}

fn decode_key(s: &str) -> Result<[u8; 32], String> {
    let v = STANDARD
        .decode(s.trim())
        .map_err(|e| format!("base64 解码失败: {e}"))?;
    if v.len() != 32 {
        return Err(format!("密钥长度 {} 字节,应为 32 字节", v.len()));
    }
    let mut k = [0u8; 32];
    k.copy_from_slice(&v);
    Ok(k)
}

pub fn parse(text: &str) -> Result<Config, String> {
    // Tolerate a UTF-8 BOM (Notepad / PowerShell often prepend one).
    let text = text.strip_prefix('\u{feff}').unwrap_or(text);
    let mut section = String::new();

    let mut private_key: Option<StaticSecret> = None;
    let mut addresses: Vec<IpNet> = Vec::new();
    let mut dns: Vec<IpAddr> = Vec::new();
    let mut mtu: u16 = 1420;
    let mut peer_public: Option<PublicKey> = None;
    let mut preshared: Option<[u8; 32]> = None;
    let mut endpoint_raw: Option<String> = None;
    let mut allowed_ips: Vec<IpNet> = Vec::new();
    let mut keepalive: Option<u16> = None;

    for (i, raw) in text.lines().enumerate() {
        let line = raw.split('#').next().unwrap_or("").trim();
        if line.is_empty() {
            continue;
        }
        if line.starts_with('[') && line.ends_with(']') {
            section = line[1..line.len() - 1].trim().to_lowercase();
            continue;
        }
        let (key, val) = line
            .split_once('=')
            .ok_or_else(|| format!("第 {} 行格式错误: `{line}`", i + 1))?;
        let key = key.trim().to_lowercase();
        let val = val.trim();

        match (section.as_str(), key.as_str()) {
            ("interface", "privatekey") => {
                private_key = Some(StaticSecret::from(decode_key(val)?));
            }
            ("interface", "address") => {
                for part in val.split(',') {
                    let p = part.trim();
                    if p.is_empty() {
                        continue;
                    }
                    addresses.push(IpNet::from_str(p).map_err(|e| format!("Address `{p}`: {e}"))?);
                }
            }
            ("interface", "dns") => {
                for part in val.split(',') {
                    let p = part.trim();
                    if p.is_empty() {
                        continue;
                    }
                    if let Ok(ip) = IpAddr::from_str(p) {
                        dns.push(ip);
                    }
                }
            }
            ("interface", "mtu") => {
                mtu = val.parse().map_err(|_| format!("MTU 无效: {val}"))?;
            }
            ("peer", "publickey") => {
                peer_public = Some(PublicKey::from(decode_key(val)?));
            }
            ("peer", "presharedkey") => {
                preshared = Some(decode_key(val)?);
            }
            ("peer", "endpoint") => {
                endpoint_raw = Some(val.to_string());
            }
            ("peer", "allowedips") => {
                for part in val.split(',') {
                    let p = part.trim();
                    if p.is_empty() {
                        continue;
                    }
                    allowed_ips
                        .push(IpNet::from_str(p).map_err(|e| format!("AllowedIPs `{p}`: {e}"))?);
                }
            }
            ("peer", "persistentkeepalive") => {
                let k: u16 = val.parse().map_err(|_| format!("Keepalive 无效: {val}"))?;
                if k > 0 {
                    keepalive = Some(k);
                }
            }
            _ => { /* ignore unknown keys for forward-compat */ }
        }
    }

    let private_key = private_key.ok_or("缺少 [Interface] PrivateKey")?;
    let peer_public = peer_public.ok_or("缺少 [Peer] PublicKey")?;
    let endpoint_raw = endpoint_raw.ok_or("缺少 [Peer] Endpoint")?;
    if addresses.is_empty() {
        return Err("缺少 [Interface] Address".into());
    }

    // Split host:port (supports bare IPv4 host and [v6]:port).
    let (host, port) = if let Some(rest) = endpoint_raw.strip_prefix('[') {
        let (h, p) = rest
            .split_once("]:")
            .ok_or("Endpoint IPv6 格式应为 [addr]:port")?;
        (h.to_string(), p)
    } else {
        let (h, p) = endpoint_raw
            .rsplit_once(':')
            .ok_or("Endpoint 应为 host:port")?;
        (h.to_string(), p)
    };
    let port: u16 = port.parse().map_err(|_| format!("端口无效: {port}"))?;
    let endpoint = (host.as_str(), port)
        .to_socket_addrs()
        .map_err(|e| format!("无法解析 Endpoint `{host}:{port}`: {e}"))?
        .next()
        .ok_or_else(|| format!("Endpoint `{host}:{port}` 未解析到任何地址"))?;

    Ok(Config {
        private_key,
        addresses,
        dns,
        mtu,
        peer_public,
        preshared,
        endpoint,
        endpoint_host: host,
        allowed_ips,
        keepalive,
    })
}