diff --git a/src/main.rs b/src/main.rs index d74b28b..95ad3d4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ // #![feature(if_let_guard)] use std::{ collections::HashMap, path::PathBuf, sync::Arc }; -use anyhow::Result; +use anyhow::{bail, Result}; use tokio::{ fs::File, @@ -23,14 +23,10 @@ type A = Arc; fn parse_args () -> Args { let directory = std::env::args() .position(|e| e == "--directory") - .map(|e| e + 1) - .map(|d| std::env::args().nth(d)) - .flatten() + .and_then(|d| std::env::args().nth(d + 1)) .map(|d| PathBuf::from(d)); - Args { - directory - } + Args { directory } } #[tokio::main] @@ -51,39 +47,55 @@ async fn main() -> Result<()> { async fn process (mut stream: TcpStream, args: A) -> Result<()> { let buf_reader = BufReader::new(&mut stream); - let mut data = buf_reader - .lines(); + let mut data = buf_reader.lines(); - let (_method, path, _ver) = { - let start_line = data.next_line().await?.ok_or(E::InvalidRequest)?; // should be 500; - let mut parts = start_line.split_whitespace().map(ToOwned::to_owned); - let method = parts.next().ok_or(E::InvalidRequest)?; - let path = parts.next().ok_or(E::InvalidRequest)?; - let ver = parts.next().ok_or(E::InvalidRequest)?; + let (method, target, _version) = 'outer : { + 'inner : { + let Ok(Some(start_line)) = data.next_line().await else { break 'inner }; + let mut parts = start_line.split_ascii_whitespace(); + let Some(Ok(method)) = parts.next().map(Method::try_from) else { break 'inner }; + let Some(path) = parts.next() else { break 'inner }; + let Some(version) = parts.next() else { break 'inner }; + break 'outer (method, path.to_owned(), version.to_owned()); + } + // this is either the best or the worst piece of code i've written - (method, path, ver) + let _ = stream.write_all(&Response::_400.build()).await; + let _ = stream.flush(); + bail!(E::InvalidRequest) }; - let headers = Headers::parse(data.into_inner()).await; - let response = match path.as_str() { - "/" => Response::Empty, - "/user-agent" => Response::TextPlain(headers.get("User-Agent").to_owned()), + let mut data = data.into_inner(); + + let headers = Headers::parse(&mut data).await; + + let length = headers.get("Content-Length").parse().unwrap(); + let body = parse_req_body(&mut data, length).await; + + use Method as M; + let response = match (method, target.as_str()) { + (M::GET, "/") => Response::Empty, + (M::GET, "/user-agent") => Response::TextPlain(headers.get("User-Agent").to_owned()), // p if let Some(echo) = p.strip_prefix("/echo/") => Response::TextPlain(echo), // a nicer way to do that, not available in stable yet - p if p.starts_with("/echo/") => Response::TextPlain(p.trim_start_matches("/echo/").to_owned()), - p if p.starts_with("/files/") => 'a : { - let Some(path) = &args.directory else { - break 'a Response::_500; - }; - let path = path.join(p.trim_start_matches("/files/")); - let Ok(mut f) = File::open(path).await else { - break 'a Response::_404; - }; + (M::GET, r) if r.starts_with("/echo/") => Response::TextPlain(r.trim_start_matches("/echo/").to_owned()), + (M::GET, r) if r.starts_with("/files/") => 'file : { + let Some(path) = &args.directory else { break 'file Response::_500; }; + let path = path.join(r.trim_start_matches("/files/")); + let Ok(mut f) = File::open(path).await else { break 'file Response::_404; }; let mut buf = vec![]; let _ = f.read_to_end(&mut buf).await; Response::OctetStream(buf) }, + (M::POST, r) if r.starts_with("/files") => 'file : { + let Some(path) = &args.directory else { break 'file Response::_500; }; + let path = path.join(r.trim_start_matches("/files/")); + let Ok(mut f) = File::create(path).await else { break 'file Response::_500; }; + let Ok(_) = f.write_all(&body).await else { break 'file Response::_500 }; + + Response::_201 + }, _ => Response::_404, }; @@ -93,11 +105,17 @@ async fn process (mut stream: TcpStream, args: A) -> Result<()> { Ok(()) } +pub async fn parse_req_body (reader: &mut BufReader<&mut TcpStream>, length: usize) -> Vec { + let mut v = vec![0; length]; + let _ = reader.read_exact(&mut v).await; + v +} + #[derive(Debug, Clone)] pub struct Headers (HashMap); impl Headers { - pub async fn parse (mut reader: BufReader<&'_ mut TcpStream>) -> Self { + pub async fn parse (reader: &mut BufReader<&'_ mut TcpStream>) -> Self { let mut map = HashMap::new(); let mut buf = String::new(); while let Ok(_) = reader.read_line(&mut buf).await { @@ -119,6 +137,8 @@ impl Headers { #[derive(Debug, Clone)] enum Response { + _201, + _400, _404, _500, Empty, @@ -126,20 +146,41 @@ enum Response { OctetStream (Vec) } +#[derive(Debug, Clone, PartialEq, Eq)] +enum Method { + GET, + POST, + PUT, + DELETE, + UPDATE, +} + +impl TryFrom<&str> for Method { + type Error = anyhow::Error; + fn try_from (s: &str) -> Result { + let mut s = s.to_string(); + s.make_ascii_lowercase(); + Ok(match s.as_str() { + "get" => Self::GET, + "post" => Self::POST, + "put" => Self::PUT, + "delete" => Self::DELETE, + "update" => Self::UPDATE, + _ => bail!(E::UnknownMethod) + }) + } +} + #[allow(non_upper_case_globals)] impl Response { fn build (self) -> Vec { let headers = self.headers().join("\r\n"); - - let code = match self { - Self::_404 => "404 Not Found", - Self::_500 => "500 Internal Server Error", - _ => "200 OK", - }; + let code = self.code(); let mut v: Vec = f!("HTTP/1.1 {code}\r\n{headers}\r\n\r\n").into(); + match self { Self::OctetStream(bytes) => { v.extend_from_slice(&bytes); @@ -153,10 +194,26 @@ impl Response { v } + fn code (&self) -> &'static str { + match self { + Self::_201 => "201 Created", + Self::_400 => "400 Bad Request", + Self::_404 => "404 Not Found", + Self::_500 => "500 Internal Server Error", + _ => "200 OK", + } + } + fn headers (&self) -> Vec { match self { - Self::TextPlain(text) => vec![f!("Content-Type: text/plain"), format!("Content-Length: {}", text.len())], - Self::OctetStream(bytes) => vec![f!("Content-Type: application/octet-stream"), format!("Content-Length: {}", bytes.len())], + Self::TextPlain(text) => vec![ + f!("Content-Type: text/plain"), + format!("Content-Length: {}", text.len()) + ], + Self::OctetStream(bytes) => vec![ + f!("Content-Type: application/octet-stream"), + format!("Content-Length: {}", bytes.len()) + ], _ => d!() } } diff --git a/src/utils.rs b/src/utils.rs index 31ba577..0d59ea6 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -9,4 +9,6 @@ macro_rules! f { ($s: expr) => { format!($s) }; } pub enum E { #[error("Invalid request data found during parsing")] InvalidRequest, + #[error("Cannot parse method")] + UnknownMethod, }