use super::headers::HeaderMap;
use super::request::{Method, Request, Version};
use bytes::{Buf, BytesMut};
use std::io;
/// Parser state
#[derive(Debug, Clone, PartialEq)]
pub enum ParseState {
ParsingRequestLine,
ParsingHeaders,
ParsingBody(usize), // remaining bytes
ParsingChunked,
Complete,
Error(String),
}
/// Parse result from feeding data
#[derive(Debug)]
pub enum ParseResult {
NeedMore,
Complete(usize), // bytes consumed
Error(ParseError),
}
/// Parse error
#[derive(Debug)]
pub struct ParseError {
pub message: String,
}
impl std::fmt::Display for ParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "HTTP parse error: {}", self.message)
}
}
impl std::error::Error for ParseError {}
impl From<io::Error> for ParseError {
fn from(e: io::Error) -> Self {
ParseError {
message: e.to_string(),
}
}
}
/// Streaming HTTP/1.1 parser
pub struct HttpParser {
state: ParseState,
buffer: BytesMut,
method: Option<Method>,
uri: Option<String>,
version: Option<Version>,
headers: HeaderMap,
content_length: Option<usize>,
chunk_size: Option<usize>,
total_parsed: usize,
/// Accumulated request body (decoded; chunked framing removed).
body: BytesMut,
}
impl HttpParser {
pub fn new() -> Self {
Self {
state: ParseState::ParsingRequestLine,
buffer: BytesMut::with_capacity(8192),
method: None,
uri: None,
version: None,
headers: HeaderMap::new(),
content_length: None,
chunk_size: None,
total_parsed: 0,
body: BytesMut::new(),
}
}
/// Feed data into the parser. Returns the parse result.
pub fn feed(&mut self, data: &[u8]) -> ParseResult {
self.buffer.extend_from_slice(data);
loop {
match &self.state {
ParseState::ParsingRequestLine => {
if let Some(pos) = find_crlf(&self.buffer) {
// Use split_to to extract the line without copying.
// split_to returns a BytesMut containing [0..pos].
let line = self.buffer.split_to(pos);
self.buffer.advance(2); // skip \r\n
self.total_parsed += pos + 2;
match parse_request_line(&line) {
Ok((method, uri, version)) => {
self.method = Some(method);
self.uri = Some(uri);
self.version = Some(version);
self.state = ParseState::ParsingHeaders;
}
Err(e) => {
self.state = ParseState::Error(e.message.clone());
return ParseResult::Error(e);
}
}
} else {
return ParseResult::NeedMore;
}
}
ParseState::ParsingHeaders => {
if let Some(pos) = find_crlf(&self.buffer) {
if pos == 0 {
// Empty line = end of headers
self.buffer.advance(2);
self.total_parsed += 2;
self.content_length = self
.headers
.get("content-length")
.and_then(|cl| cl.parse().ok());
let transfer_encoding = self
.headers
.get("transfer-encoding")
.map(|s| s.to_lowercase());
if transfer_encoding.as_deref() == Some("chunked") {
self.state = ParseState::ParsingChunked;
} else if let Some(len) = self.content_length {
if len > 0 {
self.state = ParseState::ParsingBody(len);
} else {
self.state = ParseState::Complete;
return ParseResult::Complete(self.total_parsed);
}
} else {
self.state = ParseState::Complete;
return ParseResult::Complete(self.total_parsed);
}
} else {
// Use split_to to extract the line without copying.
let line = self.buffer.split_to(pos);
self.buffer.advance(2); // skip \r\n
self.total_parsed += pos + 2;
match parse_header_line(&line) {
Ok((name, value)) => {
self.headers.insert(name, value);
}
Err(e) => {
self.state = ParseState::Error(e.message.clone());
return ParseResult::Error(e);
}
}
}
} else {
return ParseResult::NeedMore;
}
}
ParseState::ParsingBody(remaining) => {
if self.buffer.len() >= *remaining {
self.total_parsed += remaining;
let chunk = self.buffer.split_to(*remaining);
self.body.extend_from_slice(&chunk);
self.state = ParseState::Complete;
return ParseResult::Complete(self.total_parsed);
} else {
return ParseResult::NeedMore;
}
}
ParseState::ParsingChunked => {
// Parse chunked transfer encoding
loop {
if let Some(size) = self.chunk_size {
if size == 0 {
// Last chunk
if self.buffer.len() >= 2 && self.buffer[..2] == *b"\r\n" {
let _ = self.buffer.split_to(2);
self.total_parsed += 2;
self.state = ParseState::Complete;
return ParseResult::Complete(self.total_parsed);
}
return ParseResult::NeedMore;
}
if self.buffer.len() >= size + 2 {
let chunk = self.buffer.split_to(size + 2);
// Keep the chunk data, drop the trailing CRLF.
self.body.extend_from_slice(&chunk[..size]);
self.total_parsed += size + 2;
self.chunk_size = None;
} else {
return ParseResult::NeedMore;
}
} else if let Some(pos) = find_crlf(&self.buffer) {
let size_str = std::str::from_utf8(&self.buffer[..pos]).unwrap_or("0");
let size = usize::from_str_radix(size_str.trim(), 16).unwrap_or(0);
let _ = self.buffer.split_to(pos + 2);
self.total_parsed += pos + 2;
self.chunk_size = Some(size);
} else {
return ParseResult::NeedMore;
}
}
}
ParseState::Complete => {
return ParseResult::Complete(self.total_parsed);
}
ParseState::Error(msg) => {
return ParseResult::Error(ParseError {
message: msg.clone(),
});
}
}
}
}
/// Check if parsing is complete
pub fn is_complete(&self) -> bool {
matches!(self.state, ParseState::Complete)
}
/// Take the parsed request, consuming the parser state
pub fn take_request(&mut self) -> Option<Request> {
if !self.is_complete() {
return None;
}
let method = self.method.take()?;
let uri = self.uri.take()?;
let version = self.version.take()?;
let headers = std::mem::take(&mut self.headers);
let mut request = Request::new(method, uri, version);
request.headers = headers;
if !self.body.is_empty() {
request.body = Some(std::mem::take(&mut self.body).freeze());
}
Some(request)
}
/// Reset the parser for the next request (keep-alive)
pub fn reset(&mut self) {
self.state = ParseState::ParsingRequestLine;
self.buffer.clear();
self.method = None;
self.uri = None;
self.version = None;
self.headers = HeaderMap::new();
self.content_length = None;
self.chunk_size = None;
self.total_parsed = 0;
self.body.clear();
}
/// Get remaining buffered data
pub fn remaining(&self) -> &[u8] {
&self.buffer
}
}
impl Default for HttpParser {
fn default() -> Self {
Self::new()
}
}
fn find_crlf(buf: &[u8]) -> Option<usize> {
buf.windows(2).position(|w| w == b"\r\n")
}
/// Parse the HTTP request line: `METHOD URI HTTP/VERSION\r\n`.
///
/// Works directly on the byte slice to avoid an intermediate `String`
/// allocation for the entire line. Only the URI (which may contain
/// arbitrary percent-encoded bytes) is converted to an owned `String`.
fn parse_request_line(line: &[u8]) -> Result<(Method, String, Version), ParseError> {
// Find the two spaces that delimit method, URI, and version.
let Some(sp1) = line.iter().position(|&b| b == b' ') else {
return Err(ParseError {
message: "Invalid request line: missing space".to_string(),
});
};
let rest = &line[sp1 + 1..];
let Some(sp2) = rest.iter().position(|&b| b == b' ') else {
return Err(ParseError {
message: "Invalid request line: missing second space".to_string(),
});
};
let method_bytes = &line[..sp1];
let uri_bytes = &rest[..sp2];
let version_bytes = &rest[sp2 + 1..];
let method = Method::from_bytes(method_bytes).ok_or_else(|| {
// Safe: method tokens are always ASCII.
let m = unsafe { String::from_utf8_unchecked(method_bytes.to_vec()) };
ParseError {
message: format!("Unknown method: {}", m),
}
})?;
// URI is typically ASCII; use from_utf8_unchecked to skip validation.
// If it contains non-UTF-8 bytes, we still store them losslessly.
let uri = unsafe { String::from_utf8_unchecked(uri_bytes.to_vec()) };
let version = match version_bytes {
b"HTTP/1.0" => Version::Http10,
b"HTTP/1.1" => Version::Http11,
b"HTTP/2" | b"HTTP/2.0" => Version::Http20,
_ => {
let v = unsafe { String::from_utf8_unchecked(version_bytes.to_vec()) };
return Err(ParseError {
message: format!("Unknown version: {}", v),
});
}
};
Ok((method, uri, version))
}
/// Parse a single HTTP header line: `Name: value\r\n`.
///
/// Avoids allocating a `String` for the entire line by working directly
/// on the byte slice and only converting the trimmed name/value to
/// `String` using `from_utf8_unchecked` (HTTP header names and values
/// are constrained to ASCII / ISO-8859-1, which is a subset of UTF-8
/// for the characters that are legal in headers).
fn parse_header_line(line: &[u8]) -> Result<(String, String), ParseError> {
// Find the colon separator.
let colon_pos = line
.iter()
.position(|&b| b == b':')
.ok_or_else(|| ParseError {
message: "Invalid header line: missing colon".to_string(),
})?;
let name_bytes = trim_ascii(&line[..colon_pos]);
let value_bytes = trim_ascii(&line[colon_pos + 1..]);
// HTTP header names are restricted to visible ASCII characters
// (RFC 9110 section 5.1), so from_utf8_unchecked is safe here.
let name = unsafe { String::from_utf8_unchecked(name_bytes.to_vec()) };
let value = unsafe { String::from_utf8_unchecked(value_bytes.to_vec()) };
Ok((name, value))
}
/// Trim leading and trailing ASCII whitespace from a byte slice.
///
/// This avoids the allocation that `str::trim()` would require when
/// operating on a `&str` converted from bytes.
#[inline]
fn trim_ascii(bytes: &[u8]) -> &[u8] {
let start = bytes
.iter()
.position(|&b| !b.is_ascii_whitespace())
.unwrap_or(bytes.len());
let end = bytes
.iter()
.rposition(|&b| !b.is_ascii_whitespace())
.map(|p| p + 1)
.unwrap_or(start);
&bytes[start..end]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_simple_get() {
let mut parser = HttpParser::new();
let data = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n";
let result = parser.feed(data);
assert!(matches!(result, ParseResult::Complete(_)));
let req = parser.take_request().unwrap();
assert_eq!(req.method, Method::GET);
assert_eq!(req.path, "/");
assert_eq!(req.host(), Some("example.com"));
}
#[test]
fn test_parse_post_with_body() {
let mut parser = HttpParser::new();
let data = b"POST /api HTTP/1.1\r\nHost: example.com\r\nContent-Length: 5\r\n\r\nhello";
let result = parser.feed(data);
assert!(matches!(result, ParseResult::Complete(_)));
let req = parser.take_request().unwrap();
assert_eq!(req.method, Method::POST);
assert_eq!(req.content_length(), Some(5));
}
#[test]
fn test_parse_incremental() {
let mut parser = HttpParser::new();
assert!(matches!(parser.feed(b"GET / HTT"), ParseResult::NeedMore));
assert!(matches!(
parser.feed(b"P/1.1\r\nHost: "),
ParseResult::NeedMore
));
assert!(matches!(
parser.feed(b"example.com\r\n\r\n"),
ParseResult::Complete(_)
));
let req = parser.take_request().unwrap();
assert_eq!(req.method, Method::GET);
}
#[test]
fn test_reset_for_keepalive() {
let mut parser = HttpParser::new();
parser.feed(b"GET / HTTP/1.1\r\nHost: a.com\r\n\r\n");
parser.take_request();
parser.reset();
parser.feed(b"POST /api HTTP/1.1\r\nHost: b.com\r\nContent-Length: 0\r\n\r\n");
let req = parser.take_request().unwrap();
assert_eq!(req.method, Method::POST);
}
}