Add methods to compress/decompress small blocks (#15980)

Signed-off-by: yah01 <yang.cen@zilliz.com>
This commit is contained in:
yah01 2022-03-11 14:55:59 +08:00 committed by GitHub
parent 3121619758
commit 7d2934e4c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 94 additions and 7 deletions

View File

@ -6,25 +6,29 @@ import (
"github.com/klauspost/compress/zstd" "github.com/klauspost/compress/zstd"
) )
type CompressType int16 type CompressType string
const ( const (
Zstd CompressType = iota + 1 CompressTypeZstd CompressType = "zstd"
DefaultCompressAlgorithm CompressType = Zstd DefaultCompressAlgorithm CompressType = CompressTypeZstd
) )
type Compressor interface { type Compressor interface {
Compress(in io.Reader) error Compress(in io.Reader) error
CompressBytes(src, dst []byte) []byte
ResetWriter(out io.Writer) ResetWriter(out io.Writer)
// Flush() error // Flush() error
Close() error Close() error
GetType() CompressType
} }
type Decompressor interface { type Decompressor interface {
Decompress(out io.Writer) error Decompress(out io.Writer) error
DecompressBytes(src, dst []byte) ([]byte, error)
ResetReader(in io.Reader) ResetReader(in io.Reader)
Close() Close()
GetType() CompressType
} }
var ( var (
@ -36,6 +40,7 @@ type ZstdCompressor struct {
encoder *zstd.Encoder encoder *zstd.Encoder
} }
// For compressing small blocks, pass nil to the `out` parameter
func NewZstdCompressor(out io.Writer, opts ...zstd.EOption) (*ZstdCompressor, error) { func NewZstdCompressor(out io.Writer, opts ...zstd.EOption) (*ZstdCompressor, error) {
encoder, err := zstd.NewWriter(out, opts...) encoder, err := zstd.NewWriter(out, opts...)
if err != nil { if err != nil {
@ -45,6 +50,7 @@ func NewZstdCompressor(out io.Writer, opts ...zstd.EOption) (*ZstdCompressor, er
return &ZstdCompressor{encoder}, nil return &ZstdCompressor{encoder}, nil
} }
// Use case: compress stream
// Call Close() to make sure the data is flushed to the underlying writer // Call Close() to make sure the data is flushed to the underlying writer
// after the last Compress() call // after the last Compress() call
func (c *ZstdCompressor) Compress(in io.Reader) error { func (c *ZstdCompressor) Compress(in io.Reader) error {
@ -57,6 +63,14 @@ func (c *ZstdCompressor) Compress(in io.Reader) error {
return nil return nil
} }
// Use case: compress small blocks
// This compresses the src bytes and appends it to the dst bytes, then return the result
// This can be called concurrently
func (c *ZstdCompressor) CompressBytes(src []byte, dst []byte) []byte {
return c.encoder.EncodeAll(src, dst)
}
// Reset the writer to reuse the compressor
func (c *ZstdCompressor) ResetWriter(out io.Writer) { func (c *ZstdCompressor) ResetWriter(out io.Writer) {
c.encoder.Reset(out) c.encoder.Reset(out)
} }
@ -76,10 +90,15 @@ func (c *ZstdCompressor) Close() error {
return c.encoder.Close() return c.encoder.Close()
} }
func (c *ZstdCompressor) GetType() CompressType {
return CompressTypeZstd
}
type ZstdDecompressor struct { type ZstdDecompressor struct {
decoder *zstd.Decoder decoder *zstd.Decoder
} }
// For compressing small blocks, pass nil to the `in` parameter
func NewZstdDecompressor(in io.Reader, opts ...zstd.DOption) (*ZstdDecompressor, error) { func NewZstdDecompressor(in io.Reader, opts ...zstd.DOption) (*ZstdDecompressor, error) {
decoder, err := zstd.NewReader(in, opts...) decoder, err := zstd.NewReader(in, opts...)
if err != nil { if err != nil {
@ -89,6 +108,8 @@ func NewZstdDecompressor(in io.Reader, opts ...zstd.DOption) (*ZstdDecompressor,
return &ZstdDecompressor{decoder}, nil return &ZstdDecompressor{decoder}, nil
} }
// Usa case: decompress stream
// Write the decompressed data into `out`
func (dec *ZstdDecompressor) Decompress(out io.Writer) error { func (dec *ZstdDecompressor) Decompress(out io.Writer) error {
_, err := io.Copy(out, dec.decoder) _, err := io.Copy(out, dec.decoder)
if err != nil { if err != nil {
@ -99,6 +120,14 @@ func (dec *ZstdDecompressor) Decompress(out io.Writer) error {
return nil return nil
} }
// Use case: decompress small blocks
// This decompresses the src bytes and appends it to the dst bytes, then return the result
// This can be called concurrently
func (dec *ZstdDecompressor) DecompressBytes(src []byte, dst []byte) ([]byte, error) {
return dec.decoder.DecodeAll(src, dst)
}
// Reset the reader to reuse the decompressor
func (dec *ZstdDecompressor) ResetReader(in io.Reader) { func (dec *ZstdDecompressor) ResetReader(in io.Reader) {
dec.decoder.Reset(in) dec.decoder.Reset(in)
} }
@ -108,7 +137,15 @@ func (dec *ZstdDecompressor) Close() {
dec.decoder.Close() dec.decoder.Close()
} }
func (dec *ZstdDecompressor) GetType() CompressType {
return CompressTypeZstd
}
// Global methods // Global methods
// Usa case: compress stream, large object only once
// This can be called concurrently
// Try ZstdCompressor for better efficiency if you need compress mutiple streams one by one
func ZstdCompress(in io.Reader, out io.Writer, opts ...zstd.EOption) error { func ZstdCompress(in io.Reader, out io.Writer, opts ...zstd.EOption) error {
enc, err := NewZstdCompressor(out, opts...) enc, err := NewZstdCompressor(out, opts...)
if err != nil { if err != nil {
@ -123,6 +160,9 @@ func ZstdCompress(in io.Reader, out io.Writer, opts ...zstd.EOption) error {
return enc.Close() return enc.Close()
} }
// Use case: decompress stream, large object only once
// This can be called concurrently
// Try ZstdDecompressor for better efficiency if you need decompress mutiple streams one by one
func ZstdDecompress(in io.Reader, out io.Writer, opts ...zstd.DOption) error { func ZstdDecompress(in io.Reader, out io.Writer, opts ...zstd.DOption) error {
dec, err := NewZstdDecompressor(in, opts...) dec, err := NewZstdDecompressor(in, opts...)
if err != nil { if err != nil {
@ -136,3 +176,20 @@ func ZstdDecompress(in io.Reader, out io.Writer, opts ...zstd.DOption) error {
return nil return nil
} }
var (
globalZstdCompressor, _ = zstd.NewWriter(nil)
globalZstdDecompressor, _ = zstd.NewReader(nil)
)
// Use case: compress small blocks
// This can be called concurrently
func ZstdCompressBytes(src, dst []byte) []byte {
return globalZstdCompressor.EncodeAll(src, dst)
}
// Use case: decompress small blocks
// This can be called concurrently
func ZstdDecompressBytes(src, dst []byte) ([]byte, error) {
return globalZstdDecompressor.DecodeAll(src, dst)
}

View File

@ -30,13 +30,24 @@ func TestZstdCompress(t *testing.T) {
enc.ResetWriter(compressed) enc.ResetWriter(compressed)
testCompress(t, data+": reuse", enc, compressed, origin) testCompress(t, data+": reuse", enc, compressed, origin)
// Test type
dec, err := NewZstdDecompressor(nil)
assert.NoError(t, err)
assert.Equal(t, enc.GetType(), CompressTypeZstd)
assert.Equal(t, dec.GetType(), CompressTypeZstd)
} }
func testCompress(t *testing.T, data string, enc Compressor, compressed, origin *bytes.Buffer) { func testCompress(t *testing.T, data string, enc Compressor, compressed, origin *bytes.Buffer) {
compressedBytes := make([]byte, 0)
originBytes := make([]byte, 0)
err := enc.Compress(strings.NewReader(data)) err := enc.Compress(strings.NewReader(data))
assert.NoError(t, err) assert.NoError(t, err)
err = enc.Close() err = enc.Close()
assert.NoError(t, err) assert.NoError(t, err)
compressedBytes = enc.CompressBytes([]byte(data), compressedBytes)
assert.Equal(t, compressed.Bytes(), compressedBytes)
// Close() method should satisfy idempotence // Close() method should satisfy idempotence
err = enc.Close() err = enc.Close()
@ -46,6 +57,9 @@ func testCompress(t *testing.T, data string, enc Compressor, compressed, origin
assert.NoError(t, err) assert.NoError(t, err)
err = dec.Decompress(origin) err = dec.Decompress(origin)
assert.NoError(t, err) assert.NoError(t, err)
originBytes, err = dec.DecompressBytes(compressedBytes, originBytes)
assert.NoError(t, err)
assert.Equal(t, origin.Bytes(), originBytes)
assert.Equal(t, data, origin.String()) assert.Equal(t, data, origin.String())
@ -70,21 +84,30 @@ func testCompress(t *testing.T, data string, enc Compressor, compressed, origin
func TestGlobalMethods(t *testing.T) { func TestGlobalMethods(t *testing.T) {
data := "hello zstd algorithm!" data := "hello zstd algorithm!"
compressed := new(bytes.Buffer) compressed := new(bytes.Buffer)
compressedBytes := make([]byte, 0)
origin := new(bytes.Buffer) origin := new(bytes.Buffer)
originBytes := make([]byte, 0)
err := ZstdCompress(strings.NewReader(data), compressed) err := ZstdCompress(strings.NewReader(data), compressed)
assert.NoError(t, err) assert.NoError(t, err)
compressedBytes = ZstdCompressBytes([]byte(data), compressedBytes)
assert.Equal(t, compressed.Bytes(), compressedBytes)
err = ZstdDecompress(compressed, origin) err = ZstdDecompress(compressed, origin)
assert.NoError(t, err) assert.NoError(t, err)
originBytes, err = ZstdDecompressBytes(compressedBytes, originBytes)
assert.NoError(t, err)
assert.Equal(t, origin.Bytes(), originBytes)
assert.Equal(t, data, origin.String()) assert.Equal(t, data, origin.String())
// Mock error reader/writer // Mock error reader/writer
errReader := &mock.ErrReader{Err: io.ErrUnexpectedEOF} errReader := &mock.ErrReader{Err: io.ErrUnexpectedEOF}
errWriter := &mock.ErrWriter{Err: io.ErrShortWrite} errWriter := &mock.ErrWriter{Err: io.ErrShortWrite}
compressedBytes := compressed.Bytes() compressedBytes = compressed.Bytes()
compressed = bytes.NewBuffer(compressedBytes) // The old compressed buffer is closed compressed = bytes.NewBuffer(compressedBytes) // The old compressed buffer is closed
err = ZstdCompress(errReader, compressed) err = ZstdCompress(errReader, compressed)
assert.ErrorIs(t, err, errReader.Err) assert.ErrorIs(t, err, errReader.Err)
@ -116,16 +139,23 @@ func TestCurrencyGlobalMethods(t *testing.T) {
go func(idx int) { go func(idx int) {
defer wg.Done() defer wg.Done()
buf := new(bytes.Buffer) compressed := new(bytes.Buffer)
compressedBytes := make([]byte, 0)
origin := new(bytes.Buffer) origin := new(bytes.Buffer)
originBytes := make([]byte, 0)
data := prefix + fmt.Sprintf(": %d-th goroutine", idx) data := prefix + fmt.Sprintf(": %d-th goroutine", idx)
err := ZstdCompress(strings.NewReader(data), buf, zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(idx))) err := ZstdCompress(strings.NewReader(data), compressed, zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(idx)))
assert.NoError(t, err) assert.NoError(t, err)
compressedBytes = ZstdCompressBytes([]byte(data), compressedBytes)
assert.Equal(t, compressed.Bytes(), compressedBytes)
err = ZstdDecompress(buf, origin) err = ZstdDecompress(compressed, origin)
assert.NoError(t, err) assert.NoError(t, err)
originBytes, err = ZstdDecompressBytes(compressedBytes, originBytes)
assert.NoError(t, err)
assert.Equal(t, origin.Bytes(), originBytes)
assert.Equal(t, data, origin.String()) assert.Equal(t, data, origin.String())
}(i) }(i)