Ferrit Explore
中文·繁體·EN·日本語 Sign in Register
cielxl / veld / src / core / worker.rs
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::broadcast;
use tracing::{debug, info, warn};

use crate::config::ParsedConfig;
use crate::core::pipeline::PipelineProcessor;
use crate::util::signal::ServerSignal;

/// Worker statistics
pub struct WorkerStats {
    pub active_connections: Arc<AtomicUsize>,
    pub total_connections: AtomicU64,
    pub requests_handled: AtomicU64,
}

impl Default for WorkerStats {
    fn default() -> Self {
        Self::new()
    }
}

impl WorkerStats {
    pub fn new() -> Self {
        Self {
            active_connections: Arc::new(AtomicUsize::new(0)),
            total_connections: AtomicU64::new(0),
            requests_handled: AtomicU64::new(0),
        }
    }
}

/// Worker process that accepts and handles connections
pub struct Worker {
    pub id: usize,
    pipeline: Arc<PipelineProcessor>,
    stats: Arc<WorkerStats>,
    connection_id: AtomicU64,
}

impl Worker {
    pub fn new(id: usize, config: Arc<ParsedConfig>) -> Self {
        let pipeline = Arc::new(PipelineProcessor::new(config));
        Self {
            id,
            pipeline,
            stats: Arc::new(WorkerStats::new()),
            connection_id: AtomicU64::new(0),
        }
    }

    /// Run the worker event loop
    pub async fn run(
        &self,
        listener: Arc<TcpListener>,
        mut shutdown_rx: broadcast::Receiver<ServerSignal>,
        mut reload_rx: broadcast::Receiver<ServerSignal>,
    ) {
        info!("Worker {} started", self.id);

        loop {
            tokio::select! {
                // Accept new connections
                accept_result = listener.accept() => {
                    match accept_result {
                        Ok((stream, addr)) => {
                            debug!("Worker {} accepted connection from {}", self.id, addr);

                            let conn_id = self.connection_id.fetch_add(1, Ordering::Relaxed);
                            self.stats.total_connections.fetch_add(1, Ordering::Relaxed);
                            self.stats.active_connections.fetch_add(1, Ordering::Relaxed);

                            let pipeline = self.pipeline.clone();
                            let active = Arc::clone(&self.stats.active_connections);

                            tokio::spawn(async move {
                                handle_connection(conn_id, stream, addr, pipeline).await;
                                active.fetch_sub(1, Ordering::Relaxed);
                            });
                        }
                        Err(e) => {
                            // Handle accept errors gracefully
                            if e.kind() == std::io::ErrorKind::ConnectionAborted
                                || e.kind() == std::io::ErrorKind::ConnectionReset
                            {
                                continue;
                            }
                            warn!("Worker {} accept error: {}", self.id, e);
                            // Brief pause to avoid tight loop on persistent errors
                            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
                        }
                    }
                }

                // Handle shutdown signal
                signal = shutdown_rx.recv() => {
                    match signal {
                        Ok(ServerSignal::Shutdown) => {
                            info!("Worker {} received shutdown signal", self.id);
                            break;
                        }
                        Ok(_) => {}
                        Err(broadcast::error::RecvError::Lagged(_)) => {}
                        Err(broadcast::error::RecvError::Closed) => break,
                    }
                }

                // Handle reload signal
                signal = reload_rx.recv() => {
                    match signal {
                        Ok(ServerSignal::Reload) => {
                            info!("Worker {} received reload signal", self.id);
                            // In a full implementation, we'd reload the config here
                        }
                        Ok(_) => {}
                        Err(broadcast::error::RecvError::Lagged(_)) => {}
                        Err(broadcast::error::RecvError::Closed) => {}
                    }
                }
            }
        }

        // Wait for active connections to drain
        let active = self.stats.active_connections.load(Ordering::Relaxed);
        if active > 0 {
            info!(
                "Worker {} waiting for {} active connections to drain",
                self.id, active
            );
            let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(30);
            while self.stats.active_connections.load(Ordering::Relaxed) > 0 {
                if tokio::time::Instant::now() >= deadline {
                    warn!(
                        "Worker {} shutdown timeout, {} connections still active",
                        self.id,
                        self.stats.active_connections.load(Ordering::Relaxed)
                    );
                    break;
                }
                tokio::time::sleep(std::time::Duration::from_millis(100)).await;
            }
        }

        info!("Worker {} stopped", self.id);
    }
}

async fn handle_connection(
    _id: u64,
    mut stream: tokio::net::TcpStream,
    _addr: std::net::SocketAddr,
    pipeline: std::sync::Arc<PipelineProcessor>,
) {
    use tokio::io::{AsyncReadExt, AsyncWriteExt};

    let mut buf = vec![0u8; 8192];
    let mut parser = crate::http::parser::HttpParser::new();

    loop {
        let n =
            match tokio::time::timeout(std::time::Duration::from_secs(60), stream.read(&mut buf))
                .await
            {
                Ok(Ok(0)) => return,
                Ok(Ok(n)) => n,
                Ok(Err(_)) | Err(_) => return,
            };

        match parser.feed(&buf[..n]) {
            crate::http::parser::ParseResult::Complete(_) => {
                if let Some(request) = parser.take_request() {
                    let keep_alive = request.is_keep_alive();
                    let response = pipeline.process(&request, "http", 80).await;
                    let response_bytes = response.to_bytes();
                    let _ = stream.write_all(&response_bytes).await;

                    if !keep_alive {
                        return;
                    }
                    parser.reset();
                }
            }
            crate::http::parser::ParseResult::NeedMore => continue,
            crate::http::parser::ParseResult::Error(_) => {
                let response = crate::http::response::Response::bad_request();
                let _ = stream.write_all(&response.to_bytes()).await;
                return;
            }
        }
    }
}