Ferrit Explore
中文·繁體·EN·日本語 Sign in Register
cielxl / veld / src / http / request.rs
//! HTTP request types and a single-shot parser.
//!
//! This module defines [`Method`], [`Version`], and [`Request`] -- the core
//! types for representing a parsed HTTP/1.x request.  The
//! [`Request::parse`] function performs a one-shot parse of a complete
//! (or at least header-complete) byte buffer, returning the request and
//! the number of bytes consumed (so the caller can split off the body).
//!
//! For streaming/incremental parsing see [`super::parser::HttpParser`].

use bytes::Bytes;
use std::fmt;
use std::net::SocketAddr;

use super::headers::HeaderMap;

// ---------------------------------------------------------------------------
// Method
// ---------------------------------------------------------------------------

/// HTTP request method.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Method {
    GET,
    POST,
    PUT,
    DELETE,
    HEAD,
    OPTIONS,
    PATCH,
    TRACE,
    CONNECT,
}

impl Method {
    /// Parse a method from raw bytes.
    ///
    /// Returns `None` if the input does not match a recognised token.
    pub fn from_bytes(src: &[u8]) -> Option<Self> {
        match src {
            b"GET" => Some(Method::GET),
            b"POST" => Some(Method::POST),
            b"PUT" => Some(Method::PUT),
            b"DELETE" => Some(Method::DELETE),
            b"HEAD" => Some(Method::HEAD),
            b"OPTIONS" => Some(Method::OPTIONS),
            b"PATCH" => Some(Method::PATCH),
            b"TRACE" => Some(Method::TRACE),
            b"CONNECT" => Some(Method::CONNECT),
            _ => None,
        }
    }

    /// Return the canonical method token (e.g. `"GET"`).
    #[inline]
    pub fn as_str(&self) -> &'static str {
        match self {
            Method::GET => "GET",
            Method::POST => "POST",
            Method::PUT => "PUT",
            Method::DELETE => "DELETE",
            Method::HEAD => "HEAD",
            Method::OPTIONS => "OPTIONS",
            Method::PATCH => "PATCH",
            Method::TRACE => "TRACE",
            Method::CONNECT => "CONNECT",
        }
    }
}

impl fmt::Display for Method {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.write_str(self.as_str())
    }
}

// ---------------------------------------------------------------------------
// Version
// ---------------------------------------------------------------------------

/// HTTP protocol version.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Version {
    Http10,
    Http11,
    Http20,
}

impl Version {
    /// Parse a version from raw bytes (e.g. `b"HTTP/1.1"`).
    ///
    /// Both `HTTP/2` and `HTTP/2.0` are accepted and map to
    /// [`Version::Http20`].
    pub fn from_bytes(src: &[u8]) -> Option<Self> {
        match src {
            b"HTTP/1.0" => Some(Version::Http10),
            b"HTTP/1.1" => Some(Version::Http11),
            b"HTTP/2" | b"HTTP/2.0" => Some(Version::Http20),
            _ => None,
        }
    }

    /// Return the canonical version string (e.g. `"HTTP/1.1"`).
    #[inline]
    pub fn as_str(&self) -> &'static str {
        match self {
            Version::Http10 => "HTTP/1.0",
            Version::Http11 => "HTTP/1.1",
            Version::Http20 => "HTTP/2",
        }
    }
}

impl fmt::Display for Version {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.write_str(self.as_str())
    }
}

// ---------------------------------------------------------------------------
// ParseError
// ---------------------------------------------------------------------------

/// Errors returned by [`Request::parse`].
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ParseError {
    /// The request line is missing or does not contain three tokens.
    MalformedRequestLine,
    /// The method token is not a recognised HTTP method.
    InvalidMethod,
    /// The version token is not a recognised HTTP version.
    InvalidVersion,
    /// A header line is malformed (missing colon, invalid UTF-8, etc.).
    MalformedHeader,
    /// Not enough data is available; the caller should supply more bytes
    /// and retry.
    Incomplete,
}

impl fmt::Display for ParseError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            ParseError::MalformedRequestLine => f.write_str("malformed request line"),
            ParseError::InvalidMethod => f.write_str("unrecognised HTTP method"),
            ParseError::InvalidVersion => f.write_str("unrecognised HTTP version"),
            ParseError::MalformedHeader => f.write_str("malformed header line"),
            ParseError::Incomplete => f.write_str("incomplete request"),
        }
    }
}

impl std::error::Error for ParseError {}

// ---------------------------------------------------------------------------
// Request
// ---------------------------------------------------------------------------

