Ferrit Explore
中文·繁體·EN·日本語 Sign in Register
cielxl / veld / src / http / parser.rs
use super::headers::HeaderMap;
use super::request::{Method, Request, Version};
use bytes::{Buf, BytesMut};
use std::io;

/// Parser state
#[derive(Debug, Clone, PartialEq)]
pub enum ParseState {
    ParsingRequestLine,
    ParsingHeaders,
    ParsingBody(usize), // remaining bytes
    ParsingChunked,
    Complete,
    Error(String),
}

/// Parse result from feeding data
#[derive(Debug)]
pub enum ParseResult {
    NeedMore,
    Complete(usize), // bytes consumed
    Error(ParseError),
}

/// Parse error
#[derive(Debug)]
pub struct ParseError {
    pub message: String,
}

impl std::fmt::Display for ParseError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "HTTP parse error: {}", self.message)
    }
}

impl std::error::Error for ParseError {}

impl From<io::Error> for ParseError {
    fn from(e: io::Error) -> Self {
        ParseError {
            message: e.to_string(),
        }
    }
}

/// Streaming HTTP/1.1 parser
pub struct HttpParser {
    state: ParseState,
    buffer: BytesMut,
    method: Option<Method>,
    uri: Option<String>,
    version: Option<Version>,
    headers: HeaderMap,
    content_length: Option<usize>,
    chunk_size: Option<usize>,
    total_parsed: usize,
    /// Accumulated request body (decoded; chunked framing removed).
    body: BytesMut,
}

impl HttpParser {
    pub fn new() -> Self {
        Self {
            state: ParseState::ParsingRequestLine,
            buffer: BytesMut::with_capacity(8192),
            method: None,
            uri: None,
            version: None,
            headers: HeaderMap::new(),
            content_length: None,
            chunk_size: None,
            total_parsed: 0,
            body: BytesMut::new(),
        }
    }

    /// Feed data into the parser. Returns the parse result.
    pub fn feed(&mut self, data: &[u8]) -> ParseResult {
        self.buffer.extend_from_slice(data);

        loop {
            match &self.state {
                ParseState::ParsingRequestLine => {
                    if let Some(pos) = find_crlf(&self.buffer) {
                        // Use split_to to extract the line without copying.
                        // split_to returns a BytesMut containing [0..pos].
                        let line = self.buffer.split_to(pos);
                        self.buffer.advance(2); // skip \r\n
                        self.total_parsed += pos + 2;

                        match parse_request_line(&line) {
                            Ok((method, uri, version)) => {
                                self.method = Some(method);
                                self.uri = Some(uri);
                                self.version = Some(version);
                                self.state = ParseState::ParsingHeaders;
                            }
                            Err(e) => {
                                self.state = ParseState::Error(e.message.clone());
                                return ParseResult::Error(e);
                            }
                        }
                    } else {
                        return ParseResult::NeedMore;
                    }
                }
                ParseState::ParsingHeaders => {
                    if let Some(pos) = find_crlf(&self.buffer) {
                        if pos == 0 {
                            // Empty line = end of headers
                            self.buffer.advance(2);
                            self.total_parsed += 2;

                            self.content_length = self
                                .headers
                                .get("content-length")
                                .and_then(|cl| cl.parse().ok());

                            let transfer_encoding = self
                                .headers
                                .get("transfer-encoding")
                                .map(|s| s.to_lowercase());

                            if transfer_encoding.as_deref() == Some("chunked") {
                                self.state = ParseState::ParsingChunked;
                            } else if let Some(len) = self.content_length {
                                if len > 0 {
                                    self.state = ParseState::ParsingBody(len);
                                } else {
                                    self.state = ParseState::Complete;
                                    return ParseResult::Complete(self.total_parsed);
                                }
                            } else {
                                self.state = ParseState::Complete;
                                return ParseResult::Complete(self.total_parsed);
                            }
                        } else {
                            // Use split_to to extract the line without copying.
                            let line = self.buffer.split_to(pos);
                            self.buffer.advance(2); // skip \r\n
                            self.total_parsed += pos + 2;

                            match parse_header_line(&line) {
                                Ok((name, value)) => {
                                    self.headers.insert(name, value);
                                }
                                Err(e) => {
                                    self.state = ParseState::Error(e.message.clone());
                                    return ParseResult::Error(e);
                                }
                            }
                        }
                    } else {
                        return ParseResult::NeedMore;
                    }
                }
                ParseState::ParsingBody(remaining) => {
                    if self.buffer.len() >= *remaining {
                        self.total_parsed += remaining;
                        let chunk = self.buffer.split_to(*remaining);
                        self.body.extend_from_slice(&chunk);
                        self.state = ParseState::Complete;
                        return ParseResult::Complete(self.total_parsed);
                    } else {
                        return ParseResult::NeedMore;
                    }
                }
                ParseState::ParsingChunked => {
                    // Parse chunked transfer encoding
                    loop {
                        if let Some(size) = self.chunk_size {
                            if size == 0 {
                                // Last chunk
                                if self.buffer.len() >= 2 && self.buffer[..2] == *b"\r\n" {
                                    let _ = self.buffer.split_to(2);
                                    self.total_parsed += 2;
                                    self.state = ParseState::Complete;
                                    return ParseResult::Complete(self.total_parsed);
                                }
                                return ParseResult::NeedMore;
                            }
                            if self.buffer.len() >= size + 2 {
                                let chunk = self.buffer.split_to(size + 2);
                                // Keep the chunk data, drop the trailing CRLF.
                                self.body.extend_from_slice(&chunk[..size]);
                                self.total_parsed += size + 2;
                                self.chunk_size = None;
                            } else {
                                return ParseResult::NeedMore;
                            }
                        } else if let Some(pos) = find_crlf(&self.buffer) {
                            let size_str = std::str::from_utf8(&self.buffer[..pos]).unwrap_or("0");
                            let size = usize::from_str_radix(size_str.trim(), 16).unwrap_or(0);
                            let _ = self.buffer.split_to(pos + 2);
                            self.total_parsed += pos + 2;
                            self.chunk_size = Some(size);
                        } else {
                            return ParseResult::NeedMore;
                        }
                    }
                }
                ParseState::Complete => {
                    return ParseResult::Complete(self.total_parsed);
                }
                ParseState::Error(msg) => {
                    return ParseResult::Error(ParseError {
                        message: msg.clone(),
                    });
                }
            }
        }
    }

