diff --git a/g/crypto/gcrc32/gcrc32.go b/g/crypto/gcrc32/gcrc32.go index 8ee9f1a99..60ea5f7ff 100644 --- a/g/crypto/gcrc32/gcrc32.go +++ b/g/crypto/gcrc32/gcrc32.go @@ -4,17 +4,25 @@ // If a copy of the MIT was not distributed with this file, // You can obtain one at https://github.com/gogf/gf. -// Package gcrc32 provides useful API for CRC32 encryption/decryption algorithms. +// Package gcrc32 provides API for CRC32 encryption/decryption algorithm. package gcrc32 import ( "hash/crc32" ) +// Encrypt encrypts bytes using CRC32 algorithm. +func Encrypt(v []byte) uint32 { + return crc32.ChecksumIEEE(v) +} + +// EncryptString encrypts string using CRC32 algorithm. func EncryptString(v string) uint32 { return crc32.ChecksumIEEE([]byte(v)) } +// Alias of Encrypt. +// Deprecated. func EncryptBytes(v []byte) uint32 { - return crc32.ChecksumIEEE(v) + return Encrypt(v) } diff --git a/g/net/gtcp/gtcp_conn.go b/g/net/gtcp/gtcp_conn.go index a608c848c..e77b4554d 100644 --- a/g/net/gtcp/gtcp_conn.go +++ b/g/net/gtcp/gtcp_conn.go @@ -87,7 +87,7 @@ func (c *Conn) Send(data []byte, retry...Retry) error { // 获取数据,指定读取的数据长度(length < 1表示获取所有可读数据),以及重试策略(retry) // 需要注意: // 1、往往在socket通信中需要指定固定的数据结构,并在设定对应的长度字段,并在读取数据时便于区分包大小; -// 2、当length < 1时表示获取缓冲区所有的数据,但是可能会引起包解析问题(可能出现非完整的包情况),因此需要解析端注意解析策略; +// 2、当length < 1时表示获取缓冲区所有的数据,但是可能会引起包解析问题(可能出现断包情况),因此需要解析端注意解析策略; func (c *Conn) Recv(length int, retry...Retry) ([]byte, error) { var err error // 读取错误 var size int // 读取长度 diff --git a/g/net/gtcp/gtcp_conn_pkg.go b/g/net/gtcp/gtcp_conn_pkg.go index cd0a8ee53..2d88d2d1f 100644 --- a/g/net/gtcp/gtcp_conn_pkg.go +++ b/g/net/gtcp/gtcp_conn_pkg.go @@ -10,84 +10,130 @@ import ( "encoding/binary" "errors" "fmt" + "github.com/gogf/gf/g/crypto/gcrc32" "time" ) const ( - // 允许最大的简单协议包大小(byte), 15MB - PKG_MAX_SIZE = 0xFFFFFF - // 消息包头大小: "总长度"3字节+"校验码"4字节 - PKG_HEADER_SIZE = 7 + // 默认允许最大的简单协议包大小(byte), 65535 byte + gPKG_MAX_SIZE = 65535 ) -// 根据简单协议发送数据包。 -// 简单协议数据格式:总长度(24bit)|校验码(32bit)|数据(变长)。 -// 注意: -// 1. "总长度"包含自身3字节及"校验码"4字节。 -// 2. 由于"总长度"为3字节,并且使用的BigEndian字节序,因此最后返回的buffer使用了buffer[1:]。 -func (c *Conn) SendPkg(data []byte, retry...Retry) error { - length := uint32(len(data)) - if length > PKG_MAX_SIZE - PKG_HEADER_SIZE { - return errors.New(fmt.Sprintf(`data size %d exceeds max pkg size %d`, length, PKG_MAX_SIZE - PKG_HEADER_SIZE)) +// 数据读取选项 +type Option struct { + MaxSize int // (byte)数据读取的最大包大小,最大不能超过3字节(0xFFFFFF,15MB),默认为65535byte + Secret []byte // (可选)安全通信密钥 + Retry Retry // 失败重试 +} + +// getPkgOption wraps and returns the option. +// If no option given, it returns a new option with default value. +func getPkgOption(option...Option) (*Option, error) { + pkgOption := Option{} + if len(option) > 0 { + pkgOption = option[0] } - buffer := make([]byte, PKG_HEADER_SIZE + 1 + len(data)) - copy(buffer[PKG_HEADER_SIZE + 1 : ], data) - binary.BigEndian.PutUint32(buffer[0 : ], PKG_HEADER_SIZE + length) - binary.BigEndian.PutUint32(buffer[4 : ], Checksum(data)) - //fmt.Println("SendPkg:", buffer[1:]) - return c.Send(buffer[1:], retry...) + if pkgOption.MaxSize == 0 { + pkgOption.MaxSize = gPKG_MAX_SIZE + } else if pkgOption.MaxSize > 0xFFFFFF { + return nil, fmt.Errorf(`package size %d exceeds allowed max size %d`, pkgOption.MaxSize, 0xFFFFFF) + } + return &pkgOption, nil +} + +// 根据简单协议发送数据包。 +// 简单协议数据格式:总长度(24bit)|校验位(32bit,可选)|数据(变长)。 +// 注意: +// 1. "总长度"包含自身3字节及"校验位"4字节(可选)。 +// 2. 当Secret有提供时,"校验位"才会存在,否则该字段为空。 +// 3. "校验位"提供简单的数据完整性及防篡改校验,默认没有开启。 +// 4. 由于"总长度"为3字节,并且使用的BigEndian字节序,因此这里最后返回的buffer使用了buffer[1:]。 +func (c *Conn) SendPkg(data []byte, option...Option) error { + pkgOption, err := getPkgOption(option...) + if err != nil { + return err + } + headerSize := 3 + if len(pkgOption.Secret) > 0 { + headerSize = 7 + } + length := len(data) + if length > pkgOption.MaxSize - headerSize { + return errors.New(fmt.Sprintf(`data size %d exceeds max pkg size %d`, length, gPKG_MAX_SIZE - headerSize)) + } + + buffer := make([]byte, headerSize + 1 + len(data)) + copy(buffer[headerSize + 1 : ], data) + binary.BigEndian.PutUint32(buffer[0 : ], uint32(headerSize + length)) + if len(pkgOption.Secret) > 0 { + binary.BigEndian.PutUint32(buffer[4 : ], gcrc32.Encrypt(append(data, pkgOption.Secret...))) + } + if pkgOption.Retry.Count > 0 { + c.Send(buffer[1:], pkgOption.Retry) + } + return c.Send(buffer[1:]) } // 简单协议: 带超时时间的数据发送 -func (c *Conn) SendPkgWithTimeout(data []byte, timeout time.Duration, retry...Retry) error { +func (c *Conn) SendPkgWithTimeout(data []byte, timeout time.Duration, option...Option) error { c.SetSendDeadline(time.Now().Add(timeout)) defer c.SetSendDeadline(time.Time{}) - return c.SendPkg(data, retry...) + return c.SendPkg(data, option...) } // 简单协议: 发送数据并等待接收返回数据 -func (c *Conn) SendRecvPkg(data []byte, retry...Retry) ([]byte, error) { - if err := c.SendPkg(data, retry...); err == nil { - return c.RecvPkg(retry...) +func (c *Conn) SendRecvPkg(data []byte, option...Option) ([]byte, error) { + if err := c.SendPkg(data, option...); err == nil { + return c.RecvPkg(option...) } else { return nil, err } } // 简单协议: 发送数据并等待接收返回数据(带返回超时等待时间) -func (c *Conn) SendRecvPkgWithTimeout(data []byte, timeout time.Duration, retry...Retry) ([]byte, error) { - if err := c.SendPkg(data, retry...); err == nil { - return c.RecvPkgWithTimeout(timeout, retry...) +func (c *Conn) SendRecvPkgWithTimeout(data []byte, timeout time.Duration, option...Option) ([]byte, error) { + if err := c.SendPkg(data, option...); err == nil { + return c.RecvPkgWithTimeout(timeout, option...) } else { return nil, err } } // 简单协议: 获取一个数据包。 -func (c *Conn) RecvPkg(retry...Retry) (result []byte, err error) { +func (c *Conn) RecvPkg(option...Option) (result []byte, err error) { var temp []byte - var length uint32 + var length int + pkgOption, err := getPkgOption(option...) + if err != nil { + return nil, err + } + headerSize := 3 + if len(pkgOption.Secret) > 0 { + headerSize = 7 + } for { // 先根据对象的缓冲区数据进行计算 for { - if len(c.buffer) >= PKG_HEADER_SIZE { + if len(c.buffer) >= headerSize { // 注意"总长度"为3个字节,不满足4个字节的uint32类型,因此这里"低位"补0 - length = binary.BigEndian.Uint32([]byte{0, c.buffer[0], c.buffer[1], c.buffer[2]}) - // 解析的大小是否符合规范 - if length == 0 || length + PKG_HEADER_SIZE > PKG_MAX_SIZE { - c.buffer = c.buffer[1:] - continue + length = int(binary.BigEndian.Uint32([]byte{0, c.buffer[0], c.buffer[1], c.buffer[2]})) + // 解析的大小是否符合规范,清空从该连接接收到的所有数据包 + if length <= 0 || length + headerSize > pkgOption.MaxSize { + c.buffer = c.buffer[:0] + return nil, fmt.Errorf(`invalid package size %d`, length) } // 不满足包大小,需要继续读取 - if uint32(len(c.buffer)) < length { + if len(c.buffer) < length { break } - // 数据校验 - if binary.BigEndian.Uint32(c.buffer[3 : PKG_HEADER_SIZE]) != Checksum(c.buffer[PKG_HEADER_SIZE : length]) { - c.buffer = c.buffer[1:] - continue + // 数据校验,如果失败,丢弃该数据包 + receivedCrc32 := binary.BigEndian.Uint32(c.buffer[3 : headerSize]) + calculatedCrc32 := gcrc32.Encrypt(c.buffer[headerSize : length]) + if receivedCrc32 != calculatedCrc32 { + c.buffer = c.buffer[length: ] + return nil, fmt.Errorf(`data CRC32 validates failed, received %d, caculated %d`, receivedCrc32, calculatedCrc32) } - result = c.buffer[PKG_HEADER_SIZE : length] + result = c.buffer[headerSize : length] c.buffer = c.buffer[length: ] return } else { @@ -95,7 +141,7 @@ func (c *Conn) RecvPkg(retry...Retry) (result []byte, err error) { } } // 读取系统socket缓冲区的完整数据 - temp, err = c.Recv(-1, retry...) + temp, err = c.Recv(-1, option...) if err != nil { break } @@ -108,7 +154,7 @@ func (c *Conn) RecvPkg(retry...Retry) (result []byte, err error) { } // 简单协议: 带超时时间的消息包获取 -func (c *Conn) RecvPkgWithTimeout(timeout time.Duration, retry...Retry) ([]byte, error) { +func (c *Conn) RecvPkgWithTimeout(timeout time.Duration, option...Option) ([]byte, error) { c.SetRecvDeadline(time.Now().Add(timeout)) defer c.SetRecvDeadline(time.Time{}) return c.RecvPkg(retry...) diff --git a/g/net/gtcp/gtcp_func.go b/g/net/gtcp/gtcp_func.go index e4aacd6ef..cbc82106a 100644 --- a/g/net/gtcp/gtcp_func.go +++ b/g/net/gtcp/gtcp_func.go @@ -12,8 +12,8 @@ import ( ) const ( - gDEFAULT_RETRY_INTERVAL = 100 // (毫秒)默认重试时间间隔 - gDEFAULT_READ_BUFFER_SIZE = 1024 // 默认数据读取缓冲区大小 + gDEFAULT_RETRY_INTERVAL = 100 // (毫秒)默认重试时间间隔 + gDEFAULT_READ_BUFFER_SIZE = 128 // (byte)默认数据读取缓冲区大小 ) type Retry struct { diff --git a/g/net/gtcp/gtcp_server.go b/g/net/gtcp/gtcp_server.go index 7962d9240..c6e3df5be 100644 --- a/g/net/gtcp/gtcp_server.go +++ b/g/net/gtcp/gtcp_server.go @@ -8,48 +8,62 @@ package gtcp import ( - "errors" - "github.com/gogf/gf/g/os/glog" - "net" - "github.com/gogf/gf/g/container/gmap" - "github.com/gogf/gf/g/util/gconv" + "crypto/rand" + "crypto/tls" + "errors" + "github.com/gogf/gf/g/container/gmap" + "github.com/gogf/gf/g/os/glog" + "github.com/gogf/gf/g/util/gconv" + "net" + "time" ) const ( gDEFAULT_SERVER = "default" ) -// tcp server结构体 +// TCP Server. type Server struct { address string handler func (*Conn) + tlsConfig *tls.Config } -// Server表,用以存储和检索名称与Server对象之间的关联关系 +// Map for name to server, for singleton purpose. var serverMapping = gmap.NewStrAnyMap() -// 获取/创建一个空配置的TCP Server -// 单例模式,请保证name的唯一性 +// GetServer returns the TCP server with specified , +// or it returns a new normal TCP server named if it does not exist. +// The parameter is used to specify the TCP server func GetServer(name...interface{}) *Server { serverName := gDEFAULT_SERVER if len(name) > 0 { serverName = gconv.String(name[0]) } - if s := serverMapping.Get(serverName); s != nil { - return s.(*Server) + return serverMapping.GetOrSetFuncLock(serverName, func() interface{} { + return NewServer("", nil) + }).(*Server) +} + +// NewServer creates and returns a new normal TCP server. +// The param is optional, which is used to specify the instance name of the server. +func NewServer(address string, handler func (*Conn), name...string) *Server { + s := &Server{ + address : address, + handler : handler, + } + if len(name) > 0 { + serverMapping.Set(name[0], s) } - s := NewServer("", nil) - serverMapping.Set(serverName, s) return s } -// 创建一个tcp server对象,并且可以选择指定一个单例名字 -func NewServer(address string, handler func (*Conn), names...string) *Server { - s := &Server{address, handler} - if len(names) > 0 { - serverMapping.Set(names[0], s) - } - return s +// NewTlsServer creates and returns a new TCP server with TLS support. +// The param is optional, which is used to specify the instance name of the server. +func NewTLSServer(address, crtFile, keyFile string, handler func (*Conn), name...string) *Server { + s := NewServer(address, handler, name...) + s.SetTLSKeyCrt(crtFile, keyFile) + return s } // 设置参数 - address @@ -62,24 +76,53 @@ func (s *Server) SetHandler (handler func (*Conn)) { s.handler = handler } -// 执行监听 -func (s *Server) Run() error { +// SetTlsKeyCrt sets the certificate and key file for TLS configuration of server. +func (s *Server) SetTLSKeyCrt (crtFile, keyFile string) error { + crt, err := tls.LoadX509KeyPair(crtFile,keyFile) + if err != nil { + return err + } + s.tlsConfig = &tls.Config{} + s.tlsConfig.Certificates = []tls.Certificate{crt} + s.tlsConfig.Time = time.Now + s.tlsConfig.Rand = rand.Reader + return nil +} + +// SetTlsConfig sets the TLS configuration of server. +func (s *Server) SetTLSConfig(tlsConfig *tls.Config) { + s.tlsConfig = tlsConfig +} + +// Run starts running the TCP Server. +func (s *Server) Run() (err error) { if s.handler == nil { - err := errors.New("start running failed: socket handler not defined") + err = errors.New("start running failed: socket handler not defined") glog.Error(err) - return err + return } - addr, err := net.ResolveTCPAddr("tcp", s.address) - if err != nil { - glog.Error(err) - return err + listen := net.Listener(nil) + if s.tlsConfig != nil { + // TLS Server + listen, err = tls.Listen("tcp", s.address, s.tlsConfig) + if err != nil { + glog.Error(err) + return + } + } else { + // Normal Server + addr, err := net.ResolveTCPAddr("tcp", s.address) + if err != nil { + glog.Error(err) + return err + } + listen, err = net.ListenTCP("tcp", addr) + if err != nil { + glog.Error(err) + return err + } } - listen, err := net.ListenTCP("tcp", addr) - if err != nil { - glog.Error(err) - return err - } - for { + for { if conn, err := listen.Accept(); err != nil { glog.Error(err) return err