| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106 |
- package validator
- import (
- "context"
- "github.com/go-kratos/kratos/v2/middleware"
- "google.golang.org/protobuf/proto"
- "git.ikuban.com/server/kratos-utils/v2/validator"
- )
- // Validator 是验证器中间件选项
- type Option func(*options)
- type options struct {
- validator validator.Validator
- groups []string
- }
- // WithValidator 设置验证器
- func WithValidator(v validator.Validator) Option {
- return func(o *options) {
- o.validator = v
- }
- }
- // WithGroups 设置验证组
- func WithGroups(groups ...string) Option {
- return func(o *options) {
- o.groups = groups
- }
- }
- // Server 是服务端验证中间件
- func Server(opts ...Option) middleware.Middleware {
- o := &options{}
- for _, opt := range opts {
- opt(o)
- }
- // 如果没有指定验证器,创建默认验证器
- if o.validator == nil {
- v, err := validator.New()
- if err != nil {
- panic(err)
- }
- o.validator = v
- }
- return func(handler middleware.Handler) middleware.Handler {
- return func(ctx context.Context, req interface{}) (interface{}, error) {
- // 检查请求是否为 proto.Message
- if msg, ok := req.(proto.Message); ok {
- // 如果设置了验证组,使用验证组验证
- if len(o.groups) > 0 {
- ctx = validator.WithGroups(ctx, o.groups...)
- }
- // 执行验证
- if err := o.validator.Validate(ctx, msg); err != nil {
- return nil, err
- }
- }
- // 继续处理请求
- return handler(ctx, req)
- }
- }
- }
- // Client 是客户端验证中间件
- func Client(opts ...Option) middleware.Middleware {
- o := &options{}
- for _, opt := range opts {
- opt(o)
- }
- // 如果没有指定验证器,创建默认验证器
- if o.validator == nil {
- v, err := validator.New()
- if err != nil {
- panic(err)
- }
- o.validator = v
- }
- return func(handler middleware.Handler) middleware.Handler {
- return func(ctx context.Context, req interface{}) (interface{}, error) {
- // 检查请求是否为 proto.Message
- if msg, ok := req.(proto.Message); ok {
- // 如果设置了验证组,使用验证组验证
- if len(o.groups) > 0 {
- ctx = validator.WithGroups(ctx, o.groups...)
- }
- // 执行验证
- if err := o.validator.Validate(ctx, msg); err != nil {
- return nil, err
- }
- }
- // 继续处理请求
- return handler(ctx, req)
- }
- }
- }
|