/// A parsed HTTP request.
///
/// Holds the decoded method, URI components, headers, and an optional
/// body.  Network-level metadata ([`remote_addr`], [`connection_id`])
/// is populated by the caller after parsing.
#[derive(Debug, Clone)]
pub struct Request {
    pub method: Method,
    /// The raw, un-decoded request-target (e.g. `"/foo%20bar?x=1"`).
    pub uri: String,
    /// The decoded path portion of the URI (e.g. `"/foo%20bar"`).
    pub path: String,
    /// The query string **without** the leading `?`, if present.
    pub query: Option<String>,
    pub version: Version,
    pub headers: HeaderMap,
    /// The message body, if any.  Uses [`Bytes`] so that slicing and
    /// cloning are cheap reference-counted operations.
    pub body: Option<Bytes>,
    /// Peer socket address, filled in by the connection layer.
    pub remote_addr: Option<SocketAddr>,
    /// Monotonically increasing connection identifier.
    pub connection_id: u64,
}

impl Request {
    /// Create a new request from its component parts.
    ///
    /// The `uri` is split into `path` and `query` at the first `?`.
    /// Headers and body start empty.
    pub fn new(method: Method, uri: String, version: Version) -> Self {
        let (path, query) = split_uri(&uri);
        Self {
            method,
            uri,
            path,
            query,
            version,
            headers: HeaderMap::new(),
            body: None,
            remote_addr: None,
            connection_id: 0,
        }
    }

    /// Create a GET request from a URI string (convenience constructor).
    pub fn from_uri(uri: &str) -> Self {
        Self::new(Method::GET, uri.to_string(), Version::Http11)
    }

    // -- convenience accessors -----------------------------------------------

    /// Return the request URI as a string slice.
    pub fn uri(&self) -> &str {
        &self.uri
    }

    /// Return the HTTP method as a string slice.
    pub fn method(&self) -> &str {
        self.method.as_str()
    }

    /// Return the value of the `Host` header, if present.
    pub fn host(&self) -> Option<&str> {
        self.headers.get("host").and_then(|v| v.to_str().ok())
    }

    /// Return the value of the named header, if present.
    pub fn header(&self, name: &str) -> Option<&str> {
        self.headers.get(name).and_then(|v| v.to_str().ok())
    }

    /// Return `true` if the `Accept-Encoding` header lists `encoding`.
    pub fn accepts_encoding(&self, encoding: &str) -> bool {
        self.headers
            .get("accept-encoding")
            .and_then(|v| v.to_str().ok())
            .map(|ae| ae.contains(encoding))
            .unwrap_or(false)
    }

    /// Return `true` if this connection should be kept alive.
    ///
    /// For HTTP/1.0 the default is close; for HTTP/1.1 the default is
    /// keep-alive.  An explicit `Connection` header overrides both.
    pub fn is_keep_alive(&self) -> bool {
        match self.version {
            Version::Http10 => self
                .headers
                .get("connection")
                .and_then(|v| v.to_str().ok())
                .map(|c| c.eq_ignore_ascii_case("keep-alive"))
                .unwrap_or(false),
            Version::Http11 => self
                .headers
                .get("connection")
                .and_then(|v| v.to_str().ok())
                .map(|c| !c.eq_ignore_ascii_case("close"))
                .unwrap_or(true),
            Version::Http20 => true,
        }
    }

    /// Return the `Content-Length` value, if present and valid.
    pub fn content_length(&self) -> Option<usize> {
        self.headers
            .get("content-length")
            .and_then(|v| v.to_str().ok())
            .and_then(|cl| cl.parse::<usize>().ok())
    }

    // -- parsing -------------------------------------------------------------

    /// Parse an HTTP request from `data`.
    ///
    /// On success returns the parsed [`Request`] and the number of bytes
    /// consumed from `data` (up to and including the blank line that
    /// terminates the header section).  Any bytes in `data` beyond the
    /// returned offset are the start of the message body and are stored
    /// in [`Request::body`].
    ///
    /// The parser is zero-copy for the body (sub-sliced via
    /// [`Bytes::copy_from_slice`]).  Header names and values are
    /// allocated because they are stored in the [`HeaderMap`].
    ///
    /// # Errors
    ///
    /// Returns [`ParseError::Incomplete`] when `data` does not yet
    /// contain a full start-line and header block, allowing the caller
    /// to read more bytes and retry.
    pub fn parse(data: &[u8]) -> Result<(Request, usize), ParseError> {
        // -- request line ----------------------------------------------------
        let line_end = find_crlf(data, 0).ok_or(ParseError::Incomplete)?;
        let (method, uri, version) = parse_request_line(&data[..line_end])?;
        let (path, query) = split_uri(&uri);

        // -- headers ---------------------------------------------------------
        let hdr_start = line_end + 2; // skip CRLF after request line
        let (headers, hdr_consumed) = parse_headers(&data[hdr_start..])?;
        let total = hdr_start + hdr_consumed;

        // -- body (zero-copy sub-slice) --------------------------------------
        let body = if total < data.len() {
            Some(Bytes::copy_from_slice(&data[total..]))
        } else {
            None
        };

        let req = Request {
            method,
            uri,
            path,
            query,
            version,
            headers,
            body,
            remote_addr: None,
            connection_id: 0,
        };
        Ok((req, total))
    }

