use std::{
fmt,
io::{self, ErrorKind, Seek, SeekFrom, Write},
};
use flate2::{write::ZlibEncoder, Compress, FlushCompress};
use thiserror::Error;
pub use flate2::Compression;
use super::CHUNK_LEN;
struct SegmentedEncoderRaw<W> {
inner: W,
data: Compress,
written: usize,
consumed: usize,
buf_in: Vec<u8>,
buf_out: Vec<u8>,
}
impl<W: Write + Seek> SegmentedEncoderRaw<W> {
#[allow(dead_code)]
pub fn new(level: Compression, mut inner: W) -> io::Result<Self> {
inner.write_all(super::MAGIC)?;
inner.write_all(&0u32.to_le_bytes())?;
Ok(Self {
inner,
written: 0,
consumed: 0,
buf_in: Vec::with_capacity(1024),
buf_out: Vec::with_capacity(1024),
data: Compress::new(level, true),
})
}
}
impl<W: Write + Seek> Write for SegmentedEncoderRaw<W> {
fn write(&mut self, mut buf: &[u8]) -> std::io::Result<usize> {
let sum = self.consumed + buf.len();
let _spillover = sum > CHUNK_LEN;
let mut z_avail = (CHUNK_LEN - self.consumed).min(buf.len());
let mut d_in = 0;
loop {
let c_in = self.data.total_in();
let buf_in_avail = self.buf_in.capacity() - self.buf_in.len();
let z_input = if self.buf_in.is_empty() {
&buf[..z_avail]
} else {
let take = buf_in_avail.min(z_avail);
self.buf_in.extend_from_slice(&buf[..take]);
d_in += take;
buf = &buf[take..z_avail];
self.buf_in.as_slice()
};
let flush = if self.consumed + d_in == CHUNK_LEN {
FlushCompress::Finish
} else {
FlushCompress::None
};
let status = self.data.compress_vec(z_input, &mut self.buf_out, flush)?;
let consumed = (self.data.total_in() - c_in) as usize;
if self.buf_in.is_empty() {
d_in += consumed;
buf = &buf[consumed..];
} else {
self.buf_in.splice(..consumed, std::iter::empty());
}
z_avail -= consumed;
self.inner.write_all(&self.buf_out)?;
self.written += self.buf_out.len();
self.buf_out.clear();
match status {
flate2::Status::Ok | flate2::Status::BufError => {
if z_avail == 0 {
break;
}
}
flate2::Status::StreamEnd => break,
}
}
self.consumed += d_in;
Ok(d_in)
}
fn flush(&mut self) -> std::io::Result<()> {
let pos = self.inner.stream_position()?;
let diff = self.written as i64;
self.inner.seek(SeekFrom::Current(-diff - 4))?;
self.inner.write_all(&(self.written as u32).to_le_bytes())?;
self.inner.seek(SeekFrom::Start(pos))?;
self.inner.flush()
}
}
#[derive(Debug, Error)]
pub enum Error {
Io(#[from] io::Error),
FinishOnInvalid,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::FinishOnInvalid => write!(f, "Called finish on invalid"),
Self::Io(_) => write!(f, "I/O error"),
}
}
}
impl From<Error> for io::Error {
fn from(e: Error) -> Self {
io::Error::new(ErrorKind::Other, e)
}
}
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug)]
enum EncoderKind<W: Write> {
Ok(ZlibEncoder<W>),
Initial(W),
Invalid,
}
impl<W: Write + Seek> EncoderKind<W> {
fn take(&mut self) -> Self {
std::mem::replace(self, Self::Invalid)
}
fn finish(self) -> Result<W> {
match self {
Self::Ok(mut z) => {
let mut total = z.total_out();
let a_pos = z.get_mut().stream_position()?;
let mut inner = z.finish()?;
let b_pos = inner.stream_position()?;
total += b_pos - a_pos;
patch_total(&mut inner, total as u32)?;
Ok(inner)
}
Self::Initial(w) => Ok(w),
Self::Invalid => Err(Error::FinishOnInvalid),
}
}
}
pub struct SegmentedEncoder<W: Write + Seek> {
inner: EncoderKind<W>,
level: Compression,
}
impl<W: Write + Seek> Drop for SegmentedEncoder<W> {
fn drop(&mut self) {
let _ = self.finish();
}
}
impl<W: Write + Seek> SegmentedEncoder<W> {
pub fn new(mut inner: W, level: Compression) -> Result<Self> {
inner.write_all(super::MAGIC)?;
Ok(Self {
level,
inner: EncoderKind::Initial(inner),
})
}
}
fn patch_total<W: Write + Seek>(inner: &mut W, total: u32) -> Result<()> {
let ti64 = i64::from(total);
inner.seek(SeekFrom::Current(-4 - ti64))?;
inner.write_all(&total.to_le_bytes())?;
inner.seek(SeekFrom::Current(ti64))?;
Ok(())
}
impl<W: Write + Seek> SegmentedEncoder<W> {
pub fn finish(&mut self) -> Result<W> {
let inner = self.inner.take().finish()?;
Ok(inner)
}
}
impl<W: Write + Seek> Write for SegmentedEncoder<W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if let EncoderKind::Ok(z) = &mut self.inner {
let sum = z.total_out() as usize + buf.len();
let spillover = sum > CHUNK_LEN;
let avail = if spillover {
CHUNK_LEN - z.total_out() as usize
} else {
buf.len()
};
z.write_all(&buf[..avail])?;
if sum >= CHUNK_LEN {
let w = self.finish()?;
self.inner = EncoderKind::Initial(w);
}
if spillover {
self.write(&buf[avail..]).map(|l| avail + l)
} else {
Ok(avail)
}
} else if let EncoderKind::Initial(mut w) = self.inner.take() {
w.write_all(&0u32.to_le_bytes())?;
self.inner = EncoderKind::Ok(ZlibEncoder::new(w, self.level));
self.write(buf)
} else {
panic!("Called write on an invalid encoder");
}
}
fn flush(&mut self) -> io::Result<()> {
if let EncoderKind::Ok(z) = &mut self.inner {
z.flush()?;
let total = z.total_out() as u32;
patch_total(z.get_mut(), total)?;
}
Ok(())
}
}