Add zstd compressor in util (#15779)

Signed-off-by: yah01 <yang.cen@zilliz.com>
This commit is contained in:
yah01 2022-03-08 18:17:58 +08:00 committed by GitHub
parent dff08dbf47
commit 565524ed3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 289 additions and 0 deletions

View File

@ -0,0 +1,138 @@
package compressor
import (
"io"
"github.com/klauspost/compress/zstd"
)
type CompressType int16
const (
Zstd CompressType = iota + 1
DefaultCompressAlgorithm CompressType = Zstd
)
type Compressor interface {
Compress(in io.Reader) error
ResetWriter(out io.Writer)
// Flush() error
Close() error
}
type Decompressor interface {
Decompress(out io.Writer) error
ResetReader(in io.Reader)
Close()
}
var (
_ Compressor = (*ZstdCompressor)(nil)
_ Decompressor = (*ZstdDecompressor)(nil)
)
type ZstdCompressor struct {
encoder *zstd.Encoder
}
func NewZstdCompressor(out io.Writer, opts ...zstd.EOption) (*ZstdCompressor, error) {
encoder, err := zstd.NewWriter(out, opts...)
if err != nil {
return nil, err
}
return &ZstdCompressor{encoder}, nil
}
// Call Close() to make sure the data is flushed to the underlying writer
// after the last Compress() call
func (c *ZstdCompressor) Compress(in io.Reader) error {
_, err := io.Copy(c.encoder, in)
if err != nil {
c.encoder.Close()
return err
}
return nil
}
func (c *ZstdCompressor) ResetWriter(out io.Writer) {
c.encoder.Reset(out)
}
// The Flush() seems to not work as expected, remove it for now
// Replace it with Close()
// func (c *ZstdCompressor) Flush() error {
// if c.encoder != nil {
// return c.encoder.Flush()
// }
// return nil
// }
// The compressor is still re-used after calling this
func (c *ZstdCompressor) Close() error {
return c.encoder.Close()
}
type ZstdDecompressor struct {
decoder *zstd.Decoder
}
func NewZstdDecompressor(in io.Reader, opts ...zstd.DOption) (*ZstdDecompressor, error) {
decoder, err := zstd.NewReader(in, opts...)
if err != nil {
return nil, err
}
return &ZstdDecompressor{decoder}, nil
}
func (dec *ZstdDecompressor) Decompress(out io.Writer) error {
_, err := io.Copy(out, dec.decoder)
if err != nil {
dec.decoder.Close()
return err
}
return nil
}
func (dec *ZstdDecompressor) ResetReader(in io.Reader) {
dec.decoder.Reset(in)
}
// NOTICE: not like compressor, the decompressor is not usable after calling this
func (dec *ZstdDecompressor) Close() {
dec.decoder.Close()
}
// Global methods
func ZstdCompress(in io.Reader, out io.Writer, opts ...zstd.EOption) error {
enc, err := NewZstdCompressor(out, opts...)
if err != nil {
return err
}
if err = enc.Compress(in); err != nil {
enc.Close()
return err
}
return enc.Close()
}
func ZstdDecompress(in io.Reader, out io.Writer, opts ...zstd.DOption) error {
dec, err := NewZstdDecompressor(in, opts...)
if err != nil {
return err
}
defer dec.Close()
if err = dec.Decompress(out); err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,134 @@
package compressor
import (
"bytes"
"fmt"
"io"
"runtime"
"strings"
"sync"
"testing"
"github.com/klauspost/compress/zstd"
"github.com/milvus-io/milvus/internal/util/mock"
"github.com/stretchr/testify/assert"
)
func TestZstdCompress(t *testing.T) {
data := "hello zstd algorithm!"
compressed := new(bytes.Buffer)
origin := new(bytes.Buffer)
enc, err := NewZstdCompressor(compressed)
assert.NoError(t, err)
testCompress(t, data, enc, compressed, origin)
// Reuse test
compressed.Reset()
origin.Reset()
enc.ResetWriter(compressed)
testCompress(t, data+": reuse", enc, compressed, origin)
}
func testCompress(t *testing.T, data string, enc Compressor, compressed, origin *bytes.Buffer) {
err := enc.Compress(strings.NewReader(data))
assert.NoError(t, err)
err = enc.Close()
assert.NoError(t, err)
// Close() method should satisfy idempotence
err = enc.Close()
assert.NoError(t, err)
dec, err := NewZstdDecompressor(compressed)
assert.NoError(t, err)
err = dec.Decompress(origin)
assert.NoError(t, err)
assert.Equal(t, data, origin.String())
// Mock error reader/writer
errReader := &mock.ErrReader{Err: io.ErrUnexpectedEOF}
errWriter := &mock.ErrWriter{Err: io.ErrShortWrite}
err = enc.Compress(errReader)
assert.ErrorIs(t, err, errReader.Err)
dec.ResetReader(bytes.NewReader(compressed.Bytes()))
err = dec.Decompress(errWriter)
assert.ErrorIs(t, err, errWriter.Err)
// Use closed decompressor
dec.ResetReader(bytes.NewReader(compressed.Bytes()))
dec.Close()
err = dec.Decompress(origin)
assert.Error(t, err)
}
func TestGlobalMethods(t *testing.T) {
data := "hello zstd algorithm!"
compressed := new(bytes.Buffer)
origin := new(bytes.Buffer)
err := ZstdCompress(strings.NewReader(data), compressed)
assert.NoError(t, err)
err = ZstdDecompress(compressed, origin)
assert.NoError(t, err)
assert.Equal(t, data, origin.String())
// Mock error reader/writer
errReader := &mock.ErrReader{Err: io.ErrUnexpectedEOF}
errWriter := &mock.ErrWriter{Err: io.ErrShortWrite}
compressedBytes := compressed.Bytes()
compressed = bytes.NewBuffer(compressedBytes) // The old compressed buffer is closed
err = ZstdCompress(errReader, compressed)
assert.ErrorIs(t, err, errReader.Err)
assert.Positive(t, len(compressedBytes))
reader := bytes.NewReader(compressedBytes)
err = ZstdDecompress(reader, errWriter)
assert.ErrorIs(t, err, errWriter.Err)
// Incorrect option
err = ZstdCompress(strings.NewReader(data), compressed, zstd.WithWindowSize(3))
assert.Error(t, err)
err = ZstdDecompress(compressed, origin, zstd.WithDecoderConcurrency(0))
assert.Error(t, err)
}
func TestCurrencyGlobalMethods(t *testing.T) {
prefix := "Test Currency Global Methods"
currency := runtime.GOMAXPROCS(0) * 2
if currency < 6 {
currency = 6
}
wg := sync.WaitGroup{}
wg.Add(currency)
for i := 0; i < currency; i++ {
go func(idx int) {
defer wg.Done()
buf := new(bytes.Buffer)
origin := new(bytes.Buffer)
data := prefix + fmt.Sprintf(": %d-th goroutine", idx)
err := ZstdCompress(strings.NewReader(data), buf, zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(idx)))
assert.NoError(t, err)
err = ZstdDecompress(buf, origin)
assert.NoError(t, err)
assert.Equal(t, data, origin.String())
}(i)
}
wg.Wait()
}

View File

@ -0,0 +1,17 @@
package mock
type ErrReader struct {
Err error
}
func (r *ErrReader) Read(p []byte) (n int, err error) {
return 0, r.Err
}
type ErrWriter struct {
Err error
}
func (w *ErrWriter) Write(p []byte) (n int, err error) {
return 0, w.Err
}