chunkedge_nbt/binary/
decode.rs

1use std::borrow::Cow;
2use std::hash::Hash;
3use std::{fmt, mem};
4
5use byteorder::{BigEndian, ReadBytesExt};
6
7use crate::tag::Tag;
8use crate::{Compound, Error, List, Result, Value};
9
10/// Decodes uncompressed NBT binary data from the provided slice.
11///
12/// The string returned in the tuple is the name of the root compound
13/// (typically the empty string).
14pub fn from_binary<'de, S>(slice: &mut &'de [u8]) -> Result<(Compound<S>, Option<S>)>
15where
16    S: FromModifiedUtf8<'de> + Hash + Ord,
17{
18    let mut state = DecodeState { slice, depth: 0 };
19
20    let root_tag = state.read_tag()?;
21
22    if root_tag != Tag::Compound {
23        return Err(Error::new_owned(format!(
24            "expected root tag for compound (got {})",
25            root_tag.name(),
26        )));
27    }
28
29    let root_name = {
30        let mut slice = *state.slice;
31        let mut peek_state = DecodeState {
32            slice: &mut slice,
33            depth: 0,
34        };
35
36        match peek_state.read_string::<S>() {
37            Ok(_) => Some(state.read_string().unwrap()),
38            Err(_) => None,
39        }
40    };
41    let root = state.read_compound()?;
42
43    debug_assert_eq!(state.depth, 0);
44
45    Ok((root, root_name))
46}
47
48/// Decodes uncompressed network NBT binary data from the provided slice,
49/// Network NBT omits the root compound.
50pub fn from_network_binary<'de, S>(slice: &mut &'de [u8]) -> Result<Compound<S>>
51where
52    S: FromModifiedUtf8<'de> + Hash + Ord,
53{
54    let mut state = DecodeState { slice, depth: 0 };
55
56    let compound = state.read_compound()?;
57
58    debug_assert_eq!(state.depth, 0);
59
60    Ok(compound)
61}
62
63/// Maximum recursion depth to prevent overflowing the call stack.
64const MAX_DEPTH: usize = 512;
65
66struct DecodeState<'a, 'de> {
67    slice: &'a mut &'de [u8],
68    /// Current recursion depth.
69    depth: usize,
70}
71
72impl<'de> DecodeState<'_, 'de> {
73    #[inline]
74    fn check_depth<T>(&mut self, f: impl FnOnce(&mut Self) -> Result<T>) -> Result<T> {
75        if self.depth >= MAX_DEPTH {
76            return Err(Error::new_static("reached maximum recursion depth"));
77        }
78
79        self.depth += 1;
80        let res = f(self);
81        self.depth -= 1;
82        res
83    }
84
85    fn read_tag(&mut self) -> Result<Tag> {
86        match self.slice.read_u8()? {
87            0 => Ok(Tag::End),
88            1 => Ok(Tag::Byte),
89            2 => Ok(Tag::Short),
90            3 => Ok(Tag::Int),
91            4 => Ok(Tag::Long),
92            5 => Ok(Tag::Float),
93            6 => Ok(Tag::Double),
94            7 => Ok(Tag::ByteArray),
95            8 => Ok(Tag::String),
96            9 => Ok(Tag::List),
97            10 => Ok(Tag::Compound),
98            11 => Ok(Tag::IntArray),
99            12 => Ok(Tag::LongArray),
100            byte => Err(Error::new_owned(format!("invalid tag byte of {byte:#x}"))),
101        }
102    }
103
104    fn read_value<S>(&mut self, tag: Tag) -> Result<Value<S>>
105    where
106        S: FromModifiedUtf8<'de> + Hash + Ord,
107    {
108        match tag {
109            Tag::End => unreachable!("illegal TAG_End argument"),
110            Tag::Byte => Ok(self.read_byte()?.into()),
111            Tag::Short => Ok(self.read_short()?.into()),
112            Tag::Int => Ok(self.read_int()?.into()),
113            Tag::Long => Ok(self.read_long()?.into()),
114            Tag::Float => Ok(self.read_float()?.into()),
115            Tag::Double => Ok(self.read_double()?.into()),
116            Tag::ByteArray => Ok(self.read_byte_array()?.into()),
117            Tag::String => Ok(Value::String(self.read_string::<S>()?)),
118            Tag::List => self.check_depth(|st| Ok(st.read_any_list::<S>()?.into())),
119            Tag::Compound => self.check_depth(|st| Ok(st.read_compound::<S>()?.into())),
120            Tag::IntArray => Ok(self.read_int_array()?.into()),
121            Tag::LongArray => Ok(self.read_long_array()?.into()),
122        }
123    }
124
125    fn read_byte(&mut self) -> Result<i8> {
126        Ok(self.slice.read_i8()?)
127    }
128
129    fn read_short(&mut self) -> Result<i16> {
130        Ok(self.slice.read_i16::<BigEndian>()?)
131    }
132
133    fn read_int(&mut self) -> Result<i32> {
134        Ok(self.slice.read_i32::<BigEndian>()?)
135    }
136
137    fn read_long(&mut self) -> Result<i64> {
138        Ok(self.slice.read_i64::<BigEndian>()?)
139    }
140
141    fn read_float(&mut self) -> Result<f32> {
142        Ok(self.slice.read_f32::<BigEndian>()?)
143    }
144
145    fn read_double(&mut self) -> Result<f64> {
146        Ok(self.slice.read_f64::<BigEndian>()?)
147    }
148
149    fn read_byte_array(&mut self) -> Result<Vec<i8>> {
150        let len = self.slice.read_i32::<BigEndian>()?;
151
152        if len.is_negative() {
153            return Err(Error::new_owned(format!(
154                "negative byte array length of {len}"
155            )));
156        }
157
158        if len as usize > self.slice.len() {
159            return Err(Error::new_owned(format!(
160                "byte array length of {len} exceeds remainder of input"
161            )));
162        }
163
164        let (left, right) = self.slice.split_at(len as usize);
165
166        let array = left.iter().map(|b| *b as i8).collect();
167        *self.slice = right;
168
169        Ok(array)
170    }
171
172    fn read_string<S>(&mut self) -> Result<S>
173    where
174        S: FromModifiedUtf8<'de>,
175    {
176        let len = self.slice.read_u16::<BigEndian>()?.into();
177
178        if len > self.slice.len() {
179            return Err(Error::new_owned(format!(
180                "string of length {len} exceeds remainder of input"
181            )));
182        }
183
184        let (left, right) = self.slice.split_at(len);
185
186        match S::from_modified_utf8(left) {
187            Ok(str) => {
188                *self.slice = right;
189                Ok(str)
190            }
191            Err(_) => Err(Error::new_static("could not decode modified UTF-8 data")),
192        }
193    }
194
195    fn read_any_list<S>(&mut self) -> Result<List<S>>
196    where
197        S: FromModifiedUtf8<'de> + Hash + Ord,
198    {
199        match self.read_tag()? {
200            Tag::End => match self.read_int()? {
201                0 => Ok(List::End),
202                len => Err(Error::new_owned(format!(
203                    "TAG_End list with nonzero length of {len}"
204                ))),
205            },
206            Tag::Byte => Ok(self.read_list(Tag::Byte, 1, |st| st.read_byte())?.into()),
207            Tag::Short => Ok(self.read_list(Tag::Short, 2, |st| st.read_short())?.into()),
208            Tag::Int => Ok(self.read_list(Tag::Int, 4, |st| st.read_int())?.into()),
209            Tag::Long => Ok(self.read_list(Tag::Long, 8, |st| st.read_long())?.into()),
210            Tag::Float => Ok(self.read_list(Tag::Float, 4, |st| st.read_float())?.into()),
211            Tag::Double => Ok(self
212                .read_list(Tag::Double, 8, |st| st.read_double())?
213                .into()),
214            Tag::ByteArray => Ok(self
215                .read_list(Tag::ByteArray, 0, |st| st.read_byte_array())?
216                .into()),
217            Tag::String => Ok(List::String(
218                self.read_list(Tag::String, 0, |st| st.read_string::<S>())?,
219            )),
220            Tag::List => self.check_depth(|st| {
221                Ok(st
222                    .read_list(Tag::List, 0, |st| st.read_any_list::<S>())?
223                    .into())
224            }),
225            Tag::Compound => self.check_depth(|st| {
226                Ok(st
227                    .read_list(Tag::Compound, 0, |st| st.read_compound::<S>())?
228                    .into())
229            }),
230            Tag::IntArray => Ok(self
231                .read_list(Tag::IntArray, 0, |st| st.read_int_array())?
232                .into()),
233            Tag::LongArray => Ok(self
234                .read_list(Tag::LongArray, 0, |st| st.read_long_array())?
235                .into()),
236        }
237    }
238
239    /// Assumes the element tag has already been read.
240    ///
241    /// `min_elem_size` is the minimum size of the list element when encoded.
242    #[inline]
243    fn read_list<T, F>(
244        &mut self,
245        elem_type: Tag,
246        elem_size: usize,
247        mut read_elem: F,
248    ) -> Result<Vec<T>>
249    where
250        F: FnMut(&mut Self) -> Result<T>,
251    {
252        let len = self.read_int()?;
253
254        if len.is_negative() {
255            return Err(Error::new_owned(format!(
256                "negative {} list length of {len}",
257                elem_type.name()
258            )));
259        }
260
261        // Ensure we don't reserve more than the maximum amount of memory required given
262        // the size of the remaining input.
263        if len as u64 * elem_size as u64 > self.slice.len() as u64 {
264            return Err(Error::new_owned(format!(
265                "{} list of length {len} exceeds remainder of input",
266                elem_type.name()
267            )));
268        }
269
270        let mut list = Vec::with_capacity(if elem_size == 0 { 0 } else { len as usize });
271
272        for _ in 0..len {
273            list.push(read_elem(self)?);
274        }
275
276        Ok(list)
277    }
278
279    fn read_compound<S>(&mut self) -> Result<Compound<S>>
280    where
281        S: FromModifiedUtf8<'de> + Hash + Ord,
282    {
283        let mut compound = Compound::new();
284
285        loop {
286            let tag = self.read_tag()?;
287            if tag == Tag::End {
288                return Ok(compound);
289            }
290
291            compound.insert(self.read_string::<S>()?, self.read_value::<S>(tag)?);
292        }
293    }
294
295    fn read_int_array(&mut self) -> Result<Vec<i32>> {
296        let len = self.read_int()?;
297
298        if len.is_negative() {
299            return Err(Error::new_owned(format!(
300                "negative int array length of {len}",
301            )));
302        }
303
304        if len as u64 * mem::size_of::<i32>() as u64 > self.slice.len() as u64 {
305            return Err(Error::new_owned(format!(
306                "int array of length {len} exceeds remainder of input"
307            )));
308        }
309
310        let mut array = Vec::with_capacity(len as usize);
311        for _ in 0..len {
312            array.push(self.read_int()?);
313        }
314
315        Ok(array)
316    }
317
318    fn read_long_array(&mut self) -> Result<Vec<i64>> {
319        let len = self.read_int()?;
320
321        if len.is_negative() {
322            return Err(Error::new_owned(format!(
323                "negative long array length of {len}",
324            )));
325        }
326
327        if len as u64 * mem::size_of::<i64>() as u64 > self.slice.len() as u64 {
328            return Err(Error::new_owned(format!(
329                "long array of length {len} exceeds remainder of input"
330            )));
331        }
332
333        let mut array = Vec::with_capacity(len as usize);
334        for _ in 0..len {
335            array.push(self.read_long()?);
336        }
337
338        Ok(array)
339    }
340}
341
342#[derive(Copy, Clone, Debug)]
343pub struct FromModifiedUtf8Error;
344
345impl fmt::Display for FromModifiedUtf8Error {
346    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
347        f.write_str("could not decode modified UTF-8 data")
348    }
349}
350
351impl std::error::Error for FromModifiedUtf8Error {}
352
353/// A string type which can be decoded from Java's [modified UTF-8](https://docs.oracle.com/javase/8/docs/api/java/io/DataInput.html#modified-utf-8).
354pub trait FromModifiedUtf8<'de>: Sized {
355    fn from_modified_utf8(
356        modified_utf8: &'de [u8],
357    ) -> std::result::Result<Self, FromModifiedUtf8Error>;
358}
359
360impl<'de> FromModifiedUtf8<'de> for Cow<'de, str> {
361    fn from_modified_utf8(
362        modified_utf8: &'de [u8],
363    ) -> std::result::Result<Self, FromModifiedUtf8Error> {
364        cesu8::from_java_cesu8(modified_utf8).map_err(move |_| FromModifiedUtf8Error)
365    }
366}
367
368impl<'de> FromModifiedUtf8<'de> for String {
369    fn from_modified_utf8(
370        modified_utf8: &'de [u8],
371    ) -> std::result::Result<Self, FromModifiedUtf8Error> {
372        match cesu8::from_java_cesu8(modified_utf8) {
373            Ok(str) => Ok(str.into_owned()),
374            Err(_) => Err(FromModifiedUtf8Error),
375        }
376    }
377}
378
379#[cfg(feature = "java_string")]
380impl<'de> FromModifiedUtf8<'de> for Cow<'de, java_string::JavaStr> {
381    fn from_modified_utf8(
382        modified_utf8: &'de [u8],
383    ) -> std::result::Result<Self, FromModifiedUtf8Error> {
384        java_string::JavaStr::from_modified_utf8(modified_utf8).map_err(|_| FromModifiedUtf8Error)
385    }
386}
387
388#[cfg(feature = "java_string")]
389impl<'de> FromModifiedUtf8<'de> for java_string::JavaString {
390    fn from_modified_utf8(
391        modified_utf8: &'de [u8],
392    ) -> std::result::Result<Self, FromModifiedUtf8Error> {
393        match java_string::JavaStr::from_modified_utf8(modified_utf8) {
394            Ok(str) => Ok(str.into_owned()),
395            Err(_) => Err(FromModifiedUtf8Error),
396        }
397    }
398}