    // -- serialisation -------------------------------------------------------

    /// Serialize the start-line and headers into a byte vector.
    ///
    /// The output is suitable for forwarding to an upstream server.
    /// If a body is present it is **appended** after the blank line.
    pub fn to_bytes(&self) -> Vec<u8> {
        // Upper-bound estimate: method + SP + uri + SP + version + CRLF
        //   + headers + CRLF + body.
        let hdr_count = self.headers.values_len();
        let mut buf = Vec::with_capacity(128 + hdr_count * 64);

        // Start line
        buf.extend_from_slice(self.method.as_str().as_bytes());
        buf.push(b' ');
        buf.extend_from_slice(self.uri.as_bytes());
        buf.push(b' ');
        buf.extend_from_slice(self.version.as_str().as_bytes());
        buf.extend_from_slice(b"\r\n");

        // Headers
        for (name, values) in self.headers.iter() {
            for value in values {
                buf.extend_from_slice(name.as_original().as_bytes());
                buf.extend_from_slice(b": ");
                buf.extend_from_slice(value.as_bytes());
                buf.extend_from_slice(b"\r\n");
            }
        }

        // Blank line
        buf.extend_from_slice(b"\r\n");

        // Body
        if let Some(ref body) = self.body {
            buf.extend_from_slice(body);
        }

        buf
    }
}

impl fmt::Display for Request {
    /// Reconstruct the raw HTTP request text.
    ///
    /// Output format follows the wire encoding: start-line, headers
    /// terminated by `\r\n`, blank line, then body.  This is identical
    /// to what [`to_bytes`](Request::to_bytes) produces, but rendered
    /// through the [`fmt`] machinery.
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        // Start line
        write!(
            f,
            "{} {} {}\r\n",
            self.method.as_str(),
            self.uri,
            self.version.as_str()
        )?;

        // Headers
        for (name, values) in self.headers.iter() {
            for value in values {
                write!(f, "{}: {}\r\n", name.as_original(), value)?;
            }
        }

        // Blank line
        f.write_str("\r\n")?;

        // Body (lossy UTF-8 -- binary content shows replacement chars)
        if let Some(body) = &self.body {
            f.write_str(&String::from_utf8_lossy(body))?;
        }

        Ok(())
    }
}

// ---------------------------------------------------------------------------
// Internal helpers
// ---------------------------------------------------------------------------

/// Locate the next `\r\n` in `buf` starting at `start`.
///
/// Returns the index of the `\r`, or `None` if no CRLF is found.
#[inline]
fn find_crlf(buf: &[u8], start: usize) -> Option<usize> {
    let len = buf.len();
    if start + 1 >= len {
        return None;
    }
    let mut i = start;
    while i + 1 < len {
        if buf[i] == b'\r' && buf[i + 1] == b'\n' {
            return Some(i);
        }
        i += 1;
    }
    None
}

/// Parse `<METHOD> <URI> <VERSION>` from a line that has already had
/// the trailing CRLF stripped.
fn parse_request_line(line: &[u8]) -> Result<(Method, String, Version), ParseError> {
    // METHOD
    let method_end = line
        .iter()
        .position(|&b| b == b' ')
        .ok_or(ParseError::MalformedRequestLine)?;
    let method = Method::from_bytes(&line[..method_end]).ok_or(ParseError::InvalidMethod)?;

    // URI
    let uri_start = method_end + 1;
    let uri_end = line[uri_start..]
        .iter()
        .position(|&b| b == b' ')
        .ok_or(ParseError::MalformedRequestLine)?
        + uri_start;
    let uri = std::str::from_utf8(&line[uri_start..uri_end])
        .map_err(|_| ParseError::MalformedRequestLine)?;

    // VERSION
    let version = Version::from_bytes(&line[uri_end + 1..]).ok_or(ParseError::InvalidVersion)?;

    Ok((method, uri.to_owned(), version))
}

