mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-12-02 03:48:37 +08:00
Add zstd compressor in util (#15779)
Signed-off-by: yah01 <yang.cen@zilliz.com>
This commit is contained in:
parent
dff08dbf47
commit
565524ed3c
138
internal/util/compressor/compressor.go
Normal file
138
internal/util/compressor/compressor.go
Normal file
@ -0,0 +1,138 @@
|
||||
package compressor
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
)
|
||||
|
||||
type CompressType int16
|
||||
|
||||
const (
|
||||
Zstd CompressType = iota + 1
|
||||
|
||||
DefaultCompressAlgorithm CompressType = Zstd
|
||||
)
|
||||
|
||||
type Compressor interface {
|
||||
Compress(in io.Reader) error
|
||||
ResetWriter(out io.Writer)
|
||||
// Flush() error
|
||||
Close() error
|
||||
}
|
||||
|
||||
type Decompressor interface {
|
||||
Decompress(out io.Writer) error
|
||||
ResetReader(in io.Reader)
|
||||
Close()
|
||||
}
|
||||
|
||||
var (
|
||||
_ Compressor = (*ZstdCompressor)(nil)
|
||||
_ Decompressor = (*ZstdDecompressor)(nil)
|
||||
)
|
||||
|
||||
type ZstdCompressor struct {
|
||||
encoder *zstd.Encoder
|
||||
}
|
||||
|
||||
func NewZstdCompressor(out io.Writer, opts ...zstd.EOption) (*ZstdCompressor, error) {
|
||||
encoder, err := zstd.NewWriter(out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ZstdCompressor{encoder}, nil
|
||||
}
|
||||
|
||||
// Call Close() to make sure the data is flushed to the underlying writer
|
||||
// after the last Compress() call
|
||||
func (c *ZstdCompressor) Compress(in io.Reader) error {
|
||||
_, err := io.Copy(c.encoder, in)
|
||||
if err != nil {
|
||||
c.encoder.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ZstdCompressor) ResetWriter(out io.Writer) {
|
||||
c.encoder.Reset(out)
|
||||
}
|
||||
|
||||
// The Flush() seems to not work as expected, remove it for now
|
||||
// Replace it with Close()
|
||||
// func (c *ZstdCompressor) Flush() error {
|
||||
// if c.encoder != nil {
|
||||
// return c.encoder.Flush()
|
||||
// }
|
||||
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// The compressor is still re-used after calling this
|
||||
func (c *ZstdCompressor) Close() error {
|
||||
return c.encoder.Close()
|
||||
}
|
||||
|
||||
type ZstdDecompressor struct {
|
||||
decoder *zstd.Decoder
|
||||
}
|
||||
|
||||
func NewZstdDecompressor(in io.Reader, opts ...zstd.DOption) (*ZstdDecompressor, error) {
|
||||
decoder, err := zstd.NewReader(in, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ZstdDecompressor{decoder}, nil
|
||||
}
|
||||
|
||||
func (dec *ZstdDecompressor) Decompress(out io.Writer) error {
|
||||
_, err := io.Copy(out, dec.decoder)
|
||||
if err != nil {
|
||||
dec.decoder.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dec *ZstdDecompressor) ResetReader(in io.Reader) {
|
||||
dec.decoder.Reset(in)
|
||||
}
|
||||
|
||||
// NOTICE: not like compressor, the decompressor is not usable after calling this
|
||||
func (dec *ZstdDecompressor) Close() {
|
||||
dec.decoder.Close()
|
||||
}
|
||||
|
||||
// Global methods
|
||||
func ZstdCompress(in io.Reader, out io.Writer, opts ...zstd.EOption) error {
|
||||
enc, err := NewZstdCompressor(out, opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = enc.Compress(in); err != nil {
|
||||
enc.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return enc.Close()
|
||||
}
|
||||
|
||||
func ZstdDecompress(in io.Reader, out io.Writer, opts ...zstd.DOption) error {
|
||||
dec, err := NewZstdDecompressor(in, opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer dec.Close()
|
||||
|
||||
if err = dec.Decompress(out); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
134
internal/util/compressor/compressor_test.go
Normal file
134
internal/util/compressor/compressor_test.go
Normal file
@ -0,0 +1,134 @@
|
||||
package compressor
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/milvus-io/milvus/internal/util/mock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestZstdCompress(t *testing.T) {
|
||||
data := "hello zstd algorithm!"
|
||||
compressed := new(bytes.Buffer)
|
||||
origin := new(bytes.Buffer)
|
||||
|
||||
enc, err := NewZstdCompressor(compressed)
|
||||
assert.NoError(t, err)
|
||||
testCompress(t, data, enc, compressed, origin)
|
||||
|
||||
// Reuse test
|
||||
compressed.Reset()
|
||||
origin.Reset()
|
||||
|
||||
enc.ResetWriter(compressed)
|
||||
|
||||
testCompress(t, data+": reuse", enc, compressed, origin)
|
||||
}
|
||||
|
||||
func testCompress(t *testing.T, data string, enc Compressor, compressed, origin *bytes.Buffer) {
|
||||
err := enc.Compress(strings.NewReader(data))
|
||||
assert.NoError(t, err)
|
||||
err = enc.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Close() method should satisfy idempotence
|
||||
err = enc.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
dec, err := NewZstdDecompressor(compressed)
|
||||
assert.NoError(t, err)
|
||||
err = dec.Decompress(origin)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, data, origin.String())
|
||||
|
||||
// Mock error reader/writer
|
||||
errReader := &mock.ErrReader{Err: io.ErrUnexpectedEOF}
|
||||
errWriter := &mock.ErrWriter{Err: io.ErrShortWrite}
|
||||
|
||||
err = enc.Compress(errReader)
|
||||
assert.ErrorIs(t, err, errReader.Err)
|
||||
|
||||
dec.ResetReader(bytes.NewReader(compressed.Bytes()))
|
||||
err = dec.Decompress(errWriter)
|
||||
assert.ErrorIs(t, err, errWriter.Err)
|
||||
|
||||
// Use closed decompressor
|
||||
dec.ResetReader(bytes.NewReader(compressed.Bytes()))
|
||||
dec.Close()
|
||||
err = dec.Decompress(origin)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestGlobalMethods(t *testing.T) {
|
||||
data := "hello zstd algorithm!"
|
||||
compressed := new(bytes.Buffer)
|
||||
origin := new(bytes.Buffer)
|
||||
|
||||
err := ZstdCompress(strings.NewReader(data), compressed)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = ZstdDecompress(compressed, origin)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, data, origin.String())
|
||||
|
||||
// Mock error reader/writer
|
||||
errReader := &mock.ErrReader{Err: io.ErrUnexpectedEOF}
|
||||
errWriter := &mock.ErrWriter{Err: io.ErrShortWrite}
|
||||
|
||||
compressedBytes := compressed.Bytes()
|
||||
compressed = bytes.NewBuffer(compressedBytes) // The old compressed buffer is closed
|
||||
err = ZstdCompress(errReader, compressed)
|
||||
assert.ErrorIs(t, err, errReader.Err)
|
||||
|
||||
assert.Positive(t, len(compressedBytes))
|
||||
reader := bytes.NewReader(compressedBytes)
|
||||
err = ZstdDecompress(reader, errWriter)
|
||||
assert.ErrorIs(t, err, errWriter.Err)
|
||||
|
||||
// Incorrect option
|
||||
err = ZstdCompress(strings.NewReader(data), compressed, zstd.WithWindowSize(3))
|
||||
assert.Error(t, err)
|
||||
|
||||
err = ZstdDecompress(compressed, origin, zstd.WithDecoderConcurrency(0))
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestCurrencyGlobalMethods(t *testing.T) {
|
||||
prefix := "Test Currency Global Methods"
|
||||
|
||||
currency := runtime.GOMAXPROCS(0) * 2
|
||||
if currency < 6 {
|
||||
currency = 6
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(currency)
|
||||
for i := 0; i < currency; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
origin := new(bytes.Buffer)
|
||||
|
||||
data := prefix + fmt.Sprintf(": %d-th goroutine", idx)
|
||||
|
||||
err := ZstdCompress(strings.NewReader(data), buf, zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(idx)))
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = ZstdDecompress(buf, origin)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, data, origin.String())
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
17
internal/util/mock/io_fake.go
Normal file
17
internal/util/mock/io_fake.go
Normal file
@ -0,0 +1,17 @@
|
||||
package mock
|
||||
|
||||
type ErrReader struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (r *ErrReader) Read(p []byte) (n int, err error) {
|
||||
return 0, r.Err
|
||||
}
|
||||
|
||||
type ErrWriter struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (w *ErrWriter) Write(p []byte) (n int, err error) {
|
||||
return 0, w.Err
|
||||
}
|
Loading…
Reference in New Issue
Block a user