    /// Check if parsing is complete
    pub fn is_complete(&self) -> bool {
        matches!(self.state, ParseState::Complete)
    }

    /// Take the parsed request, consuming the parser state
    pub fn take_request(&mut self) -> Option<Request> {
        if !self.is_complete() {
            return None;
        }

        let method = self.method.take()?;
        let uri = self.uri.take()?;
        let version = self.version.take()?;
        let headers = std::mem::take(&mut self.headers);

        let mut request = Request::new(method, uri, version);
        request.headers = headers;
        if !self.body.is_empty() {
            request.body = Some(std::mem::take(&mut self.body).freeze());
        }
        Some(request)
    }

    /// Reset the parser for the next request (keep-alive)
    pub fn reset(&mut self) {
        self.state = ParseState::ParsingRequestLine;
        self.buffer.clear();
        self.method = None;
        self.uri = None;
        self.version = None;
        self.headers = HeaderMap::new();
        self.content_length = None;
        self.chunk_size = None;
        self.total_parsed = 0;
        self.body.clear();
    }

    /// Get remaining buffered data
    pub fn remaining(&self) -> &[u8] {
        &self.buffer
    }
}

impl Default for HttpParser {
    fn default() -> Self {
        Self::new()
    }
}

fn find_crlf(buf: &[u8]) -> Option<usize> {
    buf.windows(2).position(|w| w == b"\r\n")
}

/// Parse the HTTP request line: `METHOD URI HTTP/VERSION\r\n`.
///
/// Works directly on the byte slice to avoid an intermediate `String`
/// allocation for the entire line.  Only the URI (which may contain
/// arbitrary percent-encoded bytes) is converted to an owned `String`.
fn parse_request_line(line: &[u8]) -> Result<(Method, String, Version), ParseError> {
    // Find the two spaces that delimit method, URI, and version.
    let Some(sp1) = line.iter().position(|&b| b == b' ') else {
        return Err(ParseError {
            message: "Invalid request line: missing space".to_string(),
        });
    };
    let rest = &line[sp1 + 1..];
    let Some(sp2) = rest.iter().position(|&b| b == b' ') else {
        return Err(ParseError {
            message: "Invalid request line: missing second space".to_string(),
        });
    };

    let method_bytes = &line[..sp1];
    let uri_bytes = &rest[..sp2];
    let version_bytes = &rest[sp2 + 1..];

    let method = Method::from_bytes(method_bytes).ok_or_else(|| {
        // Safe: method tokens are always ASCII.
        let m = unsafe { String::from_utf8_unchecked(method_bytes.to_vec()) };
        ParseError {
            message: format!("Unknown method: {}", m),
        }
    })?;

    // URI is typically ASCII; use from_utf8_unchecked to skip validation.
    // If it contains non-UTF-8 bytes, we still store them losslessly.
    let uri = unsafe { String::from_utf8_unchecked(uri_bytes.to_vec()) };

    let version = match version_bytes {
        b"HTTP/1.0" => Version::Http10,
        b"HTTP/1.1" => Version::Http11,
        b"HTTP/2" | b"HTTP/2.0" => Version::Http20,
        _ => {
            let v = unsafe { String::from_utf8_unchecked(version_bytes.to_vec()) };
            return Err(ParseError {
                message: format!("Unknown version: {}", v),
            });
        }
    };

    Ok((method, uri, version))
}

