Ferrit Explore
中文·繁體·EN·日本語 Sign in Register
cielxl / magpie / src / tunnel.rs
//! The live tunnel: creates the wintun adapter, configures IP/routes, and runs
//! three worker threads that move packets between the OS and the WireGuard peer.
//!
//!   adapter ──read──► encapsulate (boringtun) ──► UDP send  ► peer
//!   peer ► UDP recv ──► decapsulate (boringtun) ──write──► adapter
//!   timer tick ──► update_timers ──► UDP send (keepalive / rekey)

use crate::config::Config;
use boringtun::noise::{Tunn, TunnResult};
use ipnet::IpNet;
use std::net::UdpSocket;
use std::process::Command;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;

const ADAPTER_NAME: &str = "Magpie";
const BUF: usize = 65536;

/// Flush stdout immediately so output is visible even when redirected to a
/// file/pipe (Rust block-buffers stdout in that case).
fn flush() {
    use std::io::Write;
    let _ = std::io::stdout().flush();
}

pub fn run(cfg: Config) -> Result<(), String> {
    // --- UDP socket to the peer ----------------------------------------
    let sock = UdpSocket::bind("0.0.0.0:0").map_err(|e| format!("绑定 UDP 失败: {e}"))?;
    sock.connect(cfg.endpoint)
        .map_err(|e| format!("连接 Endpoint 失败: {e}"))?;
    let sock = Arc::new(sock);
    println!("● UDP 已就绪,对端 {} ({})", cfg.endpoint, cfg.endpoint_host);

    // --- wintun virtual adapter ----------------------------------------
    let wintun = unsafe { wintun::load_from_path("wintun.dll") }
        .map_err(|e| format!("加载 wintun.dll 失败 (确认它在 exe 同目录): {e}"))?;
    let adapter = wintun::Adapter::create(&wintun, ADAPTER_NAME, "Magpie WireGuard Tunnel", None)
        .map_err(|e| {
            format!(
                "创建虚拟网卡失败: {e}\n  >> 这一步需要【管理员权限】,请用管理员身份运行。"
            )
        })?;
    let session = Arc::new(
        adapter
            .start_session(wintun::MAX_RING_CAPACITY)
            .map_err(|e| format!("启动网卡会话失败: {e}"))?,
    );
    println!("● 虚拟网卡 \"{ADAPTER_NAME}\" 已创建");

    // --- IP / MTU / routes ---------------------------------------------
    configure_network(&cfg);

    // --- WireGuard state machine (consumes the private key) ------------
    let tunn = Tunn::new(
        cfg.private_key,
        cfg.peer_public,
        cfg.preshared,
        cfg.keepalive,
        0,
        None,
    )
    .map_err(|e| format!("初始化 WireGuard 失败: {e}"))?;
    let tunn = Arc::new(Mutex::new(tunn));

    // --- worker threads ------------------------------------------------
    let connected = Arc::new(AtomicBool::new(false));

    // Thread A: adapter -> encapsulate -> UDP
    {
        let tunn = tunn.clone();
        let sock = sock.clone();
        let session = session.clone();
        std::thread::spawn(move || {
            let mut out = vec![0u8; BUF];
            loop {
                match session.receive_blocking() {
                    Ok(packet) => {
                        let res = tunn.lock().unwrap().encapsulate(packet.bytes(), &mut out);
                        if let TunnResult::WriteToNetwork(b) = res {
                            let _ = sock.send(b);
                        }
                    }
                    Err(e) => {
                        eprintln!("网卡读取结束: {e}");
                        break;
                    }
                }
            }
        });
    }

    // Thread B: UDP -> decapsulate -> adapter
    {
        let tunn = tunn.clone();
        let sock = sock.clone();
        let session = session.clone();
        let connected = connected.clone();
        std::thread::spawn(move || {
            let mut udp = vec![0u8; BUF];
            loop {
                let n = match sock.recv(&mut udp) {
                    Ok(n) => n,
                    Err(e) => {
                        eprintln!("UDP 读取错误: {e}");
                        continue;
                    }
                };
                // A single datagram may produce a network reply (handshake) and
                // then several queued tunnel packets; drain until Done.
                let mut input: &[u8] = &udp[..n];
                loop {
                    let mut out = vec![0u8; BUF];
                    let res = tunn.lock().unwrap().decapsulate(None, input, &mut out);
                    match res {
                        TunnResult::WriteToNetwork(b) => {
                            let _ = sock.send(b);
                            input = &[]; // drain queued packets
                        }
                        TunnResult::WriteToTunnelV4(b, _) | TunnResult::WriteToTunnelV6(b, _) => {
                            if !connected.swap(true, Ordering::SeqCst) {
                                println!("✓ 隧道已连通,开始转发数据流量。");
                                flush();
                            }
                            if let Ok(mut p) = session.allocate_send_packet(b.len() as u16) {
                                p.bytes_mut().copy_from_slice(b);
                                session.send_packet(p);
                            }
                            break;
                        }
                        _ => break,
                    }
                }
            }
        });
    }

    // Thread C: periodic timers (keepalive / rekey) + connection watchdog
    {
        let tunn = tunn.clone();
        let sock = sock.clone();
        let connected = connected.clone();
        let handshook = Arc::new(AtomicBool::new(false));
        std::thread::spawn(move || {
            // Kick off the handshake immediately so the user sees activity.
            kick_handshake(&tunn, &sock);
            let mut ticks: u32 = 0;
            loop {
                std::thread::sleep(Duration::from_millis(250));
                let mut out = vec![0u8; BUF];
                let res = tunn.lock().unwrap().update_timers(&mut out);
                if let TunnResult::WriteToNetwork(b) = res {
                    let _ = sock.send(b);
                }
                // Announce as soon as the Noise handshake itself completes,
                // even before any user traffic flows.
                let done = tunn.lock().unwrap().time_since_last_handshake().is_some();
                if done && !handshook.swap(true, Ordering::SeqCst) {
                    println!("✓ 与服务器握手成功,安全会话已建立。");
                    flush();
                }
                ticks += 1;
                // Every ~3s, if still not connected, re-send a handshake init.
                if ticks % 12 == 0 && !connected.load(Ordering::SeqCst) {
                    kick_handshake(&tunn, &sock);
                }
            }
        });
    }

    println!("\n鹊桥已启动。按 Ctrl+C 断开。\n");
    flush();

    // Keep the adapter (and process) alive until interrupted.
    loop {
        std::thread::sleep(Duration::from_secs(3600));
    }
}

