pavex/request/body/
buffered_body.rs1use bytes::Bytes;
2use http::header::CONTENT_LENGTH;
3use http_body_util::{BodyExt, Limited};
4use pavex_macros::methods;
5use ubyte::ByteUnit;
6
7use crate::{request::RequestHead, request::body::errors::SizeLimitExceeded};
8
9use super::{
10 BodySizeLimit, RawIncomingBody,
11 errors::{ExtractBufferedBodyError, UnexpectedBufferError},
12};
13
14#[derive(Debug, Clone)]
15#[non_exhaustive]
16pub struct BufferedBody {
52 pub bytes: Bytes,
54}
55
56#[methods]
57impl BufferedBody {
58 #[request_scoped(pavex = crate)]
62 pub async fn extract(
63 request_head: &RequestHead,
64 body: RawIncomingBody,
65 body_size_limit: BodySizeLimit,
66 ) -> Result<Self, ExtractBufferedBodyError> {
67 match body_size_limit {
68 BodySizeLimit::Enabled { max_size } => {
69 Self::_extract_with_limit(request_head, body, max_size).await
70 }
71 BodySizeLimit::Disabled => match body.collect().await {
72 Ok(collected) => Ok(Self {
73 bytes: collected.to_bytes(),
74 }),
75 Err(e) => Err(UnexpectedBufferError { source: e.into() }.into()),
76 },
77 }
78 }
79
80 async fn _extract_with_limit<B>(
81 request_head: &RequestHead,
82 body: B,
83 max_size: ByteUnit,
84 ) -> Result<Self, ExtractBufferedBodyError>
85 where
86 B: hyper::body::Body,
87 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
88 {
89 let content_length = request_head
90 .headers
91 .get(CONTENT_LENGTH)
92 .and_then(|value| value.to_str().ok()?.parse::<usize>().ok());
93
94 let limit_error = || SizeLimitExceeded {
96 max_size,
97 content_length,
98 };
99
100 if let Some(len) = content_length
107 && len > max_size
108 {
109 return Err(limit_error().into());
110 }
111
112 let max_n_bytes = max_size.as_u64().try_into().unwrap_or(usize::MAX);
115 let limited_body = Limited::new(body, max_n_bytes);
119 match limited_body.collect().await {
120 Ok(collected) => Ok(Self {
121 bytes: collected.to_bytes(),
122 }),
123 Err(e) => {
124 if e.downcast_ref::<http_body_util::LengthLimitError>()
125 .is_some()
126 {
127 Err(limit_error().into())
128 } else {
129 Err(UnexpectedBufferError { source: e }.into())
130 }
131 }
132 }
133 }
134}
135
136impl From<BufferedBody> for Bytes {
137 fn from(buffered_body: BufferedBody) -> Self {
138 buffered_body.bytes
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use http::HeaderMap;
145 use ubyte::ToByteUnit;
146
147 use crate::request::RequestHead;
148
149 use super::{BufferedBody, Bytes};
150
151 fn dummy_request_head() -> RequestHead {
153 RequestHead {
154 method: http::Method::GET,
155 target: "/".parse().unwrap(),
156 version: http::Version::HTTP_11,
157 headers: HeaderMap::new(),
158 }
159 }
160
161 #[tokio::test]
162 async fn error_if_body_above_size_limit_without_content_length() {
163 let raw_body = vec![0; 1000];
164
165 let max_n_bytes = 100.bytes();
167 assert!(raw_body.len() > max_n_bytes.as_u64() as usize);
168
169 let body = crate::response::body::raw::Full::new(Bytes::from(raw_body));
170 let err = BufferedBody::_extract_with_limit(&dummy_request_head(), body, max_n_bytes)
171 .await
172 .unwrap_err();
173 insta::assert_snapshot!(err, @"The request body is larger than the maximum size limit enforced by this server.");
174 insta::assert_debug_snapshot!(err, @r###"
175 SizeLimitExceeded(
176 SizeLimitExceeded {
177 max_size: ByteUnit(
178 100,
179 ),
180 content_length: None,
181 },
182 )
183 "###);
184 }
185
186 #[tokio::test]
187 async fn error_if_content_length_header_is_larger_than_limit() {
191 let mut request_head = dummy_request_head();
192
193 let max_size = 100.bytes();
196 let body = crate::response::body::raw::Full::new(Bytes::from(vec![0; 500]));
197 request_head
198 .headers
199 .insert("Content-Length", "1000".parse().unwrap());
200
201 let err = BufferedBody::_extract_with_limit(&request_head, body, max_size)
203 .await
204 .unwrap_err();
205 insta::assert_snapshot!(err, @"The request body is larger than the maximum size limit enforced by this server.");
206 insta::assert_debug_snapshot!(err, @r###"
207 SizeLimitExceeded(
208 SizeLimitExceeded {
209 max_size: ByteUnit(
210 100,
211 ),
212 content_length: Some(
213 1000,
214 ),
215 },
216 )
217 "###);
218 }
219}