1use std::borrow::Cow;
2use std::hash::Hash;
3use std::{fmt, mem};
4
5use byteorder::{BigEndian, ReadBytesExt};
6
7use crate::tag::Tag;
8use crate::{Compound, Error, List, Result, Value};
9
10pub fn from_binary<'de, S>(slice: &mut &'de [u8]) -> Result<(Compound<S>, Option<S>)>
15where
16 S: FromModifiedUtf8<'de> + Hash + Ord,
17{
18 let mut state = DecodeState { slice, depth: 0 };
19
20 let root_tag = state.read_tag()?;
21
22 if root_tag != Tag::Compound {
23 return Err(Error::new_owned(format!(
24 "expected root tag for compound (got {})",
25 root_tag.name(),
26 )));
27 }
28
29 let root_name = {
30 let mut slice = *state.slice;
31 let mut peek_state = DecodeState {
32 slice: &mut slice,
33 depth: 0,
34 };
35
36 match peek_state.read_string::<S>() {
37 Ok(_) => Some(state.read_string().unwrap()),
38 Err(_) => None,
39 }
40 };
41 let root = state.read_compound()?;
42
43 debug_assert_eq!(state.depth, 0);
44
45 Ok((root, root_name))
46}
47
48pub fn from_network_binary<'de, S>(slice: &mut &'de [u8]) -> Result<Compound<S>>
51where
52 S: FromModifiedUtf8<'de> + Hash + Ord,
53{
54 let mut state = DecodeState { slice, depth: 0 };
55
56 let compound = state.read_compound()?;
57
58 debug_assert_eq!(state.depth, 0);
59
60 Ok(compound)
61}
62
63const MAX_DEPTH: usize = 512;
65
66struct DecodeState<'a, 'de> {
67 slice: &'a mut &'de [u8],
68 depth: usize,
70}
71
72impl<'de> DecodeState<'_, 'de> {
73 #[inline]
74 fn check_depth<T>(&mut self, f: impl FnOnce(&mut Self) -> Result<T>) -> Result<T> {
75 if self.depth >= MAX_DEPTH {
76 return Err(Error::new_static("reached maximum recursion depth"));
77 }
78
79 self.depth += 1;
80 let res = f(self);
81 self.depth -= 1;
82 res
83 }
84
85 fn read_tag(&mut self) -> Result<Tag> {
86 match self.slice.read_u8()? {
87 0 => Ok(Tag::End),
88 1 => Ok(Tag::Byte),
89 2 => Ok(Tag::Short),
90 3 => Ok(Tag::Int),
91 4 => Ok(Tag::Long),
92 5 => Ok(Tag::Float),
93 6 => Ok(Tag::Double),
94 7 => Ok(Tag::ByteArray),
95 8 => Ok(Tag::String),
96 9 => Ok(Tag::List),
97 10 => Ok(Tag::Compound),
98 11 => Ok(Tag::IntArray),
99 12 => Ok(Tag::LongArray),
100 byte => Err(Error::new_owned(format!("invalid tag byte of {byte:#x}"))),
101 }
102 }
103
104 fn read_value<S>(&mut self, tag: Tag) -> Result<Value<S>>
105 where
106 S: FromModifiedUtf8<'de> + Hash + Ord,
107 {
108 match tag {
109 Tag::End => unreachable!("illegal TAG_End argument"),
110 Tag::Byte => Ok(self.read_byte()?.into()),
111 Tag::Short => Ok(self.read_short()?.into()),
112 Tag::Int => Ok(self.read_int()?.into()),
113 Tag::Long => Ok(self.read_long()?.into()),
114 Tag::Float => Ok(self.read_float()?.into()),
115 Tag::Double => Ok(self.read_double()?.into()),
116 Tag::ByteArray => Ok(self.read_byte_array()?.into()),
117 Tag::String => Ok(Value::String(self.read_string::<S>()?)),
118 Tag::List => self.check_depth(|st| Ok(st.read_any_list::<S>()?.into())),
119 Tag::Compound => self.check_depth(|st| Ok(st.read_compound::<S>()?.into())),
120 Tag::IntArray => Ok(self.read_int_array()?.into()),
121 Tag::LongArray => Ok(self.read_long_array()?.into()),
122 }
123 }
124
125 fn read_byte(&mut self) -> Result<i8> {
126 Ok(self.slice.read_i8()?)
127 }
128
129 fn read_short(&mut self) -> Result<i16> {
130 Ok(self.slice.read_i16::<BigEndian>()?)
131 }
132
133 fn read_int(&mut self) -> Result<i32> {
134 Ok(self.slice.read_i32::<BigEndian>()?)
135 }
136
137 fn read_long(&mut self) -> Result<i64> {
138 Ok(self.slice.read_i64::<BigEndian>()?)
139 }
140
141 fn read_float(&mut self) -> Result<f32> {
142 Ok(self.slice.read_f32::<BigEndian>()?)
143 }
144
145 fn read_double(&mut self) -> Result<f64> {
146 Ok(self.slice.read_f64::<BigEndian>()?)
147 }
148
149 fn read_byte_array(&mut self) -> Result<Vec<i8>> {
150 let len = self.slice.read_i32::<BigEndian>()?;
151
152 if len.is_negative() {
153 return Err(Error::new_owned(format!(
154 "negative byte array length of {len}"
155 )));
156 }
157
158 if len as usize > self.slice.len() {
159 return Err(Error::new_owned(format!(
160 "byte array length of {len} exceeds remainder of input"
161 )));
162 }
163
164 let (left, right) = self.slice.split_at(len as usize);
165
166 let array = left.iter().map(|b| *b as i8).collect();
167 *self.slice = right;
168
169 Ok(array)
170 }
171
172 fn read_string<S>(&mut self) -> Result<S>
173 where
174 S: FromModifiedUtf8<'de>,
175 {
176 let len = self.slice.read_u16::<BigEndian>()?.into();
177
178 if len > self.slice.len() {
179 return Err(Error::new_owned(format!(
180 "string of length {len} exceeds remainder of input"
181 )));
182 }
183
184 let (left, right) = self.slice.split_at(len);
185
186 match S::from_modified_utf8(left) {
187 Ok(str) => {
188 *self.slice = right;
189 Ok(str)
190 }
191 Err(_) => Err(Error::new_static("could not decode modified UTF-8 data")),
192 }
193 }
194
195 fn read_any_list<S>(&mut self) -> Result<List<S>>
196 where
197 S: FromModifiedUtf8<'de> + Hash + Ord,
198 {
199 match self.read_tag()? {
200 Tag::End => match self.read_int()? {
201 0 => Ok(List::End),
202 len => Err(Error::new_owned(format!(
203 "TAG_End list with nonzero length of {len}"
204 ))),
205 },
206 Tag::Byte => Ok(self.read_list(Tag::Byte, 1, |st| st.read_byte())?.into()),
207 Tag::Short => Ok(self.read_list(Tag::Short, 2, |st| st.read_short())?.into()),
208 Tag::Int => Ok(self.read_list(Tag::Int, 4, |st| st.read_int())?.into()),
209 Tag::Long => Ok(self.read_list(Tag::Long, 8, |st| st.read_long())?.into()),
210 Tag::Float => Ok(self.read_list(Tag::Float, 4, |st| st.read_float())?.into()),
211 Tag::Double => Ok(self
212 .read_list(Tag::Double, 8, |st| st.read_double())?
213 .into()),
214 Tag::ByteArray => Ok(self
215 .read_list(Tag::ByteArray, 0, |st| st.read_byte_array())?
216 .into()),
217 Tag::String => Ok(List::String(
218 self.read_list(Tag::String, 0, |st| st.read_string::<S>())?,
219 )),
220 Tag::List => self.check_depth(|st| {
221 Ok(st
222 .read_list(Tag::List, 0, |st| st.read_any_list::<S>())?
223 .into())
224 }),
225 Tag::Compound => self.check_depth(|st| {
226 Ok(st
227 .read_list(Tag::Compound, 0, |st| st.read_compound::<S>())?
228 .into())
229 }),
230 Tag::IntArray => Ok(self
231 .read_list(Tag::IntArray, 0, |st| st.read_int_array())?
232 .into()),
233 Tag::LongArray => Ok(self
234 .read_list(Tag::LongArray, 0, |st| st.read_long_array())?
235 .into()),
236 }
237 }
238
239 #[inline]
243 fn read_list<T, F>(
244 &mut self,
245 elem_type: Tag,
246 elem_size: usize,
247 mut read_elem: F,
248 ) -> Result<Vec<T>>
249 where
250 F: FnMut(&mut Self) -> Result<T>,
251 {
252 let len = self.read_int()?;
253
254 if len.is_negative() {
255 return Err(Error::new_owned(format!(
256 "negative {} list length of {len}",
257 elem_type.name()
258 )));
259 }
260
261 if len as u64 * elem_size as u64 > self.slice.len() as u64 {
264 return Err(Error::new_owned(format!(
265 "{} list of length {len} exceeds remainder of input",
266 elem_type.name()
267 )));
268 }
269
270 let mut list = Vec::with_capacity(if elem_size == 0 { 0 } else { len as usize });
271
272 for _ in 0..len {
273 list.push(read_elem(self)?);
274 }
275
276 Ok(list)
277 }
278
279 fn read_compound<S>(&mut self) -> Result<Compound<S>>
280 where
281 S: FromModifiedUtf8<'de> + Hash + Ord,
282 {
283 let mut compound = Compound::new();
284
285 loop {
286 let tag = self.read_tag()?;
287 if tag == Tag::End {
288 return Ok(compound);
289 }
290
291 compound.insert(self.read_string::<S>()?, self.read_value::<S>(tag)?);
292 }
293 }
294
295 fn read_int_array(&mut self) -> Result<Vec<i32>> {
296 let len = self.read_int()?;
297
298 if len.is_negative() {
299 return Err(Error::new_owned(format!(
300 "negative int array length of {len}",
301 )));
302 }
303
304 if len as u64 * mem::size_of::<i32>() as u64 > self.slice.len() as u64 {
305 return Err(Error::new_owned(format!(
306 "int array of length {len} exceeds remainder of input"
307 )));
308 }
309
310 let mut array = Vec::with_capacity(len as usize);
311 for _ in 0..len {
312 array.push(self.read_int()?);
313 }
314
315 Ok(array)
316 }
317
318 fn read_long_array(&mut self) -> Result<Vec<i64>> {
319 let len = self.read_int()?;
320
321 if len.is_negative() {
322 return Err(Error::new_owned(format!(
323 "negative long array length of {len}",
324 )));
325 }
326
327 if len as u64 * mem::size_of::<i64>() as u64 > self.slice.len() as u64 {
328 return Err(Error::new_owned(format!(
329 "long array of length {len} exceeds remainder of input"
330 )));
331 }
332
333 let mut array = Vec::with_capacity(len as usize);
334 for _ in 0..len {
335 array.push(self.read_long()?);
336 }
337
338 Ok(array)
339 }
340}
341
342#[derive(Copy, Clone, Debug)]
343pub struct FromModifiedUtf8Error;
344
345impl fmt::Display for FromModifiedUtf8Error {
346 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
347 f.write_str("could not decode modified UTF-8 data")
348 }
349}
350
351impl std::error::Error for FromModifiedUtf8Error {}
352
353pub trait FromModifiedUtf8<'de>: Sized {
355 fn from_modified_utf8(
356 modified_utf8: &'de [u8],
357 ) -> std::result::Result<Self, FromModifiedUtf8Error>;
358}
359
360impl<'de> FromModifiedUtf8<'de> for Cow<'de, str> {
361 fn from_modified_utf8(
362 modified_utf8: &'de [u8],
363 ) -> std::result::Result<Self, FromModifiedUtf8Error> {
364 cesu8::from_java_cesu8(modified_utf8).map_err(move |_| FromModifiedUtf8Error)
365 }
366}
367
368impl<'de> FromModifiedUtf8<'de> for String {
369 fn from_modified_utf8(
370 modified_utf8: &'de [u8],
371 ) -> std::result::Result<Self, FromModifiedUtf8Error> {
372 match cesu8::from_java_cesu8(modified_utf8) {
373 Ok(str) => Ok(str.into_owned()),
374 Err(_) => Err(FromModifiedUtf8Error),
375 }
376 }
377}
378
379#[cfg(feature = "java_string")]
380impl<'de> FromModifiedUtf8<'de> for Cow<'de, java_string::JavaStr> {
381 fn from_modified_utf8(
382 modified_utf8: &'de [u8],
383 ) -> std::result::Result<Self, FromModifiedUtf8Error> {
384 java_string::JavaStr::from_modified_utf8(modified_utf8).map_err(|_| FromModifiedUtf8Error)
385 }
386}
387
388#[cfg(feature = "java_string")]
389impl<'de> FromModifiedUtf8<'de> for java_string::JavaString {
390 fn from_modified_utf8(
391 modified_utf8: &'de [u8],
392 ) -> std::result::Result<Self, FromModifiedUtf8Error> {
393 match java_string::JavaStr::from_modified_utf8(modified_utf8) {
394 Ok(str) => Ok(str.into_owned()),
395 Err(_) => Err(FromModifiedUtf8Error),
396 }
397 }
398}