package validator import ( "fmt" "sync" "github.com/google/cel-go/cel" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "google.golang.org/protobuf/proto" ) // CELExecutor CEL 表达式执行器 type CELExecutor struct { mu sync.RWMutex cache map[string]cel.Program // 表达式缓存 (cel.Program 是接口,不需要指针) env *cel.Env // CEL 环境 } // NewCELExecutor 创建 CEL 执行器 func NewCELExecutor() (*CELExecutor, error) { // 创建 CEL 环境 (v0.26+ 不需要显式注册基础类型) env, err := cel.NewEnv() if err != nil { return nil, fmt.Errorf("failed to create CEL environment: %w", err) } return &CELExecutor{ cache: make(map[string]cel.Program), env: env, }, nil } // Evaluate 执行 CEL 表达式 func (e *CELExecutor) Evaluate(expression string, data map[string]interface{}) (bool, error) { // 获取或编译程序 prog, err := e.getProgram(expression) if err != nil { return false, fmt.Errorf("failed to compile CEL expression: %w", err) } // 执行表达式 result, _, err := prog.Eval(data) if err != nil { return false, fmt.Errorf("failed to evaluate CEL expression: %w", err) } // 检查结果类型 boolResult, ok := result.Value().(bool) if !ok { return false, fmt.Errorf("CEL expression must return boolean, got %T", result.Value()) } return boolResult, nil } // EvaluateWithProto 使用 Proto 消息执行 CEL 表达式 func (e *CELExecutor) EvaluateWithProto(expression string, msg proto.Message) (bool, error) { // 获取或编译程序 prog, err := e.getProgram(expression) if err != nil { return false, fmt.Errorf("failed to compile CEL expression: %w", err) } // 将 Proto 消息转换为 CEL 变量 data := map[string]interface{}{ "this": msg, } // 执行表达式 result, _, err := prog.Eval(data) if err != nil { return false, fmt.Errorf("failed to evaluate CEL expression: %w", err) } // 检查结果类型 boolResult, ok := result.Value().(bool) if !ok { return false, fmt.Errorf("CEL expression must return boolean, got %T", result.Value()) } return boolResult, nil } // EvaluateFieldRule 执行字段级 CEL 规则 func (e *CELExecutor) EvaluateFieldRule(expression string, fieldValue interface{}, parent proto.Message) (bool, error) { // 获取或编译程序 prog, err := e.getProgram(expression) if err != nil { return false, fmt.Errorf("failed to compile CEL expression: %w", err) } // 准备数据 data := map[string]interface{}{ "this": fieldValue, } // 如果有父消息,也添加进去(用于跨字段验证) if parent != nil { data["parent"] = parent } // 执行表达式 result, _, err := prog.Eval(data) if err != nil { return false, fmt.Errorf("failed to evaluate CEL expression: %w", err) } // 检查结果类型 boolResult, ok := result.Value().(bool) if !ok { return false, fmt.Errorf("CEL expression must return boolean, got %T", result.Value()) } return boolResult, nil } // EvaluateString 执行返回字符串的 CEL 表达式(用于动态错误消息) func (e *CELExecutor) EvaluateString(expression string, data map[string]interface{}) (string, error) { // 获取或编译程序 prog, err := e.getProgram(expression) if err != nil { return "", fmt.Errorf("failed to compile CEL expression: %w", err) } // 执行表达式 result, _, err := prog.Eval(data) if err != nil { return "", fmt.Errorf("failed to evaluate CEL expression: %w", err) } // 检查结果类型 strResult, ok := result.Value().(string) if !ok { return "", fmt.Errorf("CEL expression must return string, got %T", result.Value()) } return strResult, nil } // getProgram 获取或编译程序(带缓存) func (e *CELExecutor) getProgram(expression string) (cel.Program, error) { // 先尝试从缓存读取 e.mu.RLock() if prog, ok := e.cache[expression]; ok { e.mu.RUnlock() return prog, nil } e.mu.RUnlock() // 编译表达式 ast, issues := e.env.Compile(expression) if issues != nil && issues.Err() != nil { return nil, fmt.Errorf("CEL compilation error: %w", issues.Err()) } // 创建程序 prog, err := e.env.Program(ast) if err != nil { return nil, fmt.Errorf("failed to create CEL program: %w", err) } // 存入缓存 e.mu.Lock() e.cache[expression] = prog e.mu.Unlock() return prog, nil } // ClearCache 清空表达式缓存 func (e *CELExecutor) ClearCache() { e.mu.Lock() e.cache = make(map[string]cel.Program) e.mu.Unlock() } // CacheSize 获取缓存大小 func (e *CELExecutor) CacheSize() int { e.mu.RLock() defer e.mu.RUnlock() return len(e.cache) } // ValidateCELExpression 验证 CEL 表达式语法(不执行) func (e *CELExecutor) ValidateCELExpression(expression string) error { _, issues := e.env.Compile(expression) if issues != nil && issues.Err() != nil { return fmt.Errorf("invalid CEL expression: %w", issues.Err()) } return nil } // Helper 函数 // celBool 将 ref.Val 转换为 bool func celBool(val ref.Val) (bool, bool) { if val.Type() == types.BoolType { return val.Value().(bool), true } return false, false } // celString 将 ref.Val 转换为 string func celString(val ref.Val) (string, bool) { if val.Type() == types.StringType { return val.Value().(string), true } return "", false } // celInt 将 ref.Val 转换为 int64 func celInt(val ref.Val) (int64, bool) { if val.Type() == types.IntType { return val.Value().(int64), true } return 0, false } // celDouble 将 ref.Val 转换为 float64 func celDouble(val ref.Val) (float64, bool) { if val.Type() == types.DoubleType { return val.Value().(float64), true } return 0, false }