proto.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. package binding
  2. import (
  3. "encoding/base64"
  4. "errors"
  5. "fmt"
  6. "strconv"
  7. "strings"
  8. "time"
  9. "google.golang.org/genproto/protobuf/field_mask"
  10. "google.golang.org/protobuf/proto"
  11. "google.golang.org/protobuf/reflect/protoreflect"
  12. "google.golang.org/protobuf/reflect/protoregistry"
  13. "google.golang.org/protobuf/types/known/durationpb"
  14. "google.golang.org/protobuf/types/known/timestamppb"
  15. "google.golang.org/protobuf/types/known/wrapperspb"
  16. )
  17. // MapProto sets a value in a nested Protobuf structure.
  18. func MapProto(msg proto.Message, values map[string]string) error {
  19. for key, value := range values {
  20. if err := populateFieldValues(msg.ProtoReflect(), strings.Split(key, "."), []string{value}); err != nil {
  21. return err
  22. }
  23. }
  24. return nil
  25. }
  26. func mapProto(msg proto.Message, values map[string][]string) error {
  27. for key, values := range values {
  28. if err := populateFieldValues(msg.ProtoReflect(), strings.Split(key, "."), values); err != nil {
  29. return err
  30. }
  31. }
  32. return nil
  33. }
  34. func populateFieldValues(v protoreflect.Message, fieldPath []string, values []string) error {
  35. if len(fieldPath) < 1 {
  36. return errors.New("no field path")
  37. }
  38. if len(values) < 1 {
  39. return errors.New("no value provided")
  40. }
  41. var fd protoreflect.FieldDescriptor
  42. for i, fieldName := range fieldPath {
  43. fields := v.Descriptor().Fields()
  44. if fd = fields.ByName(protoreflect.Name(fieldName)); fd == nil {
  45. fd = fields.ByJSONName(fieldName)
  46. if fd == nil {
  47. return nil
  48. }
  49. }
  50. if i == len(fieldPath)-1 {
  51. break
  52. }
  53. if fd.Message() == nil || fd.Cardinality() == protoreflect.Repeated {
  54. return fmt.Errorf("invalid path: %q is not a message", fieldName)
  55. }
  56. v = v.Mutable(fd).Message()
  57. }
  58. if of := fd.ContainingOneof(); of != nil {
  59. if f := v.WhichOneof(of); f != nil {
  60. return fmt.Errorf("field already set for oneof %q", of.FullName().Name())
  61. }
  62. }
  63. switch {
  64. case fd.IsList():
  65. return populateRepeatedField(fd, v.Mutable(fd).List(), values)
  66. case fd.IsMap():
  67. return populateMapField(fd, v.Mutable(fd).Map(), values)
  68. }
  69. if len(values) > 1 {
  70. return fmt.Errorf("too many values for field %q: %s", fd.FullName().Name(), strings.Join(values, ", "))
  71. }
  72. return populateField(fd, v, values[0])
  73. }
  74. func populateField(fd protoreflect.FieldDescriptor, v protoreflect.Message, value string) error {
  75. val, err := parseField(fd, value)
  76. if err != nil {
  77. return fmt.Errorf("parsing field %q: %w", fd.FullName().Name(), err)
  78. }
  79. v.Set(fd, val)
  80. return nil
  81. }
  82. func populateRepeatedField(fd protoreflect.FieldDescriptor, list protoreflect.List, values []string) error {
  83. for _, value := range values {
  84. v, err := parseField(fd, value)
  85. if err != nil {
  86. return fmt.Errorf("parsing list %q: %w", fd.FullName().Name(), err)
  87. }
  88. list.Append(v)
  89. }
  90. return nil
  91. }
  92. func populateMapField(fd protoreflect.FieldDescriptor, mp protoreflect.Map, values []string) error {
  93. if len(values) != 2 {
  94. return fmt.Errorf("more than one value provided for key %q in map %q", values[0], fd.FullName())
  95. }
  96. key, err := parseField(fd.MapKey(), values[0])
  97. if err != nil {
  98. return fmt.Errorf("parsing map key %q: %w", fd.FullName().Name(), err)
  99. }
  100. value, err := parseField(fd.MapValue(), values[1])
  101. if err != nil {
  102. return fmt.Errorf("parsing map value %q: %w", fd.FullName().Name(), err)
  103. }
  104. mp.Set(key.MapKey(), value)
  105. return nil
  106. }
  107. func parseField(fd protoreflect.FieldDescriptor, value string) (protoreflect.Value, error) {
  108. switch fd.Kind() {
  109. case protoreflect.BoolKind:
  110. v, err := strconv.ParseBool(value)
  111. if err != nil {
  112. return protoreflect.Value{}, err
  113. }
  114. return protoreflect.ValueOfBool(v), nil
  115. case protoreflect.EnumKind:
  116. enum, err := protoregistry.GlobalTypes.FindEnumByName(fd.Enum().FullName())
  117. switch {
  118. case errors.Is(err, protoregistry.NotFound):
  119. return protoreflect.Value{}, fmt.Errorf("enum %q is not registered", fd.Enum().FullName())
  120. case err != nil:
  121. return protoreflect.Value{}, fmt.Errorf("failed to look up enum: %w", err)
  122. }
  123. v := enum.Descriptor().Values().ByName(protoreflect.Name(value))
  124. if v == nil {
  125. i, err := strconv.ParseInt(value, 10, 32)
  126. if err != nil {
  127. return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value)
  128. }
  129. v = enum.Descriptor().Values().ByNumber(protoreflect.EnumNumber(i))
  130. if v == nil {
  131. return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value)
  132. }
  133. }
  134. return protoreflect.ValueOfEnum(v.Number()), nil
  135. case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
  136. v, err := strconv.ParseInt(value, 10, 32)
  137. if err != nil {
  138. return protoreflect.Value{}, err
  139. }
  140. return protoreflect.ValueOfInt32(int32(v)), nil
  141. case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
  142. v, err := strconv.ParseInt(value, 10, 64)
  143. if err != nil {
  144. return protoreflect.Value{}, err
  145. }
  146. return protoreflect.ValueOfInt64(v), nil
  147. case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
  148. v, err := strconv.ParseUint(value, 10, 32)
  149. if err != nil {
  150. return protoreflect.Value{}, err
  151. }
  152. return protoreflect.ValueOfUint32(uint32(v)), nil
  153. case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
  154. v, err := strconv.ParseUint(value, 10, 64)
  155. if err != nil {
  156. return protoreflect.Value{}, err
  157. }
  158. return protoreflect.ValueOfUint64(v), nil
  159. case protoreflect.FloatKind:
  160. v, err := strconv.ParseFloat(value, 32)
  161. if err != nil {
  162. return protoreflect.Value{}, err
  163. }
  164. return protoreflect.ValueOfFloat32(float32(v)), nil
  165. case protoreflect.DoubleKind:
  166. v, err := strconv.ParseFloat(value, 64)
  167. if err != nil {
  168. return protoreflect.Value{}, err
  169. }
  170. return protoreflect.ValueOfFloat64(v), nil
  171. case protoreflect.StringKind:
  172. return protoreflect.ValueOfString(value), nil
  173. case protoreflect.BytesKind:
  174. v, err := base64.StdEncoding.DecodeString(value)
  175. if err != nil {
  176. return protoreflect.Value{}, err
  177. }
  178. return protoreflect.ValueOfBytes(v), nil
  179. case protoreflect.MessageKind, protoreflect.GroupKind:
  180. return parseMessage(fd.Message(), value)
  181. default:
  182. panic(fmt.Sprintf("unknown field kind: %v", fd.Kind()))
  183. }
  184. }
  185. func parseMessage(md protoreflect.MessageDescriptor, value string) (protoreflect.Value, error) {
  186. var msg proto.Message
  187. switch md.FullName() {
  188. case "google.protobuf.Timestamp":
  189. if value == "null" {
  190. break
  191. }
  192. t, err := time.Parse(time.RFC3339Nano, value)
  193. if err != nil {
  194. return protoreflect.Value{}, err
  195. }
  196. msg = timestamppb.New(t)
  197. case "google.protobuf.Duration":
  198. if value == "null" {
  199. break
  200. }
  201. d, err := time.ParseDuration(value)
  202. if err != nil {
  203. return protoreflect.Value{}, err
  204. }
  205. msg = durationpb.New(d)
  206. case "google.protobuf.DoubleValue":
  207. v, err := strconv.ParseFloat(value, 64)
  208. if err != nil {
  209. return protoreflect.Value{}, err
  210. }
  211. msg = wrapperspb.Double(v)
  212. case "google.protobuf.FloatValue":
  213. v, err := strconv.ParseFloat(value, 32)
  214. if err != nil {
  215. return protoreflect.Value{}, err
  216. }
  217. msg = wrapperspb.Float(float32(v))
  218. case "google.protobuf.Int64Value":
  219. v, err := strconv.ParseInt(value, 10, 64)
  220. if err != nil {
  221. return protoreflect.Value{}, err
  222. }
  223. msg = wrapperspb.Int64(v)
  224. case "google.protobuf.Int32Value":
  225. v, err := strconv.ParseInt(value, 10, 32)
  226. if err != nil {
  227. return protoreflect.Value{}, err
  228. }
  229. msg = wrapperspb.Int32(int32(v))
  230. case "google.protobuf.UInt64Value":
  231. v, err := strconv.ParseUint(value, 10, 64)
  232. if err != nil {
  233. return protoreflect.Value{}, err
  234. }
  235. msg = wrapperspb.UInt64(v)
  236. case "google.protobuf.UInt32Value":
  237. v, err := strconv.ParseUint(value, 10, 32)
  238. if err != nil {
  239. return protoreflect.Value{}, err
  240. }
  241. msg = wrapperspb.UInt32(uint32(v))
  242. case "google.protobuf.BoolValue":
  243. v, err := strconv.ParseBool(value)
  244. if err != nil {
  245. return protoreflect.Value{}, err
  246. }
  247. msg = wrapperspb.Bool(v)
  248. case "google.protobuf.StringValue":
  249. msg = wrapperspb.String(value)
  250. case "google.protobuf.BytesValue":
  251. v, err := base64.StdEncoding.DecodeString(value)
  252. if err != nil {
  253. return protoreflect.Value{}, err
  254. }
  255. msg = wrapperspb.Bytes(v)
  256. case "google.protobuf.FieldMask":
  257. fm := &field_mask.FieldMask{}
  258. fm.Paths = append(fm.Paths, strings.Split(value, ",")...)
  259. msg = fm
  260. default:
  261. return protoreflect.Value{}, fmt.Errorf("unsupported message type: %q", string(md.FullName()))
  262. }
  263. return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil
  264. }