milvus/pkg/util/logutil/grpc_interceptor.go
madogar 919df4cd02
enhance: changes to propagate traceid from client (#32264)
https://github.com/milvus-io/milvus/issues/32321

Issue Description:
Tracing is an important means of identifying bottleneck points in a
system and is crucial for debugging production issues. Milvus(or any DB)
is generally the most downstream system for an user call -- a user call
can originate from UI and pass through multiple components, in
micro-services architecture, before reaching Milvus.
So, when an user experiences a glitch, one would debug the call trace
via logs using a common trace id. As of now, Milvus generates a new
trace id for every call and this request is to make sure client can pass
the trace id which will be used for all the logs across the Milvus
sub-components so that one can fetch logs for a user call across the
components -- including Milvus.

Signed-off-by: Shreesha Srinath Madogaran <smadogaran@salesforce.com>
Co-authored-by: Shreesha Srinath Madogaran <smadogaran@salesforce.com>
2024-04-17 01:13:20 +08:00

81 lines
2.4 KiB
Go

package logutil
import (
"context"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap/zapcore"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"github.com/milvus-io/milvus/pkg/log"
)
const (
logLevelRPCMetaKey = "log_level"
clientRequestIDKey = "client_request_id"
)
// UnaryTraceLoggerInterceptor adds a traced logger in unary rpc call ctx
func UnaryTraceLoggerInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
newctx := withLevelAndTrace(ctx)
return handler(newctx, req)
}
// StreamTraceLoggerInterceptor add a traced logger in stream rpc call ctx
func StreamTraceLoggerInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
ctx := ss.Context()
newctx := withLevelAndTrace(ctx)
wrappedStream := grpc_middleware.WrapServerStream(ss)
wrappedStream.WrappedContext = newctx
return handler(srv, wrappedStream)
}
func withLevelAndTrace(ctx context.Context) context.Context {
newctx := ctx
var traceID trace.TraceID
if md, ok := metadata.FromIncomingContext(ctx); ok {
levels := md.Get(logLevelRPCMetaKey)
// get log level
if len(levels) >= 1 {
level := zapcore.DebugLevel
if err := level.UnmarshalText([]byte(levels[0])); err != nil {
newctx = ctx
} else {
switch level {
case zapcore.DebugLevel:
newctx = log.WithDebugLevel(ctx)
case zapcore.InfoLevel:
newctx = log.WithInfoLevel(ctx)
case zapcore.WarnLevel:
newctx = log.WithWarnLevel(ctx)
case zapcore.ErrorLevel:
newctx = log.WithErrorLevel(ctx)
case zapcore.FatalLevel:
newctx = log.WithFatalLevel(ctx)
default:
newctx = ctx
}
}
// inject log level to outgoing meta
newctx = metadata.AppendToOutgoingContext(newctx, logLevelRPCMetaKey, level.String())
}
// client request id
requestID := md.Get(clientRequestIDKey)
if len(requestID) >= 1 {
// inject traceid in order to pass client request id
newctx = metadata.AppendToOutgoingContext(newctx, clientRequestIDKey, requestID[0])
// inject traceid from client for info/debug/warn/error logs
newctx = log.WithTraceID(newctx, requestID[0])
}
}
if !traceID.IsValid() {
traceID = trace.SpanContextFromContext(newctx).TraceID()
}
if traceID.IsValid() {
newctx = log.WithTraceID(newctx, traceID.String())
}
return newctx
}