chunkedge_binary/
bit_set.rs1use std::fmt;
2use std::io::Write;
3
4use anyhow::ensure;
5
6use crate::{Decode, Encode, VarInt};
7
8#[derive(Copy, Clone, PartialEq, Eq)]
15pub struct FixedBitSet<const BIT_COUNT: usize, const BYTE_COUNT: usize>(pub [u8; BYTE_COUNT]);
16
17impl<const BIT_COUNT: usize, const BYTE_COUNT: usize> FixedBitSet<BIT_COUNT, BYTE_COUNT> {
18 pub fn bit(&self, idx: usize) -> bool {
22 check_counts(BIT_COUNT, BYTE_COUNT);
23 debug_assert!(
24 idx < BIT_COUNT,
25 "bit index of {idx} out of range for bitset with {BIT_COUNT} bits"
26 );
27
28 self.0
29 .get(idx / 8)
30 .is_some_and(|byte| (byte >> (idx % 8)) & 1 == 1)
31 }
32
33 pub fn set(&mut self, idx: usize) {
37 check_counts(BIT_COUNT, BYTE_COUNT);
38 debug_assert!(
39 idx < BIT_COUNT,
40 "bit index of {idx} out of range for bitset with {BIT_COUNT} bits"
41 );
42
43 if idx < BIT_COUNT {
44 let byte = &mut self.0[idx / 8];
45 *byte |= 1 << (idx % 8);
46 }
47 }
48
49 pub fn clear(&mut self, idx: usize) {
53 check_counts(BIT_COUNT, BYTE_COUNT);
54 debug_assert!(
55 idx < BIT_COUNT,
56 "bit index of {idx} out of range for bitset with {BIT_COUNT} bits"
57 );
58
59 if idx < BIT_COUNT {
60 let byte = &mut self.0[idx / 8];
61 *byte &= !(1 << (idx % 8));
62 }
63 }
64}
65
66impl<const BIT_COUNT: usize, const BYTE_COUNT: usize> Encode
67 for FixedBitSet<BIT_COUNT, BYTE_COUNT>
68{
69 fn encode(&self, w: impl Write) -> anyhow::Result<()> {
70 check_counts(BIT_COUNT, BYTE_COUNT);
71 self.0.encode(w)
72 }
73}
74
75impl<const BIT_COUNT: usize, const BYTE_COUNT: usize> Decode<'_>
76 for FixedBitSet<BIT_COUNT, BYTE_COUNT>
77{
78 fn decode(r: &mut &'_ [u8]) -> anyhow::Result<Self> {
79 check_counts(BIT_COUNT, BYTE_COUNT);
80 Ok(Self(Decode::decode(r)?))
81 }
82}
83
84const fn check_counts(bits: usize, bytes: usize) {
85 assert!(bits.div_ceil(8) == bytes)
86}
87
88impl<const BIT_COUNT: usize, const BYTE_COUNT: usize> fmt::Debug
89 for FixedBitSet<BIT_COUNT, BYTE_COUNT>
90{
91 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92 fmt::Display::fmt(self, f)
93 }
94}
95
96impl<const BIT_COUNT: usize, const BYTE_COUNT: usize> fmt::Display
97 for FixedBitSet<BIT_COUNT, BYTE_COUNT>
98{
99 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100 write!(f, "0b")?;
101
102 for i in (0..BIT_COUNT).rev() {
103 if self.bit(i) {
104 write!(f, "1")?;
105 } else {
106 write!(f, "0")?;
107 }
108 }
109
110 Ok(())
111 }
112}
113
114macro_rules! impl_default {
116 ($($N:literal)*) => {
117 $(
118 impl<const BIT_COUNT: usize> Default for FixedBitSet<BIT_COUNT, $N> {
119 fn default() -> Self {
120 Self(Default::default())
121 }
122 }
123 )*
124 }
125}
126
127impl_default!(0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16);
128
129#[derive(Clone, PartialEq, Eq, Default)]
134pub struct VariableBitSet(pub Vec<i64>);
135
136impl VariableBitSet {
137 pub const fn new() -> Self {
139 Self(Vec::new())
140 }
141
142 pub fn bit(&self, idx: usize) -> bool {
144 let word = idx / 64;
145 let bit = idx % 64;
146
147 self.0
148 .get(word)
149 .is_some_and(|word| ((*word as u64) >> bit) & 1 == 1)
150 }
151
152 pub fn set(&mut self, idx: usize) {
154 let word = idx / 64;
155 let bit = idx % 64;
156
157 self.0.resize(self.0.len().max(word + 1), 0);
158 self.0[word] = (self.0[word] as u64 | (1 << bit)) as i64;
159 }
160
161 pub fn clear(&mut self, idx: usize) {
166 let word = idx / 64;
167 let bit = idx % 64;
168
169 if let Some(word_value) = self.0.get_mut(word) {
170 *word_value = (*word_value as u64 & !(1 << bit)) as i64;
171
172 while self.0.last() == Some(&0) {
173 self.0.pop();
174 }
175 }
176 }
177}
178
179impl Encode for VariableBitSet {
180 fn encode(&self, mut w: impl Write) -> anyhow::Result<()> {
181 ensure!(
182 i32::try_from(self.0.len()).is_ok(),
183 "length of bit set exceeds i32::MAX (got {})",
184 self.0.len()
185 );
186
187 VarInt(self.0.len() as i32).encode(&mut w)?;
188 i64::encode_slice(&self.0, w)
189 }
190}
191
192impl Decode<'_> for VariableBitSet {
193 fn decode(r: &mut &[u8]) -> anyhow::Result<Self> {
194 Ok(Self(Vec::<i64>::decode(r)?))
195 }
196}
197
198impl fmt::Debug for VariableBitSet {
199 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
200 fmt::Display::fmt(self, f)
201 }
202}
203
204impl fmt::Display for VariableBitSet {
205 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
206 if self.0.is_empty() {
207 return write!(f, "0b0");
208 }
209
210 write!(f, "0b")?;
211
212 for i in (0..self.0.len() * 64).rev() {
213 write!(f, "{}", u8::from(self.bit(i)))?;
214 }
215
216 Ok(())
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn fixed_bit_set_ops() {
226 let mut bits = FixedBitSet::<20, 3>::default();
227
228 assert!(!bits.bit(5));
229 bits.set(5);
230 assert!(bits.bit(5));
231 assert_eq!(bits.0, [0b00100000, 0, 0]);
232
233 bits.clear(5);
234 assert!(!bits.bit(5));
235 assert_eq!(bits.0, [0, 0, 0]);
236 }
237
238 #[test]
239 #[cfg_attr(debug_assertions, should_panic)]
240 fn fixed_bit_set_out_of_range_is_ignored() {
241 let mut bits = FixedBitSet::<20, 3>::default();
242
243 assert!(!bits.bit(20));
244 bits.set(20);
245 bits.clear(20);
246 assert_eq!(bits.0, [0, 0, 0]);
247 }
248
249 #[test]
250 fn display_fixed_bit_set() {
251 let mut bits = FixedBitSet::<20, 3>::default();
252 bits.set(5);
253
254 assert_eq!(format!("{bits}"), "0b00000000000000100000");
255 }
256
257 #[test]
258 fn variable_bit_set_ops() {
259 let mut bits = VariableBitSet::default();
260
261 assert!(!bits.bit(70));
262 bits.set(70);
263 assert!(bits.bit(70));
264 assert_eq!(bits.0, vec![0, 0b0100_0000]);
265
266 bits.clear(70);
267 assert!(!bits.bit(70));
268 assert!(bits.0.is_empty());
269 }
270}