package validator import ( "fmt" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" validate "git.ikuban.com/server/kubanapis/kuban/api/validate" ) // Engine 验证引擎 type Engine struct { celExecutor *CELExecutor } // NewEngine 创建验证引擎 func NewEngine(celExecutor *CELExecutor) *Engine { return &Engine{ celExecutor: celExecutor, } } // Validate 验证消息 func (e *Engine) Validate(vctx *ValidationContext, msg proto.Message) error { if msg == nil { return nil } errors := &ValidationErrors{} // 获取消息的反射对象 msgReflect := msg.ProtoReflect() msgDescriptor := msgReflect.Descriptor() // 1. 验证消息级别规则 if err := e.validateMessageRules(vctx, msgReflect, msgDescriptor, errors); err != nil { return err } // 2. 验证字段 if err := e.validateFields(vctx, msgReflect, msgDescriptor, errors); err != nil { return err } // 3. 验证 oneof if err := e.validateOneofs(vctx, msgReflect, msgDescriptor, errors); err != nil { return err } // 返回错误(如果有) if errors.HasErrors() { return errors.ToKratosError() } return nil } // validateMessageRules 验证消息级别规则 func (e *Engine) validateMessageRules( vctx *ValidationContext, msgReflect protoreflect.Message, msgDescriptor protoreflect.MessageDescriptor, errors *ValidationErrors, ) error { // 获取消息选项 opts := msgDescriptor.Options() if opts == nil { return nil } // 检查是否有验证规则扩展 if !proto.HasExtension(opts, validate.E_Message) { return nil } // 获取消息验证规则 msgRules := proto.GetExtension(opts, validate.E_Message).(*validate.MessageRules) if msgRules == nil { return nil } // 检查是否被禁用 if msgRules.Disabled { return nil } // 检查验证组 if !vctx.HasAnyGroup(msgRules.Groups) { return nil } // 执行 CEL 规则 for _, celRule := range msgRules.Cel { if err := e.executeCELRule(vctx, celRule, msgReflect.Interface(), "", errors); err != nil { if vctx.ShouldFailFast() { return err } } } return nil } // validateFields 验证所有字段 func (e *Engine) validateFields( vctx *ValidationContext, msgReflect protoreflect.Message, msgDescriptor protoreflect.MessageDescriptor, errors *ValidationErrors, ) error { fields := msgDescriptor.Fields() for i := 0; i < fields.Len(); i++ { field := fields.Get(i) fieldPath := string(field.Name()) // 检查是否达到最大错误数 if vctx.MaxErrorsReached(errors.Count()) { break } // 验证字段 if err := e.validateField(vctx, msgReflect, field, fieldPath, errors); err != nil { if vctx.ShouldFailFast() { return err } } } return nil } // validateField 验证单个字段 func (e *Engine) validateField( vctx *ValidationContext, msgReflect protoreflect.Message, field protoreflect.FieldDescriptor, fieldPath string, errors *ValidationErrors, ) error { // 获取字段选项 opts := field.Options() if opts == nil { return nil } // 检查是否有验证规则 if !proto.HasExtension(opts, validate.E_Field) { return nil } // 获取字段验证规则 fieldRules := proto.GetExtension(opts, validate.E_Field).(*validate.FieldRules) if fieldRules == nil { return nil } // 检查验证组 if !vctx.HasAnyGroup(fieldRules.Groups) { return nil } // 获取字段值 fieldValue := msgReflect.Get(field) hasValue := msgReflect.Has(field) // 检查忽略策略 if e.shouldIgnoreField(fieldRules, hasValue, fieldValue, field) { return nil } // 验证 required if fieldRules.Required { if !hasValue { defaultMsg := fmt.Sprintf("字段 %s 是必填的", fieldPath) message := e.getErrorMessage(fieldRules, vctx, "required", defaultMsg) errors.Add(NewValidationError( fieldPath, ErrCodeRequired, message, )) if vctx.ShouldFailFast() { return errors } } } // 如果字段未设置且不是 required,跳过后续验证 if !hasValue && !fieldRules.Required { return nil } // 执行 CEL 规则 for _, celRule := range fieldRules.Cel { if err := e.executeCELFieldRule( vctx, celRule, fieldValue.Interface(), msgReflect.Interface(), fieldPath, errors, ); err != nil { if vctx.ShouldFailFast() { return err } } } // 根据字段类型执行类型特定验证 return e.validateFieldType(vctx, fieldRules, fieldValue, field, fieldPath, errors) } // validateFieldType 根据字段类型执行验证 func (e *Engine) validateFieldType( vctx *ValidationContext, fieldRules *validate.FieldRules, fieldValue protoreflect.Value, field protoreflect.FieldDescriptor, fieldPath string, errors *ValidationErrors, ) error { // 根据规则类型分发 switch rule := fieldRules.Type.(type) { case *validate.FieldRules_String_: return e.validateString(vctx, fieldRules, rule.String_, fieldValue.String(), fieldPath, errors) case *validate.FieldRules_Int32: return e.validateInt32(vctx, fieldRules, rule.Int32, int32(fieldValue.Int()), fieldPath, errors) case *validate.FieldRules_Int64: return e.validateInt64(vctx, fieldRules, rule.Int64, fieldValue.Int(), fieldPath, errors) case *validate.FieldRules_Uint32: return e.validateUInt32(vctx, fieldRules, rule.Uint32, uint32(fieldValue.Uint()), fieldPath, errors) case *validate.FieldRules_Uint64: return e.validateUInt64(vctx, fieldRules, rule.Uint64, fieldValue.Uint(), fieldPath, errors) case *validate.FieldRules_Sint32: return e.validateInt32(vctx, fieldRules, (*validate.Int32Rules)(rule.Sint32), int32(fieldValue.Int()), fieldPath, errors) case *validate.FieldRules_Sint64: return e.validateInt64(vctx, fieldRules, (*validate.Int64Rules)(rule.Sint64), fieldValue.Int(), fieldPath, errors) case *validate.FieldRules_Fixed32: return e.validateUInt32(vctx, fieldRules, (*validate.UInt32Rules)(rule.Fixed32), uint32(fieldValue.Uint()), fieldPath, errors) case *validate.FieldRules_Fixed64: return e.validateUInt64(vctx, fieldRules, (*validate.UInt64Rules)(rule.Fixed64), fieldValue.Uint(), fieldPath, errors) case *validate.FieldRules_Sfixed32: return e.validateInt32(vctx, fieldRules, (*validate.Int32Rules)(rule.Sfixed32), int32(fieldValue.Int()), fieldPath, errors) case *validate.FieldRules_Sfixed64: return e.validateInt64(vctx, fieldRules, (*validate.Int64Rules)(rule.Sfixed64), fieldValue.Int(), fieldPath, errors) case *validate.FieldRules_Float: return e.validateFloat(vctx, fieldRules, rule.Float, float32(fieldValue.Float()), fieldPath, errors) case *validate.FieldRules_Double: return e.validateDouble(vctx, fieldRules, rule.Double, fieldValue.Float(), fieldPath, errors) case *validate.FieldRules_Bool: return e.validateBool(vctx, fieldRules, rule.Bool, fieldValue.Bool(), fieldPath, errors) case *validate.FieldRules_Bytes: return e.validateBytes(vctx, fieldRules, rule.Bytes, fieldValue.Bytes(), fieldPath, errors) case *validate.FieldRules_Enum: return e.validateEnum(vctx, fieldRules, rule.Enum, fieldValue.Enum(), field, fieldPath, errors) case *validate.FieldRules_Repeated: return e.validateRepeated(vctx, fieldRules, rule.Repeated, fieldValue.List(), field, fieldPath, errors) case *validate.FieldRules_Map: return e.validateMap(vctx, fieldRules, rule.Map, fieldValue.Map(), field, fieldPath, errors) } return nil } // shouldIgnoreField 检查是否应该忽略字段验证 func (e *Engine) shouldIgnoreField( rules *validate.FieldRules, hasValue bool, value protoreflect.Value, field protoreflect.FieldDescriptor, ) bool { switch rules.Ignore { case validate.IgnoreRule_IGNORE_ALWAYS: return true case validate.IgnoreRule_IGNORE_IF_ZERO: return !hasValue || isZeroValue(value, field) case validate.IgnoreRule_IGNORE_NEVER: return false default: // IGNORE_UNSPECIFIED: 未设置时不验证 return !hasValue } } // isZeroValue 检查是否为零值 func isZeroValue(value protoreflect.Value, field protoreflect.FieldDescriptor) bool { switch field.Kind() { case protoreflect.BoolKind: return !value.Bool() case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind, protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: return value.Int() == 0 case protoreflect.Uint32Kind, protoreflect.Fixed32Kind, protoreflect.Uint64Kind, protoreflect.Fixed64Kind: return value.Uint() == 0 case protoreflect.FloatKind, protoreflect.DoubleKind: return value.Float() == 0.0 case protoreflect.StringKind: return value.String() == "" case protoreflect.BytesKind: return len(value.Bytes()) == 0 case protoreflect.MessageKind: return !value.Message().IsValid() default: return false } } // validateOneofs 验证所有 oneof func (e *Engine) validateOneofs( vctx *ValidationContext, msgReflect protoreflect.Message, msgDescriptor protoreflect.MessageDescriptor, errors *ValidationErrors, ) error { oneofs := msgDescriptor.Oneofs() for i := 0; i < oneofs.Len(); i++ { oneof := oneofs.Get(i) // 获取 oneof 选项 opts := oneof.Options() if opts == nil { continue } // 检查是否有验证规则 if !proto.HasExtension(opts, validate.E_Oneof) { continue } // 获取 oneof 验证规则 oneofRules := proto.GetExtension(opts, validate.E_Oneof).(*validate.OneofRules) if oneofRules == nil { continue } // 验证 required if oneofRules.Required { hasField := false fields := oneof.Fields() for j := 0; j < fields.Len(); j++ { if msgReflect.Has(fields.Get(j)) { hasField = true break } } if !hasField { errors.Add(NewValidationError( string(oneof.Name()), ErrCodeRequired, fmt.Sprintf("oneof %s 必须选择一个字段", oneof.Name()), )) if vctx.ShouldFailFast() { return errors } } } } return nil } // executeCELRule 执行 CEL 规则 func (e *Engine) executeCELRule( vctx *ValidationContext, rule *validate.Rule, msg proto.Message, fieldPath string, errors *ValidationErrors, ) error { // 执行表达式 result, err := e.celExecutor.EvaluateWithProto(rule.Expression, msg) if err != nil { // CEL 执行错误 errors.Add(&ValidationError{ FieldPath: fieldPath, RuleID: rule.Id, Message: fmt.Sprintf("CEL 表达式执行失败: %v", err), }) return nil } // 如果验证失败 if !result { message := rule.Message if message == "" && vctx.Language == "en-US" && rule.MessageEn != "" { message = rule.MessageEn } if message == "" { message = fmt.Sprintf("验证规则 %s 失败", rule.Id) } errors.Add(&ValidationError{ FieldPath: fieldPath, RuleID: rule.Id, Message: message, MessageEN: rule.MessageEn, }) } return nil } // executeCELFieldRule 执行字段级 CEL 规则 func (e *Engine) executeCELFieldRule( vctx *ValidationContext, rule *validate.Rule, fieldValue interface{}, parent proto.Message, fieldPath string, errors *ValidationErrors, ) error { // 执行表达式 result, err := e.celExecutor.EvaluateFieldRule(rule.Expression, fieldValue, parent) if err != nil { // CEL 执行错误 errors.Add(&ValidationError{ FieldPath: fieldPath, RuleID: rule.Id, Message: fmt.Sprintf("CEL 表达式执行失败: %v", err), }) return nil } // 如果验证失败 if !result { message := rule.Message if message == "" && vctx.Language == "en-US" && rule.MessageEn != "" { message = rule.MessageEn } if message == "" { message = fmt.Sprintf("验证规则 %s 失败", rule.Id) } errors.Add(&ValidationError{ FieldPath: fieldPath, RuleID: rule.Id, Message: message, MessageEN: rule.MessageEn, Value: fieldValue, }) } return nil } // getErrorMessage 获取错误消息,按优先级返回 // 优先级: 特定消息(required_message) > 通用消息(error_message) > 默认消息 func (e *Engine) getErrorMessage( fieldRules *validate.FieldRules, vctx *ValidationContext, ruleType string, defaultMessage string, ) string { // 1. 检查是否有特定规则的自定义消息 (如 required_message) if ruleType == "required" { if vctx.Language == "en-US" && fieldRules.RequiredMessageEn != "" { return fieldRules.RequiredMessageEn } if fieldRules.RequiredMessage != "" { return fieldRules.RequiredMessage } } // 2. 检查通用自定义消息 (error_message) if vctx.Language == "en-US" && fieldRules.ErrorMessageEn != "" { return fieldRules.ErrorMessageEn } if fieldRules.ErrorMessage != "" { return fieldRules.ErrorMessage } // 3. 返回默认消息 return defaultMessage } // ValidateField 验证单个字段(外部接口) func (e *Engine) ValidateField(vctx *ValidationContext, fieldPath string, value interface{}, rules interface{}) error { errors := &ValidationErrors{} // TODO: 实现单字段验证逻辑 if errors.HasErrors() { return errors.ToKratosError() } return nil }