milvus/internal/proxy/database_interceptor.go
jaime 18df2ba6fd
[Cherry-Pick] Support Database (#24769)
Support Database(#23742)
Fix db nonexists error for FlushAll (#24222)
Fix check collection limits fails (#24235)
backward compatibility with empty DB name (#24317)
Fix GetFlushAllState with DB (#24347)
Remove db from global meta cache after drop database (#24474)
Fix db name is empty for describe collection response (#24603)
Add RBAC for Database API (#24653)
Fix miss load the same name collection during recover stage (#24941)

RBAC supports Database validation (#23609)
Fix to list grant with db return empty (#23922)
Optimize PrivilegeAll permission check (#23972)
Add the default db value for the rbac request (#24307)

Signed-off-by: jaime <yun.zhang@zilliz.com>
Co-authored-by: SimFG <bang.fu@zilliz.com>
Co-authored-by: longjiquan <jiquan.long@zilliz.com>
2023-06-25 17:20:43 +08:00

245 lines
6.4 KiB
Go

package proxy
import (
"context"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
)
// DatabaseInterceptor fill dbname into request based on kv pair <"dbname": "xx"> in header
func DatabaseInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
filledCtx, filledReq := fillDatabase(ctx, req)
return handler(filledCtx, filledReq)
}
}
func fillDatabase(ctx context.Context, req interface{}) (context.Context, interface{}) {
switch r := req.(type) {
case *milvuspb.CreateCollectionRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.DropCollectionRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.HasCollectionRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.LoadCollectionRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.ReleaseCollectionRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.DescribeCollectionRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.GetStatisticsRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.GetCollectionStatisticsRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.ShowCollectionsRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.AlterCollectionRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.CreatePartitionRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.DropPartitionRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.HasPartitionRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.LoadPartitionsRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.ReleasePartitionsRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.GetPartitionStatisticsRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.ShowPartitionsRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.GetLoadingProgressRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.GetLoadStateRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.CreateIndexRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.DescribeIndexRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.DropIndexRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.GetIndexBuildProgressRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.GetIndexStateRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.InsertRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.DeleteRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.SearchRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.FlushRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.QueryRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.CreateAliasRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.DropAliasRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.AlterAliasRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.ImportRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.ListImportTasksRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.RenameCollectionRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.TransferReplicaRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.GetPersistentSegmentInfoRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.GetQuerySegmentInfoRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.LoadBalanceRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.GetReplicasRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.OperatePrivilegeRequest:
if r.Entity != nil && r.Entity.DbName == "" {
r.Entity.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.SelectGrantRequest:
if r.Entity != nil && r.Entity.DbName == "" {
r.Entity.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.GetIndexStatisticsRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
case *milvuspb.UpsertRequest:
if r.DbName == "" {
r.DbName = GetCurDBNameFromContextOrDefault(ctx)
}
return ctx, r
default:
return ctx, req
}
}