/// Split a URI into `(path, Option<query>)` at the first `?`.
fn split_uri(uri: &str) -> (String, Option<String>) {
    match uri.find('?') {
        Some(pos) => {
            let query = if pos + 1 < uri.len() {
                Some(uri[pos + 1..].to_owned())
            } else {
                None
            };
            (uri[..pos].to_owned(), query)
        }
        None => (uri.to_owned(), None),
    }
}

/// Parse header lines until the blank line (`\r\n\r\n`).
///
/// Returns the populated [`HeaderMap`] and the total number of bytes
/// consumed **including** the trailing blank line.
fn parse_headers(data: &[u8]) -> Result<(HeaderMap, usize), ParseError> {
    let mut headers = HeaderMap::new();
    let mut pos = 0usize;

    loop {
        if pos >= data.len() {
            return Err(ParseError::Incomplete);
        }

        // Blank line -- end of header section.
        if data[pos] == b'\r' {
            if pos + 1 >= data.len() {
                return Err(ParseError::Incomplete);
            }
            if data[pos + 1] == b'\n' {
                return Ok((headers, pos + 2));
            }
        }

        let line_end = find_crlf(data, pos).ok_or(ParseError::Incomplete)?;
        let line = &data[pos..line_end];

        // Find the `:` separator.
        let colon = line
            .iter()
            .position(|&b| b == b':')
            .ok_or(ParseError::MalformedHeader)?;

        let name = std::str::from_utf8(&line[..colon]).map_err(|_| ParseError::MalformedHeader)?;

        // Skip optional OWS after the colon.
        let val_start = colon + 1;
        let val_start = val_start
            + line[val_start..]
                .iter()
                .position(|&b| b != b' ' && b != b'\t')
                .unwrap_or(line.len() - val_start);

        // Trim trailing OWS from the value.
        let value = std::str::from_utf8(&line[val_start..])
            .map_err(|_| ParseError::MalformedHeader)?
            .trim_end();

        headers.insert(name, value);

        pos = line_end + 2; // advance past CRLF
    }
}

// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn method_from_bytes_roundtrip() {
        let cases: &[(&[u8], &str)] = &[
            (b"GET", "GET"),
            (b"POST", "POST"),
            (b"PUT", "PUT"),
            (b"DELETE", "DELETE"),
            (b"HEAD", "HEAD"),
            (b"OPTIONS", "OPTIONS"),
            (b"PATCH", "PATCH"),
            (b"TRACE", "TRACE"),
            (b"CONNECT", "CONNECT"),
        ];
        for &(bytes, expected) in cases {
            let m = Method::from_bytes(bytes).unwrap();
            assert_eq!(m.as_str(), expected);
            assert_eq!(m.to_string(), expected);
        }
    }

    #[test]
    fn method_from_bytes_unknown() {
        assert!(Method::from_bytes(b"FOOBAR").is_none());
    }

    #[test]
    fn version_from_bytes() {
        assert_eq!(Version::from_bytes(b"HTTP/1.0"), Some(Version::Http10));
        assert_eq!(Version::from_bytes(b"HTTP/1.1"), Some(Version::Http11));
        assert_eq!(Version::from_bytes(b"HTTP/2"), Some(Version::Http20));
        assert_eq!(Version::from_bytes(b"HTTP/2.0"), Some(Version::Http20));
        assert_eq!(Version::from_bytes(b"HTTP/3.0"), None);
    }

    #[test]
    fn version_display() {
        assert_eq!(Version::Http10.to_string(), "HTTP/1.0");
        assert_eq!(Version::Http11.to_string(), "HTTP/1.1");
        assert_eq!(Version::Http20.to_string(), "HTTP/2");
    }

    #[test]
    fn parse_simple_get() {
        let data = b"GET /index.html HTTP/1.1\r\nHost: example.com\r\n\r\n";
        let (req, consumed) = Request::parse(data).unwrap();
        assert_eq!(consumed, data.len());
        assert_eq!(req.method, Method::GET);
        assert_eq!(req.uri, "/index.html");
        assert_eq!(req.path, "/index.html");
        assert!(req.query.is_none());
        assert_eq!(req.version, Version::Http11);
        assert_eq!(req.host(), Some("example.com"));
        assert!(req.body.is_none());
    }

    #[test]
    fn parse_get_with_query() {
        let data = b"GET /search?q=rust&lang=en HTTP/1.1\r\nHost: e.com\r\n\r\n";
        let (req, _) = Request::parse(data).unwrap();
        assert_eq!(req.path, "/search");
        assert_eq!(req.query.as_deref(), Some("q=rust&lang=en"));
    }

    #[test]
    fn parse_post_with_body() {
        let data = b"POST /api HTTP/1.1\r\nHost: e.com\r\nContent-Length: 5\r\n\r\nhello";
        let (req, consumed) = Request::parse(data).unwrap();
        assert_eq!(req.method, Method::POST);
        assert_eq!(req.content_length(), Some(5));
        let body = req.body.unwrap();
        assert_eq!(&body[..], b"hello");
        // consumed should NOT include the body
        assert_eq!(consumed, data.len() - 5);
    }

    #[test]
    fn parse_multiple_headers() {
        let data =
            b"GET / HTTP/1.1\r\nHost: e.com\r\nAccept: text/html\r\nUser-Agent: test\r\n\r\n";
        let (req, _) = Request::parse(data).unwrap();
        assert_eq!(req.header("host"), Some("e.com"));
        assert_eq!(req.header("accept"), Some("text/html"));
        assert_eq!(req.header("user-agent"), Some("test"));
    }

    #[test]
    fn parse_incomplete_returns_error() {
        let data = b"GET / HTTP/1.1\r\nHost: e.com\r\n";
        assert!(matches!(Request::parse(data), Err(ParseError::Incomplete)));
    }

    #[test]
    fn parse_no_headers_complete() {
        let data = b"GET / HTTP/1.1\r\n\r\n";
        let (req, consumed) = Request::parse(data).unwrap();
        assert_eq!(consumed, data.len());
        assert!(req.headers.is_empty());
        assert!(req.body.is_none());
    }

    #[test]
    fn parse_invalid_method() {
        let data = b"FOOBAR / HTTP/1.1\r\nHost: e.com\r\n\r\n";
        assert!(matches!(
            Request::parse(data),
            Err(ParseError::InvalidMethod)
        ));
    }

    #[test]
    fn parse_invalid_version() {
        let data = b"GET / HTTP/3.0\r\nHost: e.com\r\n\r\n";
        assert!(matches!(
            Request::parse(data),
            Err(ParseError::InvalidVersion)
        ));
    }

    #[test]
    fn parse_malformed_request_line_no_uri() {
        let data = b"GET\r\nHost: e.com\r\n\r\n";
        assert!(matches!(
            Request::parse(data),
            Err(ParseError::MalformedRequestLine)
        ));
    }

    #[test]
    fn parse_http10_keep_alive() {
        let data = b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n";
        let (req, _) = Request::parse(data).unwrap();
        assert!(req.is_keep_alive());
    }

    #[test]
    fn parse_http11_close() {
        let data = b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n";
        let (req, _) = Request::parse(data).unwrap();
        assert!(!req.is_keep_alive());
    }

    #[test]
    fn request_new_splits_uri() {
        let req = Request::new(Method::GET, "/foo?bar=baz".to_owned(), Version::Http11);
        assert_eq!(req.path, "/foo");
        assert_eq!(req.query.as_deref(), Some("bar=baz"));
    }

    #[test]
    fn request_display_roundtrip() {
        let raw = b"GET /path HTTP/1.1\r\nHost: example.com\r\nAccept: text/html\r\n\r\nbody";
        let (req, _) = Request::parse(raw).unwrap();
        let rendered = req.to_string();
        assert!(rendered.starts_with("GET /path HTTP/1.1\r\n"));
        assert!(rendered.contains("Host: example.com\r\n"));
        assert!(rendered.contains("Accept: text/html\r\n"));
        assert!(rendered.ends_with("\r\nbody"));
    }

    #[test]
    fn request_to_bytes_roundtrip() {
        let raw = b"POST /submit HTTP/1.1\r\nHost: e.com\r\nContent-Length: 4\r\n\r\ndata";
        let (req, _) = Request::parse(raw).unwrap();
        let bytes = req.to_bytes();
        assert_eq!(&bytes[..], &raw[..]);
    }

    #[test]
    fn request_accepts_encoding() {
        let data = b"GET / HTTP/1.1\r\nHost: e.com\r\nAccept-Encoding: gzip, br\r\n\r\n";
        let (req, _) = Request::parse(data).unwrap();
        assert!(req.accepts_encoding("gzip"));
        assert!(req.accepts_encoding("br"));
        assert!(!req.accepts_encoding("zstd"));
    }

    #[test]
    fn parse_empty_query_string() {
        let data = b"GET /foo? HTTP/1.1\r\nHost: e.com\r\n\r\n";
        let (req, _) = Request::parse(data).unwrap();
        assert_eq!(req.path, "/foo");
        assert!(req.query.is_none());
    }
}