fn kick_handshake(tunn: &Arc<Mutex<Tunn>>, sock: &UdpSocket) {
    let mut out = vec![0u8; BUF];
    let res = tunn
        .lock()
        .unwrap()
        .format_handshake_initiation(&mut out, false);
    if let TunnResult::WriteToNetwork(b) = res {
        let _ = sock.send(b);
    }
}

fn configure_network(cfg: &Config) {
    println!("● 配置网卡地址与路由:");

    // IP addresses + MTU
    for net in &cfg.addresses {
        match net {
            IpNet::V4(n) => {
                run_cmd(
                    "netsh",
                    &[
                        "interface",
                        "ipv4",
                        "set",
                        "address",
                        &format!("name={ADAPTER_NAME}"),
                        "static",
                        &n.addr().to_string(),
                        &n.netmask().to_string(),
                    ],
                );
                run_cmd(
                    "netsh",
                    &[
                        "interface",
                        "ipv4",
                        "set",
                        "subinterface",
                        ADAPTER_NAME,
                        &format!("mtu={}", cfg.mtu),
                        "store=active",
                    ],
                );
            }
            IpNet::V6(n) => {
                run_cmd(
                    "netsh",
                    &[
                        "interface",
                        "ipv6",
                        "set",
                        "address",
                        &format!("interface={ADAPTER_NAME}"),
                        &format!("address={}/{}", n.addr(), n.prefix_len()),
                    ],
                );
            }
        }
    }

    // DNS (best effort)
    if let Some(first) = cfg.dns.iter().find(|ip| ip.is_ipv4()) {
        run_cmd(
            "netsh",
            &[
                "interface",
                "ipv4",
                "set",
                "dnsservers",
                &format!("name={ADAPTER_NAME}"),
                "static",
                &first.to_string(),
                "primary",
            ],
        );
    }

    // Routes for AllowedIPs. The default-gateway lookup is only needed for a
    // full tunnel (0.0.0.0/0), so compute it lazily — it must never block the
    // common split-tunnel path.
    let need_gw = cfg
        .allowed_ips
        .iter()
        .any(|n| matches!(n, IpNet::V4(x) if x.prefix_len() == 0));
    let default_gw = if need_gw { default_gateway() } else { None };
    for net in &cfg.allowed_ips {
        match net {
            IpNet::V4(n) if n.prefix_len() == 0 => {
                // Full tunnel: pin the encrypted Endpoint via the real gateway
                // to avoid a routing loop, then split the default route in two.
                if let (Some(gw), std::net::IpAddr::V4(ep)) = (&default_gw, cfg.endpoint.ip()) {
                    run_cmd(
                        "route",
                        &["add", &ep.to_string(), "mask", "255.255.255.255", gw],
                    );
                } else {
                    println!("  ! 未找到默认网关,全隧道路由可能形成环路。");
                }
                add_tun_route("0.0.0.0/1");
                add_tun_route("128.0.0.0/1");
            }
            _ => add_tun_route(&net.to_string()),
        }
    }
}

