package validator import ( "context" "github.com/go-kratos/kratos/v2/middleware" "google.golang.org/protobuf/proto" "git.ikuban.com/server/kratos-utils/v2/validator" ) // Validator 是验证器中间件选项 type Option func(*options) type options struct { validator validator.Validator groups []string } // WithValidator 设置验证器 func WithValidator(v validator.Validator) Option { return func(o *options) { o.validator = v } } // WithGroups 设置验证组 func WithGroups(groups ...string) Option { return func(o *options) { o.groups = groups } } // Server 是服务端验证中间件 func Server(opts ...Option) middleware.Middleware { o := &options{} for _, opt := range opts { opt(o) } // 如果没有指定验证器,创建默认验证器 if o.validator == nil { v, err := validator.New() if err != nil { panic(err) } o.validator = v } return func(handler middleware.Handler) middleware.Handler { return func(ctx context.Context, req interface{}) (interface{}, error) { // 检查请求是否为 proto.Message if msg, ok := req.(proto.Message); ok { // 如果设置了验证组,使用验证组验证 if len(o.groups) > 0 { ctx = validator.WithGroups(ctx, o.groups...) } // 执行验证 if err := o.validator.Validate(ctx, msg); err != nil { return nil, err } } // 继续处理请求 return handler(ctx, req) } } } // Client 是客户端验证中间件 func Client(opts ...Option) middleware.Middleware { o := &options{} for _, opt := range opts { opt(o) } // 如果没有指定验证器,创建默认验证器 if o.validator == nil { v, err := validator.New() if err != nil { panic(err) } o.validator = v } return func(handler middleware.Handler) middleware.Handler { return func(ctx context.Context, req interface{}) (interface{}, error) { // 检查请求是否为 proto.Message if msg, ok := req.(proto.Message); ok { // 如果设置了验证组,使用验证组验证 if len(o.groups) > 0 { ctx = validator.WithGroups(ctx, o.groups...) } // 执行验证 if err := o.validator.Validate(ctx, msg); err != nil { return nil, err } } // 继续处理请求 return handler(ctx, req) } } }