chunkedge_nbt/binary/
encode.rs

1use std::borrow::Cow;
2use std::hash::Hash;
3use std::io::Write;
4
5use byteorder::{BigEndian, WriteBytesExt};
6
7use super::modified_utf8;
8use crate::conv::i8_slice_as_u8_slice;
9use crate::tag::Tag;
10use crate::{Compound, Error, List, Result, Value};
11
12/// Encodes uncompressed NBT binary data to the provided writer.
13///
14/// Only compounds are permitted at the top level. This is why the function
15/// accepts a [`Compound`] reference rather than a [`Value`].
16///
17/// Additionally, the root compound can be given a name. Typically the empty
18/// string `""` is used.
19pub fn to_binary<W, S, R>(comp: &Compound<S>, writer: W, root_name: Option<&R>) -> Result<()>
20where
21    W: Write,
22    S: ToModifiedUtf8 + Hash + Ord,
23    R: ToModifiedUtf8 + ?Sized,
24{
25    let mut state = EncodeState { writer };
26
27    state.write_tag(Tag::Compound)?;
28    if let Some(root_name) = root_name {
29        state.write_string(root_name)?;
30    }
31    state.write_compound(comp)?;
32
33    Ok(())
34}
35
36/// Encodes uncompressed network NBT binary data to the provided writer.
37/// Network NBT omits the root compound.
38pub fn to_network_binary<W, S>(val: &Compound<S>, writer: W) -> Result<()>
39where
40    W: Write,
41    S: ToModifiedUtf8 + Hash + Ord,
42{
43    let mut state = EncodeState { writer };
44
45    state.write_compound(val)?;
46
47    Ok(())
48}
49
50/// Returns the number of bytes that will be written when
51/// [`to_binary`] is called with this compound and root name.
52///
53/// If `to_binary` results in `Ok`, the exact number of bytes
54/// reported by this function will have been written. If the result is
55/// `Err`, then the reported count will be greater than or equal to the
56/// number of bytes that have actually been written.
57pub fn written_size<S, R>(comp: &Compound<S>, root_name: &R) -> usize
58where
59    S: ToModifiedUtf8 + Hash + Ord,
60    R: ToModifiedUtf8 + ?Sized,
61{
62    fn value_size<S>(val: &Value<S>) -> usize
63    where
64        S: ToModifiedUtf8 + Hash + Ord,
65    {
66        match val {
67            Value::Byte(_) => 1,
68            Value::Short(_) => 2,
69            Value::Int(_) => 4,
70            Value::Long(_) => 8,
71            Value::Float(_) => 4,
72            Value::Double(_) => 8,
73            Value::ByteArray(v) => 4 + v.len(),
74            Value::String(v) => string_size(v),
75            Value::List(v) => list_size(v),
76            Value::Compound(v) => compound_size(v),
77            Value::IntArray(v) => 4 + v.len() * 4,
78            Value::LongArray(v) => 4 + v.len() * 8,
79        }
80    }
81
82    fn list_size<S>(l: &List<S>) -> usize
83    where
84        S: ToModifiedUtf8 + Hash + Ord,
85    {
86        let elems_size = match l {
87            List::End => 0,
88            List::Byte(v) => v.len(),
89            List::Short(v) => v.len() * 2,
90            List::Int(v) => v.len() * 4,
91            List::Long(v) => v.len() * 8,
92            List::Float(v) => v.len() * 4,
93            List::Double(v) => v.len() * 8,
94            List::ByteArray(v) => v.iter().map(|b| 4 + b.len()).sum(),
95            List::String(v) => v.iter().map(|s| string_size(s)).sum(),
96            List::List(v) => v.iter().map(list_size).sum(),
97            List::Compound(v) => v.iter().map(compound_size).sum(),
98            List::IntArray(v) => v.iter().map(|i| 4 + i.len() * 4).sum(),
99            List::LongArray(v) => v.iter().map(|l| 4 + l.len() * 8).sum(),
100        };
101
102        1 + 4 + elems_size
103    }
104
105    fn string_size<S: ToModifiedUtf8 + ?Sized>(s: &S) -> usize {
106        2 + s.modified_uf8_len()
107    }
108
109    fn compound_size<S>(c: &Compound<S>) -> usize
110    where
111        S: ToModifiedUtf8 + Hash + Ord,
112    {
113        c.iter()
114            .map(|(k, v)| 1 + string_size(k) + value_size(v))
115            .sum::<usize>()
116            + 1
117    }
118
119    1 + string_size(root_name) + compound_size(comp)
120}
121
122struct EncodeState<W> {
123    writer: W,
124}
125
126impl<W: Write> EncodeState<W> {
127    fn write_tag(&mut self, tag: Tag) -> Result<()> {
128        Ok(self.writer.write_u8(tag as u8)?)
129    }
130
131    fn write_value<S>(&mut self, v: &Value<S>) -> Result<()>
132    where
133        S: ToModifiedUtf8 + Hash + Ord,
134    {
135        match v {
136            Value::Byte(v) => self.write_byte(*v),
137            Value::Short(v) => self.write_short(*v),
138            Value::Int(v) => self.write_int(*v),
139            Value::Long(v) => self.write_long(*v),
140            Value::Float(v) => self.write_float(*v),
141            Value::Double(v) => self.write_double(*v),
142            Value::ByteArray(v) => self.write_byte_array(v),
143            Value::String(v) => self.write_string(v),
144            Value::List(v) => self.write_any_list(v),
145            Value::Compound(v) => self.write_compound(v),
146            Value::IntArray(v) => self.write_int_array(v),
147            Value::LongArray(v) => self.write_long_array(v),
148        }
149    }
150
151    fn write_byte(&mut self, byte: i8) -> Result<()> {
152        Ok(self.writer.write_i8(byte)?)
153    }
154
155    fn write_short(&mut self, short: i16) -> Result<()> {
156        Ok(self.writer.write_i16::<BigEndian>(short)?)
157    }
158
159    fn write_int(&mut self, int: i32) -> Result<()> {
160        Ok(self.writer.write_i32::<BigEndian>(int)?)
161    }
162
163    fn write_long(&mut self, long: i64) -> Result<()> {
164        Ok(self.writer.write_i64::<BigEndian>(long)?)
165    }
166
167    fn write_float(&mut self, float: f32) -> Result<()> {
168        Ok(self.writer.write_f32::<BigEndian>(float)?)
169    }
170
171    fn write_double(&mut self, double: f64) -> Result<()> {
172        Ok(self.writer.write_f64::<BigEndian>(double)?)
173    }
174
175    fn write_byte_array(&mut self, bytes: &[i8]) -> Result<()> {
176        match bytes.len().try_into() {
177            Ok(len) => self.write_int(len)?,
178            Err(_) => {
179                return Err(Error::new_owned(format!(
180                    "byte array of length {} exceeds maximum of i32::MAX",
181                    bytes.len(),
182                )))
183            }
184        }
185
186        Ok(self.writer.write_all(i8_slice_as_u8_slice(bytes))?)
187    }
188
189    fn write_string<S: ToModifiedUtf8 + ?Sized>(&mut self, s: &S) -> Result<()> {
190        let len = s.modified_uf8_len();
191
192        match len.try_into() {
193            Ok(n) => self.writer.write_u16::<BigEndian>(n)?,
194            Err(_) => {
195                return Err(Error::new_owned(format!(
196                    "string of length {len} exceeds maximum of u16::MAX"
197                )))
198            }
199        }
200
201        s.to_modified_utf8(len, &mut self.writer)?;
202
203        Ok(())
204    }
205
206    fn write_any_list<S>(&mut self, list: &List<S>) -> Result<()>
207    where
208        S: ToModifiedUtf8 + Hash + Ord,
209    {
210        match list {
211            List::End => {
212                self.write_tag(Tag::End)?;
213                // Length
214                self.writer.write_i32::<BigEndian>(0)?;
215                Ok(())
216            }
217            List::Byte(v) => {
218                self.write_tag(Tag::Byte)?;
219
220                match v.len().try_into() {
221                    Ok(len) => self.write_int(len)?,
222                    Err(_) => {
223                        return Err(Error::new_owned(format!(
224                            "byte list of length {} exceeds maximum of i32::MAX",
225                            v.len(),
226                        )))
227                    }
228                }
229
230                Ok(self.writer.write_all(i8_slice_as_u8_slice(v))?)
231            }
232            List::Short(sl) => self.write_list(sl, Tag::Short, |st, v| st.write_short(*v)),
233            List::Int(il) => self.write_list(il, Tag::Int, |st, v| st.write_int(*v)),
234            List::Long(ll) => self.write_list(ll, Tag::Long, |st, v| st.write_long(*v)),
235            List::Float(fl) => self.write_list(fl, Tag::Float, |st, v| st.write_float(*v)),
236            List::Double(dl) => self.write_list(dl, Tag::Double, |st, v| st.write_double(*v)),
237            List::ByteArray(v) => {
238                self.write_list(v, Tag::ByteArray, |st, v| st.write_byte_array(v))
239            }
240            List::String(v) => self.write_list(v, Tag::String, |st, v| st.write_string(v)),
241            List::List(v) => self.write_list(v, Tag::List, |st, v| st.write_any_list(v)),
242            List::Compound(v) => self.write_list(v, Tag::Compound, |st, v| st.write_compound(v)),
243            List::IntArray(v) => self.write_list(v, Tag::IntArray, |st, v| st.write_int_array(v)),
244            List::LongArray(v) => {
245                self.write_list(v, Tag::LongArray, |st, v| st.write_long_array(v))
246            }
247        }
248    }
249
250    fn write_list<T, F>(&mut self, list: &[T], elem_type: Tag, mut write_elem: F) -> Result<()>
251    where
252        F: FnMut(&mut Self, &T) -> Result<()>,
253    {
254        self.write_tag(elem_type)?;
255
256        match list.len().try_into() {
257            Ok(len) => self.writer.write_i32::<BigEndian>(len)?,
258            Err(_) => {
259                return Err(Error::new_owned(format!(
260                    "{} list of length {} exceeds maximum of i32::MAX",
261                    list.len(),
262                    elem_type.name()
263                )))
264            }
265        }
266
267        for elem in list {
268            write_elem(self, elem)?;
269        }
270
271        Ok(())
272    }
273
274    fn write_compound<S>(&mut self, c: &Compound<S>) -> Result<()>
275    where
276        S: ToModifiedUtf8 + Hash + Ord,
277    {
278        for (k, v) in c {
279            self.write_tag(v.tag())?;
280            self.write_string(k)?;
281            self.write_value(v)?;
282        }
283        self.write_tag(Tag::End)?;
284
285        Ok(())
286    }
287
288    fn write_int_array(&mut self, ia: &[i32]) -> Result<()> {
289        match ia.len().try_into() {
290            Ok(len) => self.write_int(len)?,
291            Err(_) => {
292                return Err(Error::new_owned(format!(
293                    "int array of length {} exceeds maximum of i32::MAX",
294                    ia.len(),
295                )))
296            }
297        }
298
299        for i in ia {
300            self.write_int(*i)?;
301        }
302
303        Ok(())
304    }
305
306    fn write_long_array(&mut self, la: &[i64]) -> Result<()> {
307        match la.len().try_into() {
308            Ok(len) => self.write_int(len)?,
309            Err(_) => {
310                return Err(Error::new_owned(format!(
311                    "long array of length {} exceeds maximum of i32::MAX",
312                    la.len(),
313                )))
314            }
315        }
316
317        for l in la {
318            self.write_long(*l)?;
319        }
320
321        Ok(())
322    }
323}
324
325/// A string type which can be encoded into Java's [modified UTF-8](https://docs.oracle.com/javase/8/docs/api/java/io/DataInput.html#modified-utf-8).
326pub trait ToModifiedUtf8 {
327    fn modified_uf8_len(&self) -> usize;
328    fn to_modified_utf8<W: Write>(&self, encoded_len: usize, writer: W) -> std::io::Result<()>;
329}
330
331impl ToModifiedUtf8 for str {
332    fn modified_uf8_len(&self) -> usize {
333        modified_utf8::encoded_len(self.as_bytes())
334    }
335
336    fn to_modified_utf8<W: Write>(&self, encoded_len: usize, mut writer: W) -> std::io::Result<()> {
337        // Conversion to modified UTF-8 always increases the size of the string.
338        // If the new len is equal to the original len, we know it doesn't need
339        // to be re-encoded.
340        if self.len() == encoded_len {
341            writer.write_all(self.as_bytes())
342        } else {
343            modified_utf8::write_modified_utf8(writer, self)
344        }
345    }
346}
347
348impl ToModifiedUtf8 for Cow<'_, str> {
349    #[inline]
350    fn modified_uf8_len(&self) -> usize {
351        str::modified_uf8_len(self)
352    }
353
354    fn to_modified_utf8<W: Write>(&self, encoded_len: usize, writer: W) -> std::io::Result<()> {
355        str::to_modified_utf8(self, encoded_len, writer)
356    }
357}
358
359impl ToModifiedUtf8 for String {
360    #[inline]
361    fn modified_uf8_len(&self) -> usize {
362        str::modified_uf8_len(self)
363    }
364
365    fn to_modified_utf8<W: Write>(&self, encoded_len: usize, writer: W) -> std::io::Result<()> {
366        str::to_modified_utf8(self, encoded_len, writer)
367    }
368}
369
370#[cfg(feature = "java_string")]
371impl ToModifiedUtf8 for java_string::JavaStr {
372    fn modified_uf8_len(&self) -> usize {
373        modified_utf8::encoded_len(self.as_bytes())
374    }
375
376    fn to_modified_utf8<W: Write>(
377        &self,
378        _encoded_len: usize,
379        mut writer: W,
380    ) -> std::io::Result<()> {
381        writer.write_all(&self.to_modified_utf8())
382    }
383}
384
385#[cfg(feature = "java_string")]
386impl ToModifiedUtf8 for Cow<'_, java_string::JavaStr> {
387    #[inline]
388    fn modified_uf8_len(&self) -> usize {
389        java_string::JavaStr::modified_uf8_len(self)
390    }
391
392    fn to_modified_utf8<W: Write>(&self, encoded_len: usize, writer: W) -> std::io::Result<()> {
393        <java_string::JavaStr as ToModifiedUtf8>::to_modified_utf8(self, encoded_len, writer)
394    }
395}
396
397#[cfg(feature = "java_string")]
398impl ToModifiedUtf8 for java_string::JavaString {
399    #[inline]
400    fn modified_uf8_len(&self) -> usize {
401        java_string::JavaStr::modified_uf8_len(self)
402    }
403
404    fn to_modified_utf8<W: Write>(&self, encoded_len: usize, writer: W) -> std::io::Result<()> {
405        <java_string::JavaStr as ToModifiedUtf8>::to_modified_utf8(self, encoded_len, writer)
406    }
407}
408
409#[cfg(test)]
410mod tests {
411    use super::*;
412
413    #[test]
414    fn test_network_binary_empty_compound() {
415        let comp: Compound<String> = Compound::new();
416        let mut buf = Vec::new();
417        to_network_binary(&comp, &mut buf).unwrap();
418        assert_eq!(buf, [0x0]);
419    }
420}