fn add_tun_route(prefix: &str) {
    run_cmd(
        "netsh",
        &[
            "interface",
            "ipv4",
            "add",
            "route",
            &format!("prefix={prefix}"),
            &format!("interface={ADAPTER_NAME}"),
            "store=active",
        ],
    );
}

/// Find the current IPv4 default gateway by parsing `route print 0.0.0.0`
/// (fast, no WMI). Looks for the active route `0.0.0.0  0.0.0.0  <gateway> ...`.
fn default_gateway() -> Option<String> {
    let out = Command::new("route").args(["print", "0.0.0.0"]).output().ok()?;
    let text = String::from_utf8_lossy(&out.stdout);
    for line in text.lines() {
        let cols: Vec<&str> = line.split_whitespace().collect();
        if cols.len() >= 3 && cols[0] == "0.0.0.0" && cols[1] == "0.0.0.0" {
            let gw = cols[2];
            if gw != "On-link" && gw.parse::<std::net::Ipv4Addr>().is_ok() {
                return Some(gw.to_string());
            }
        }
    }
    None
}

fn run_cmd(prog: &str, args: &[&str]) -> bool {
    let r = run_cmd_inner(prog, args);
    flush();
    r
}

fn run_cmd_inner(prog: &str, args: &[&str]) -> bool {
    match Command::new(prog).args(args).output() {
        Ok(o) if o.status.success() => {
            println!("  ✓ {prog} {}", args.join(" "));
            true
        }
        Ok(o) => {
            let msg = String::from_utf8_lossy(&o.stderr);
            let msg = if msg.trim().is_empty() {
                String::from_utf8_lossy(&o.stdout).trim().to_string()
            } else {
                msg.trim().to_string()
            };
            // A route that already exists is the desired end-state, not a failure
            // (Windows auto-creates the on-link route for the interface subnet).
            if msg.contains("already exists") || msg.contains("已存在") {
                println!("  ✓ {prog} {} (路由已存在)", args.join(" "));
                true
            } else {
                println!("  ✗ {prog} {} -> {msg}", args.join(" "));
                false
            }
        }
        Err(e) => {
            println!("  ✗ {prog} {} -> {e}", args.join(" "));
            false
        }
    }
}