| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490 |
- package validator
- import (
- "fmt"
- "google.golang.org/protobuf/proto"
- "google.golang.org/protobuf/reflect/protoreflect"
- validate "git.ikuban.com/server/kubanapis/kuban/api/validate"
- )
- // Engine 验证引擎
- type Engine struct {
- celExecutor *CELExecutor
- }
- // NewEngine 创建验证引擎
- func NewEngine(celExecutor *CELExecutor) *Engine {
- return &Engine{
- celExecutor: celExecutor,
- }
- }
- // Validate 验证消息
- func (e *Engine) Validate(vctx *ValidationContext, msg proto.Message) error {
- if msg == nil {
- return nil
- }
- errors := &ValidationErrors{}
- // 获取消息的反射对象
- msgReflect := msg.ProtoReflect()
- msgDescriptor := msgReflect.Descriptor()
- // 1. 验证消息级别规则
- if err := e.validateMessageRules(vctx, msgReflect, msgDescriptor, errors); err != nil {
- return err
- }
- // 2. 验证字段
- if err := e.validateFields(vctx, msgReflect, msgDescriptor, errors); err != nil {
- return err
- }
- // 3. 验证 oneof
- if err := e.validateOneofs(vctx, msgReflect, msgDescriptor, errors); err != nil {
- return err
- }
- // 返回错误(如果有)
- if errors.HasErrors() {
- return errors.ToKratosError()
- }
- return nil
- }
- // validateMessageRules 验证消息级别规则
- func (e *Engine) validateMessageRules(
- vctx *ValidationContext,
- msgReflect protoreflect.Message,
- msgDescriptor protoreflect.MessageDescriptor,
- errors *ValidationErrors,
- ) error {
- // 获取消息选项
- opts := msgDescriptor.Options()
- if opts == nil {
- return nil
- }
- // 检查是否有验证规则扩展
- if !proto.HasExtension(opts, validate.E_Message) {
- return nil
- }
- // 获取消息验证规则
- msgRules := proto.GetExtension(opts, validate.E_Message).(*validate.MessageRules)
- if msgRules == nil {
- return nil
- }
- // 检查是否被禁用
- if msgRules.Disabled {
- return nil
- }
- // 检查验证组
- if !vctx.HasAnyGroup(msgRules.Groups) {
- return nil
- }
- // 执行 CEL 规则
- for _, celRule := range msgRules.Cel {
- if err := e.executeCELRule(vctx, celRule, msgReflect.Interface(), "", errors); err != nil {
- if vctx.ShouldFailFast() {
- return err
- }
- }
- }
- return nil
- }
- // validateFields 验证所有字段
- func (e *Engine) validateFields(
- vctx *ValidationContext,
- msgReflect protoreflect.Message,
- msgDescriptor protoreflect.MessageDescriptor,
- errors *ValidationErrors,
- ) error {
- fields := msgDescriptor.Fields()
- for i := 0; i < fields.Len(); i++ {
- field := fields.Get(i)
- fieldPath := string(field.Name())
- // 检查是否达到最大错误数
- if vctx.MaxErrorsReached(errors.Count()) {
- break
- }
- // 验证字段
- if err := e.validateField(vctx, msgReflect, field, fieldPath, errors); err != nil {
- if vctx.ShouldFailFast() {
- return err
- }
- }
- }
- return nil
- }
- // validateField 验证单个字段
- func (e *Engine) validateField(
- vctx *ValidationContext,
- msgReflect protoreflect.Message,
- field protoreflect.FieldDescriptor,
- fieldPath string,
- errors *ValidationErrors,
- ) error {
- // 获取字段选项
- opts := field.Options()
- if opts == nil {
- return nil
- }
- // 检查是否有验证规则
- if !proto.HasExtension(opts, validate.E_Field) {
- return nil
- }
- // 获取字段验证规则
- fieldRules := proto.GetExtension(opts, validate.E_Field).(*validate.FieldRules)
- if fieldRules == nil {
- return nil
- }
- // 检查验证组
- if !vctx.HasAnyGroup(fieldRules.Groups) {
- return nil
- }
- // 获取字段值
- fieldValue := msgReflect.Get(field)
- hasValue := msgReflect.Has(field)
- // 检查忽略策略
- if e.shouldIgnoreField(fieldRules, hasValue, fieldValue, field) {
- return nil
- }
- // 验证 required
- if fieldRules.Required {
- if !hasValue {
- defaultMsg := fmt.Sprintf("字段 %s 是必填的", fieldPath)
- message := e.getErrorMessage(fieldRules, vctx, "required", defaultMsg)
- errors.Add(NewValidationError(
- fieldPath,
- ErrCodeRequired,
- message,
- ))
- if vctx.ShouldFailFast() {
- return errors
- }
- }
- }
- // 如果字段未设置且不是 required,跳过后续验证
- if !hasValue && !fieldRules.Required {
- return nil
- }
- // 执行 CEL 规则
- for _, celRule := range fieldRules.Cel {
- if err := e.executeCELFieldRule(
- vctx,
- celRule,
- fieldValue.Interface(),
- msgReflect.Interface(),
- fieldPath,
- errors,
- ); err != nil {
- if vctx.ShouldFailFast() {
- return err
- }
- }
- }
- // 根据字段类型执行类型特定验证
- return e.validateFieldType(vctx, fieldRules, fieldValue, field, fieldPath, errors)
- }
- // validateFieldType 根据字段类型执行验证
- func (e *Engine) validateFieldType(
- vctx *ValidationContext,
- fieldRules *validate.FieldRules,
- fieldValue protoreflect.Value,
- field protoreflect.FieldDescriptor,
- fieldPath string,
- errors *ValidationErrors,
- ) error {
- // 根据规则类型分发
- switch rule := fieldRules.Type.(type) {
- case *validate.FieldRules_String_:
- return e.validateString(vctx, fieldRules, rule.String_, fieldValue.String(), fieldPath, errors)
- case *validate.FieldRules_Int32:
- return e.validateInt32(vctx, fieldRules, rule.Int32, int32(fieldValue.Int()), fieldPath, errors)
- case *validate.FieldRules_Int64:
- return e.validateInt64(vctx, fieldRules, rule.Int64, fieldValue.Int(), fieldPath, errors)
- case *validate.FieldRules_Uint32:
- return e.validateUInt32(vctx, fieldRules, rule.Uint32, uint32(fieldValue.Uint()), fieldPath, errors)
- case *validate.FieldRules_Uint64:
- return e.validateUInt64(vctx, fieldRules, rule.Uint64, fieldValue.Uint(), fieldPath, errors)
- case *validate.FieldRules_Sint32:
- return e.validateInt32(vctx, fieldRules, (*validate.Int32Rules)(rule.Sint32), int32(fieldValue.Int()), fieldPath, errors)
- case *validate.FieldRules_Sint64:
- return e.validateInt64(vctx, fieldRules, (*validate.Int64Rules)(rule.Sint64), fieldValue.Int(), fieldPath, errors)
- case *validate.FieldRules_Fixed32:
- return e.validateUInt32(vctx, fieldRules, (*validate.UInt32Rules)(rule.Fixed32), uint32(fieldValue.Uint()), fieldPath, errors)
- case *validate.FieldRules_Fixed64:
- return e.validateUInt64(vctx, fieldRules, (*validate.UInt64Rules)(rule.Fixed64), fieldValue.Uint(), fieldPath, errors)
- case *validate.FieldRules_Sfixed32:
- return e.validateInt32(vctx, fieldRules, (*validate.Int32Rules)(rule.Sfixed32), int32(fieldValue.Int()), fieldPath, errors)
- case *validate.FieldRules_Sfixed64:
- return e.validateInt64(vctx, fieldRules, (*validate.Int64Rules)(rule.Sfixed64), fieldValue.Int(), fieldPath, errors)
- case *validate.FieldRules_Float:
- return e.validateFloat(vctx, fieldRules, rule.Float, float32(fieldValue.Float()), fieldPath, errors)
- case *validate.FieldRules_Double:
- return e.validateDouble(vctx, fieldRules, rule.Double, fieldValue.Float(), fieldPath, errors)
- case *validate.FieldRules_Bool:
- return e.validateBool(vctx, fieldRules, rule.Bool, fieldValue.Bool(), fieldPath, errors)
- case *validate.FieldRules_Bytes:
- return e.validateBytes(vctx, fieldRules, rule.Bytes, fieldValue.Bytes(), fieldPath, errors)
- case *validate.FieldRules_Enum:
- return e.validateEnum(vctx, fieldRules, rule.Enum, fieldValue.Enum(), field, fieldPath, errors)
- case *validate.FieldRules_Repeated:
- return e.validateRepeated(vctx, fieldRules, rule.Repeated, fieldValue.List(), field, fieldPath, errors)
- case *validate.FieldRules_Map:
- return e.validateMap(vctx, fieldRules, rule.Map, fieldValue.Map(), field, fieldPath, errors)
- }
- return nil
- }
- // shouldIgnoreField 检查是否应该忽略字段验证
- func (e *Engine) shouldIgnoreField(
- rules *validate.FieldRules,
- hasValue bool,
- value protoreflect.Value,
- field protoreflect.FieldDescriptor,
- ) bool {
- switch rules.Ignore {
- case validate.IgnoreRule_IGNORE_ALWAYS:
- return true
- case validate.IgnoreRule_IGNORE_IF_ZERO:
- return !hasValue || isZeroValue(value, field)
- case validate.IgnoreRule_IGNORE_NEVER:
- return false
- default:
- // IGNORE_UNSPECIFIED: 未设置时不验证
- return !hasValue
- }
- }
- // isZeroValue 检查是否为零值
- func isZeroValue(value protoreflect.Value, field protoreflect.FieldDescriptor) bool {
- switch field.Kind() {
- case protoreflect.BoolKind:
- return !value.Bool()
- case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind,
- protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
- return value.Int() == 0
- case protoreflect.Uint32Kind, protoreflect.Fixed32Kind,
- protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
- return value.Uint() == 0
- case protoreflect.FloatKind, protoreflect.DoubleKind:
- return value.Float() == 0.0
- case protoreflect.StringKind:
- return value.String() == ""
- case protoreflect.BytesKind:
- return len(value.Bytes()) == 0
- case protoreflect.MessageKind:
- return !value.Message().IsValid()
- default:
- return false
- }
- }
- // validateOneofs 验证所有 oneof
- func (e *Engine) validateOneofs(
- vctx *ValidationContext,
- msgReflect protoreflect.Message,
- msgDescriptor protoreflect.MessageDescriptor,
- errors *ValidationErrors,
- ) error {
- oneofs := msgDescriptor.Oneofs()
- for i := 0; i < oneofs.Len(); i++ {
- oneof := oneofs.Get(i)
- // 获取 oneof 选项
- opts := oneof.Options()
- if opts == nil {
- continue
- }
- // 检查是否有验证规则
- if !proto.HasExtension(opts, validate.E_Oneof) {
- continue
- }
- // 获取 oneof 验证规则
- oneofRules := proto.GetExtension(opts, validate.E_Oneof).(*validate.OneofRules)
- if oneofRules == nil {
- continue
- }
- // 验证 required
- if oneofRules.Required {
- hasField := false
- fields := oneof.Fields()
- for j := 0; j < fields.Len(); j++ {
- if msgReflect.Has(fields.Get(j)) {
- hasField = true
- break
- }
- }
- if !hasField {
- errors.Add(NewValidationError(
- string(oneof.Name()),
- ErrCodeRequired,
- fmt.Sprintf("oneof %s 必须选择一个字段", oneof.Name()),
- ))
- if vctx.ShouldFailFast() {
- return errors
- }
- }
- }
- }
- return nil
- }
- // executeCELRule 执行 CEL 规则
- func (e *Engine) executeCELRule(
- vctx *ValidationContext,
- rule *validate.Rule,
- msg proto.Message,
- fieldPath string,
- errors *ValidationErrors,
- ) error {
- // 执行表达式
- result, err := e.celExecutor.EvaluateWithProto(rule.Expression, msg)
- if err != nil {
- // CEL 执行错误
- errors.Add(&ValidationError{
- FieldPath: fieldPath,
- RuleID: rule.Id,
- Message: fmt.Sprintf("CEL 表达式执行失败: %v", err),
- })
- return nil
- }
- // 如果验证失败
- if !result {
- message := rule.Message
- if message == "" && vctx.Language == "en-US" && rule.MessageEn != "" {
- message = rule.MessageEn
- }
- if message == "" {
- message = fmt.Sprintf("验证规则 %s 失败", rule.Id)
- }
- errors.Add(&ValidationError{
- FieldPath: fieldPath,
- RuleID: rule.Id,
- Message: message,
- MessageEN: rule.MessageEn,
- })
- }
- return nil
- }
- // executeCELFieldRule 执行字段级 CEL 规则
- func (e *Engine) executeCELFieldRule(
- vctx *ValidationContext,
- rule *validate.Rule,
- fieldValue interface{},
- parent proto.Message,
- fieldPath string,
- errors *ValidationErrors,
- ) error {
- // 执行表达式
- result, err := e.celExecutor.EvaluateFieldRule(rule.Expression, fieldValue, parent)
- if err != nil {
- // CEL 执行错误
- errors.Add(&ValidationError{
- FieldPath: fieldPath,
- RuleID: rule.Id,
- Message: fmt.Sprintf("CEL 表达式执行失败: %v", err),
- })
- return nil
- }
- // 如果验证失败
- if !result {
- message := rule.Message
- if message == "" && vctx.Language == "en-US" && rule.MessageEn != "" {
- message = rule.MessageEn
- }
- if message == "" {
- message = fmt.Sprintf("验证规则 %s 失败", rule.Id)
- }
- errors.Add(&ValidationError{
- FieldPath: fieldPath,
- RuleID: rule.Id,
- Message: message,
- MessageEN: rule.MessageEn,
- Value: fieldValue,
- })
- }
- return nil
- }
- // getErrorMessage 获取错误消息,按优先级返回
- // 优先级: 特定消息(required_message) > 通用消息(error_message) > 默认消息
- func (e *Engine) getErrorMessage(
- fieldRules *validate.FieldRules,
- vctx *ValidationContext,
- ruleType string,
- defaultMessage string,
- ) string {
- // 1. 检查是否有特定规则的自定义消息 (如 required_message)
- if ruleType == "required" {
- if vctx.Language == "en-US" && fieldRules.RequiredMessageEn != "" {
- return fieldRules.RequiredMessageEn
- }
- if fieldRules.RequiredMessage != "" {
- return fieldRules.RequiredMessage
- }
- }
- // 2. 检查通用自定义消息 (error_message)
- if vctx.Language == "en-US" && fieldRules.ErrorMessageEn != "" {
- return fieldRules.ErrorMessageEn
- }
- if fieldRules.ErrorMessage != "" {
- return fieldRules.ErrorMessage
- }
- // 3. 返回默认消息
- return defaultMessage
- }
- // ValidateField 验证单个字段(外部接口)
- func (e *Engine) ValidateField(vctx *ValidationContext, fieldPath string, value interface{}, rules interface{}) error {
- errors := &ValidationErrors{}
- // TODO: 实现单字段验证逻辑
- if errors.HasErrors() {
- return errors.ToKratosError()
- }
- return nil
- }
|