middleware.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. package validator
  2. import (
  3. "context"
  4. "github.com/go-kratos/kratos/v2/middleware"
  5. "google.golang.org/protobuf/proto"
  6. "git.ikuban.com/server/kratos-utils/v2/validator"
  7. )
  8. // Validator 是验证器中间件选项
  9. type Option func(*options)
  10. type options struct {
  11. validator validator.Validator
  12. groups []string
  13. }
  14. // WithValidator 设置验证器
  15. func WithValidator(v validator.Validator) Option {
  16. return func(o *options) {
  17. o.validator = v
  18. }
  19. }
  20. // WithGroups 设置验证组
  21. func WithGroups(groups ...string) Option {
  22. return func(o *options) {
  23. o.groups = groups
  24. }
  25. }
  26. // Server 是服务端验证中间件
  27. func Server(opts ...Option) middleware.Middleware {
  28. o := &options{}
  29. for _, opt := range opts {
  30. opt(o)
  31. }
  32. // 如果没有指定验证器,创建默认验证器
  33. if o.validator == nil {
  34. v, err := validator.New()
  35. if err != nil {
  36. panic(err)
  37. }
  38. o.validator = v
  39. }
  40. return func(handler middleware.Handler) middleware.Handler {
  41. return func(ctx context.Context, req interface{}) (interface{}, error) {
  42. // 检查请求是否为 proto.Message
  43. if msg, ok := req.(proto.Message); ok {
  44. // 如果设置了验证组,使用验证组验证
  45. if len(o.groups) > 0 {
  46. ctx = validator.WithGroups(ctx, o.groups...)
  47. }
  48. // 执行验证
  49. if err := o.validator.Validate(ctx, msg); err != nil {
  50. return nil, err
  51. }
  52. }
  53. // 继续处理请求
  54. return handler(ctx, req)
  55. }
  56. }
  57. }
  58. // Client 是客户端验证中间件
  59. func Client(opts ...Option) middleware.Middleware {
  60. o := &options{}
  61. for _, opt := range opts {
  62. opt(o)
  63. }
  64. // 如果没有指定验证器,创建默认验证器
  65. if o.validator == nil {
  66. v, err := validator.New()
  67. if err != nil {
  68. panic(err)
  69. }
  70. o.validator = v
  71. }
  72. return func(handler middleware.Handler) middleware.Handler {
  73. return func(ctx context.Context, req interface{}) (interface{}, error) {
  74. // 检查请求是否为 proto.Message
  75. if msg, ok := req.(proto.Message); ok {
  76. // 如果设置了验证组,使用验证组验证
  77. if len(o.groups) > 0 {
  78. ctx = validator.WithGroups(ctx, o.groups...)
  79. }
  80. // 执行验证
  81. if err := o.validator.Validate(ctx, msg); err != nil {
  82. return nil, err
  83. }
  84. }
  85. // 继续处理请求
  86. return handler(ctx, req)
  87. }
  88. }
  89. }