cel_executor.go 5.5 KB


  1. package validator
  2. import (
  3. "fmt"
  4. "sync"
  5. "github.com/google/cel-go/cel"
  6. "github.com/google/cel-go/common/types"
  7. "github.com/google/cel-go/common/types/ref"
  8. "google.golang.org/protobuf/proto"
  9. )
  10. // CELExecutor CEL 表达式执行器
  11. type CELExecutor struct {
  12. mu sync.RWMutex
  13. cache map[string]cel.Program // 表达式缓存 (cel.Program 是接口,不需要指针)
  14. env *cel.Env // CEL 环境
  15. }
  16. // NewCELExecutor 创建 CEL 执行器
  17. func NewCELExecutor() (*CELExecutor, error) {
  18. // 创建 CEL 环境 (v0.26+ 不需要显式注册基础类型)
  19. env, err := cel.NewEnv()
  20. if err != nil {
  21. return nil, fmt.Errorf("failed to create CEL environment: %w", err)
  22. }
  23. return &CELExecutor{
  24. cache: make(map[string]cel.Program),
  25. env: env,
  26. }, nil
  27. }
  28. // Evaluate 执行 CEL 表达式
  29. func (e *CELExecutor) Evaluate(expression string, data map[string]interface{}) (bool, error) {
  30. // 获取或编译程序
  31. prog, err := e.getProgram(expression)
  32. if err != nil {
  33. return false, fmt.Errorf("failed to compile CEL expression: %w", err)
  34. }
  35. // 执行表达式
  36. result, _, err := prog.Eval(data)
  37. if err != nil {
  38. return false, fmt.Errorf("failed to evaluate CEL expression: %w", err)
  39. }
  40. // 检查结果类型
  41. boolResult, ok := result.Value().(bool)
  42. if !ok {
  43. return false, fmt.Errorf("CEL expression must return boolean, got %T", result.Value())
  44. }
  45. return boolResult, nil
  46. }
  47. // EvaluateWithProto 使用 Proto 消息执行 CEL 表达式
  48. func (e *CELExecutor) EvaluateWithProto(expression string, msg proto.Message) (bool, error) {
  49. // 获取或编译程序
  50. prog, err := e.getProgram(expression)
  51. if err != nil {
  52. return false, fmt.Errorf("failed to compile CEL expression: %w", err)
  53. }
  54. // 将 Proto 消息转换为 CEL 变量
  55. data := map[string]interface{}{
  56. "this": msg,
  57. }
  58. // 执行表达式
  59. result, _, err := prog.Eval(data)
  60. if err != nil {
  61. return false, fmt.Errorf("failed to evaluate CEL expression: %w", err)
  62. }
  63. // 检查结果类型
  64. boolResult, ok := result.Value().(bool)
  65. if !ok {
  66. return false, fmt.Errorf("CEL expression must return boolean, got %T", result.Value())
  67. }
  68. return boolResult, nil
  69. }
  70. // EvaluateFieldRule 执行字段级 CEL 规则
  71. func (e *CELExecutor) EvaluateFieldRule(expression string, fieldValue interface{}, parent proto.Message) (bool, error) {
  72. // 获取或编译程序
  73. prog, err := e.getProgram(expression)
  74. if err != nil {
  75. return false, fmt.Errorf("failed to compile CEL expression: %w", err)
  76. }
  77. // 准备数据
  78. data := map[string]interface{}{
  79. "this": fieldValue,
  80. }
  81. // 如果有父消息,也添加进去(用于跨字段验证)
  82. if parent != nil {
  83. data["parent"] = parent
  84. }
  85. // 执行表达式
  86. result, _, err := prog.Eval(data)
  87. if err != nil {
  88. return false, fmt.Errorf("failed to evaluate CEL expression: %w", err)
  89. }
  90. // 检查结果类型
  91. boolResult, ok := result.Value().(bool)
  92. if !ok {
  93. return false, fmt.Errorf("CEL expression must return boolean, got %T", result.Value())
  94. }
  95. return boolResult, nil
  96. }
  97. // EvaluateString 执行返回字符串的 CEL 表达式(用于动态错误消息)
  98. func (e *CELExecutor) EvaluateString(expression string, data map[string]interface{}) (string, error) {
  99. // 获取或编译程序
  100. prog, err := e.getProgram(expression)
  101. if err != nil {
  102. return "", fmt.Errorf("failed to compile CEL expression: %w", err)
  103. }
  104. // 执行表达式
  105. result, _, err := prog.Eval(data)
  106. if err != nil {
  107. return "", fmt.Errorf("failed to evaluate CEL expression: %w", err)
  108. }
  109. // 检查结果类型
  110. strResult, ok := result.Value().(string)
  111. if !ok {
  112. return "", fmt.Errorf("CEL expression must return string, got %T", result.Value())
  113. }
  114. return strResult, nil
  115. }
  116. // getProgram 获取或编译程序(带缓存)
  117. func (e *CELExecutor) getProgram(expression string) (cel.Program, error) {
  118. // 先尝试从缓存读取
  119. e.mu.RLock()
  120. if prog, ok := e.cache[expression]; ok {
  121. e.mu.RUnlock()
  122. return prog, nil
  123. }
  124. e.mu.RUnlock()
  125. // 编译表达式
  126. ast, issues := e.env.Compile(expression)
  127. if issues != nil && issues.Err() != nil {
  128. return nil, fmt.Errorf("CEL compilation error: %w", issues.Err())
  129. }
  130. // 创建程序
  131. prog, err := e.env.Program(ast)
  132. if err != nil {
  133. return nil, fmt.Errorf("failed to create CEL program: %w", err)
  134. }
  135. // 存入缓存
  136. e.mu.Lock()
  137. e.cache[expression] = prog
  138. e.mu.Unlock()
  139. return prog, nil
  140. }
  141. // ClearCache 清空表达式缓存
  142. func (e *CELExecutor) ClearCache() {
  143. e.mu.Lock()
  144. e.cache = make(map[string]cel.Program)
  145. e.mu.Unlock()
  146. }
  147. // CacheSize 获取缓存大小
  148. func (e *CELExecutor) CacheSize() int {
  149. e.mu.RLock()
  150. defer e.mu.RUnlock()
  151. return len(e.cache)
  152. }
  153. // ValidateCELExpression 验证 CEL 表达式语法(不执行)
  154. func (e *CELExecutor) ValidateCELExpression(expression string) error {
  155. _, issues := e.env.Compile(expression)
  156. if issues != nil && issues.Err() != nil {
  157. return fmt.Errorf("invalid CEL expression: %w", issues.Err())
  158. }
  159. return nil
  160. }
  161. // Helper 函数
  162. // celBool 将 ref.Val 转换为 bool
  163. func celBool(val ref.Val) (bool, bool) {
  164. if val.Type() == types.BoolType {
  165. return val.Value().(bool), true
  166. }
  167. return false, false
  168. }
  169. // celString 将 ref.Val 转换为 string
  170. func celString(val ref.Val) (string, bool) {
  171. if val.Type() == types.StringType {
  172. return val.Value().(string), true
  173. }
  174. return "", false
  175. }
  176. // celInt 将 ref.Val 转换为 int64
  177. func celInt(val ref.Val) (int64, bool) {
  178. if val.Type() == types.IntType {
  179. return val.Value().(int64), true
  180. }
  181. return 0, false
  182. }
  183. // celDouble 将 ref.Val 转换为 float64
  184. func celDouble(val ref.Val) (float64, bool) {
  185. if val.Type() == types.DoubleType {
  186. return val.Value().(float64), true
  187. }
  188. return 0, false
  189. }