Stop blindly allocating fields based on length

If we allocate and zero the memory for a length-prefixed field up front,
it makes it trivial to DOS the deserializer by sending an enormous
length prefix.

Instead allocate at most one page of memory, growing the Vec as required
to accomodate the data actually available in the incoming buffer.
This commit is contained in:
Richard Bradfield 2020-07-05 13:17:27 +01:00 committed by Tadeo Kondrak
parent 321b1ba94b
commit d8005d69b1
No known key found for this signature in database
GPG Key ID: D41E092CA43F1D8B

View File

@ -7,10 +7,31 @@ use std::{
str, u16, u32, u64, u8,
};
// We could use pagesize to get this across platforms, but 4K is a reasonable value
const PAGESIZE: usize = 4096;
pub struct Deserializer<R> {
reader: R,
}
/// Try and return a Vec<u8> of `len` bytes from a Reader
#[inline]
fn read_bytes<R: Read>(reader: R, len: usize) -> Result<Vec<u8>, std::io::Error> {
let capacity = if len > PAGESIZE { PAGESIZE } else { len };
// Allocate at most one page to start with. Growing a Vec is fairly efficient once you get out
// of the region of the first few hundred bytes.
let mut buffer: Vec<u8> = Vec::with_capacity(capacity);
let read = reader.take(len as u64).read_to_end(&mut buffer)?;
if read < len {
Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"Unexpected EOF reading number of bytes expected in field prefix"
))
} else {
Ok(buffer)
}
}
impl<R> Deserializer<R> {
pub fn new(reader: R) -> Self {
Deserializer { reader }
@ -182,9 +203,7 @@ where
{
let Uint(length) = <Uint as de::Deserialize>::deserialize(&mut *self)?;
let length = length as usize;
let mut buf = Vec::with_capacity(length);
buf.resize(length, 0);
self.reader.read_exact(&mut buf).map_err(Error::Io)?;
let buf = read_bytes(&mut self.reader, length).map_err(Error::Io)?;
let utf8 = str::from_utf8(&buf).map_err(|_| Error::InvalidUtf8)?;
visitor.visit_str(utf8)
}
@ -196,9 +215,7 @@ where
{
let Uint(length) = <Uint as de::Deserialize>::deserialize(&mut *self)?;
let length = length as usize;
let mut buf = Vec::with_capacity(length);
buf.resize(length, 0);
self.reader.read_exact(&mut buf).map_err(Error::Io)?;
let buf = read_bytes(&mut self.reader, length).map_err(Error::Io)?;
let utf8 = String::from_utf8(buf).map_err(|_| Error::InvalidUtf8)?;
visitor.visit_string(utf8)
}
@ -210,9 +227,7 @@ where
{
let Uint(length) = <Uint as de::Deserialize>::deserialize(&mut *self)?;
let length = length as usize;
let mut buf = Vec::with_capacity(length);
buf.resize(length, 0);
self.reader.read_exact(&mut buf).map_err(Error::Io)?;
let buf = read_bytes(&mut self.reader, length).map_err(Error::Io)?;
visitor.visit_bytes(&buf)
}
@ -223,9 +238,7 @@ where
{
let Uint(length) = <Uint as de::Deserialize>::deserialize(&mut *self)?;
let length = length as usize;
let mut buf = Vec::with_capacity(length);
buf.resize(length, 0);
self.reader.read_exact(&mut buf).map_err(Error::Io)?;
let buf = read_bytes(&mut self.reader, length).map_err(Error::Io)?;
visitor.visit_byte_buf(buf)
}