/// Parse a single HTTP header line: `Name: value\r\n`.
///
/// Avoids allocating a `String` for the entire line by working directly
/// on the byte slice and only converting the trimmed name/value to
/// `String` using `from_utf8_unchecked` (HTTP header names and values
/// are constrained to ASCII / ISO-8859-1, which is a subset of UTF-8
/// for the characters that are legal in headers).
fn parse_header_line(line: &[u8]) -> Result<(String, String), ParseError> {
    // Find the colon separator.
    let colon_pos = line
        .iter()
        .position(|&b| b == b':')
        .ok_or_else(|| ParseError {
            message: "Invalid header line: missing colon".to_string(),
        })?;

    let name_bytes = trim_ascii(&line[..colon_pos]);
    let value_bytes = trim_ascii(&line[colon_pos + 1..]);

    // HTTP header names are restricted to visible ASCII characters
    // (RFC 9110 section 5.1), so from_utf8_unchecked is safe here.
    let name = unsafe { String::from_utf8_unchecked(name_bytes.to_vec()) };
    let value = unsafe { String::from_utf8_unchecked(value_bytes.to_vec()) };

    Ok((name, value))
}

/// Trim leading and trailing ASCII whitespace from a byte slice.
///
/// This avoids the allocation that `str::trim()` would require when
/// operating on a `&str` converted from bytes.
#[inline]
fn trim_ascii(bytes: &[u8]) -> &[u8] {
    let start = bytes
        .iter()
        .position(|&b| !b.is_ascii_whitespace())
        .unwrap_or(bytes.len());
    let end = bytes
        .iter()
        .rposition(|&b| !b.is_ascii_whitespace())
        .map(|p| p + 1)
        .unwrap_or(start);
    &bytes[start..end]
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_parse_simple_get() {
        let mut parser = HttpParser::new();
        let data = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n";
        let result = parser.feed(data);
        assert!(matches!(result, ParseResult::Complete(_)));
        let req = parser.take_request().unwrap();
        assert_eq!(req.method, Method::GET);
        assert_eq!(req.path, "/");
        assert_eq!(req.host(), Some("example.com"));
    }

    #[test]
    fn test_parse_post_with_body() {
        let mut parser = HttpParser::new();
        let data = b"POST /api HTTP/1.1\r\nHost: example.com\r\nContent-Length: 5\r\n\r\nhello";
        let result = parser.feed(data);
        assert!(matches!(result, ParseResult::Complete(_)));
        let req = parser.take_request().unwrap();
        assert_eq!(req.method, Method::POST);
        assert_eq!(req.content_length(), Some(5));
    }

    #[test]
    fn test_parse_incremental() {
        let mut parser = HttpParser::new();
        assert!(matches!(parser.feed(b"GET / HTT"), ParseResult::NeedMore));
        assert!(matches!(
            parser.feed(b"P/1.1\r\nHost: "),
            ParseResult::NeedMore
        ));
        assert!(matches!(
            parser.feed(b"example.com\r\n\r\n"),
            ParseResult::Complete(_)
        ));
        let req = parser.take_request().unwrap();
        assert_eq!(req.method, Method::GET);
    }

    #[test]
    fn test_reset_for_keepalive() {
        let mut parser = HttpParser::new();
        parser.feed(b"GET / HTTP/1.1\r\nHost: a.com\r\n\r\n");
        parser.take_request();

        parser.reset();
        parser.feed(b"POST /api HTTP/1.1\r\nHost: b.com\r\nContent-Length: 0\r\n\r\n");
        let req = parser.take_request().unwrap();
        assert_eq!(req.method, Method::POST);
    }
}