| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278 | package bindingimport (	"encoding/base64"	"errors"	"fmt"	"strconv"	"strings"	"time"	"google.golang.org/genproto/protobuf/field_mask"	"google.golang.org/protobuf/proto"	"google.golang.org/protobuf/reflect/protoreflect"	"google.golang.org/protobuf/reflect/protoregistry"	"google.golang.org/protobuf/types/known/durationpb"	"google.golang.org/protobuf/types/known/timestamppb"	"google.golang.org/protobuf/types/known/wrapperspb")// MapProto sets a value in a nested Protobuf structure.func MapProto(msg proto.Message, values map[string]string) error {	for key, value := range values {		if err := populateFieldValues(msg.ProtoReflect(), strings.Split(key, "."), []string{value}); err != nil {			return err		}	}	return nil}func mapProto(msg proto.Message, values map[string][]string) error {	for key, values := range values {		if err := populateFieldValues(msg.ProtoReflect(), strings.Split(key, "."), values); err != nil {			return err		}	}	return nil}func populateFieldValues(v protoreflect.Message, fieldPath []string, values []string) error {	if len(fieldPath) < 1 {		return errors.New("no field path")	}	if len(values) < 1 {		return errors.New("no value provided")	}	var fd protoreflect.FieldDescriptor	for i, fieldName := range fieldPath {		fields := v.Descriptor().Fields()		if fd = fields.ByName(protoreflect.Name(fieldName)); fd == nil {			fd = fields.ByJSONName(fieldName)			if fd == nil {				return nil			}		}		if i == len(fieldPath)-1 {			break		}		if fd.Message() == nil || fd.Cardinality() == protoreflect.Repeated {			return fmt.Errorf("invalid path: %q is not a message", fieldName)		}		v = v.Mutable(fd).Message()	}	if of := fd.ContainingOneof(); of != nil {		if f := v.WhichOneof(of); f != nil {			return fmt.Errorf("field already set for oneof %q", of.FullName().Name())		}	}	switch {	case fd.IsList():		return populateRepeatedField(fd, v.Mutable(fd).List(), values)	case fd.IsMap():		return populateMapField(fd, v.Mutable(fd).Map(), values)	}	if len(values) > 1 {		return fmt.Errorf("too many values for field %q: %s", fd.FullName().Name(), strings.Join(values, ", "))	}	return populateField(fd, v, values[0])}func populateField(fd protoreflect.FieldDescriptor, v protoreflect.Message, value string) error {	val, err := parseField(fd, value)	if err != nil {		return fmt.Errorf("parsing field %q: %w", fd.FullName().Name(), err)	}	v.Set(fd, val)	return nil}func populateRepeatedField(fd protoreflect.FieldDescriptor, list protoreflect.List, values []string) error {	for _, value := range values {		v, err := parseField(fd, value)		if err != nil {			return fmt.Errorf("parsing list %q: %w", fd.FullName().Name(), err)		}		list.Append(v)	}	return nil}func populateMapField(fd protoreflect.FieldDescriptor, mp protoreflect.Map, values []string) error {	if len(values) != 2 {		return fmt.Errorf("more than one value provided for key %q in map %q", values[0], fd.FullName())	}	key, err := parseField(fd.MapKey(), values[0])	if err != nil {		return fmt.Errorf("parsing map key %q: %w", fd.FullName().Name(), err)	}	value, err := parseField(fd.MapValue(), values[1])	if err != nil {		return fmt.Errorf("parsing map value %q: %w", fd.FullName().Name(), err)	}	mp.Set(key.MapKey(), value)	return nil}func parseField(fd protoreflect.FieldDescriptor, value string) (protoreflect.Value, error) {	switch fd.Kind() {	case protoreflect.BoolKind:		v, err := strconv.ParseBool(value)		if err != nil {			return protoreflect.Value{}, err		}		return protoreflect.ValueOfBool(v), nil	case protoreflect.EnumKind:		enum, err := protoregistry.GlobalTypes.FindEnumByName(fd.Enum().FullName())		switch {		case errors.Is(err, protoregistry.NotFound):			return protoreflect.Value{}, fmt.Errorf("enum %q is not registered", fd.Enum().FullName())		case err != nil:			return protoreflect.Value{}, fmt.Errorf("failed to look up enum: %w", err)		}		v := enum.Descriptor().Values().ByName(protoreflect.Name(value))		if v == nil {			i, err := strconv.ParseInt(value, 10, 32)			if err != nil {				return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value)			}			v = enum.Descriptor().Values().ByNumber(protoreflect.EnumNumber(i))			if v == nil {				return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value)			}		}		return protoreflect.ValueOfEnum(v.Number()), nil	case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:		v, err := strconv.ParseInt(value, 10, 32)		if err != nil {			return protoreflect.Value{}, err		}		return protoreflect.ValueOfInt32(int32(v)), nil	case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:		v, err := strconv.ParseInt(value, 10, 64)		if err != nil {			return protoreflect.Value{}, err		}		return protoreflect.ValueOfInt64(v), nil	case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:		v, err := strconv.ParseUint(value, 10, 32)		if err != nil {			return protoreflect.Value{}, err		}		return protoreflect.ValueOfUint32(uint32(v)), nil	case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:		v, err := strconv.ParseUint(value, 10, 64)		if err != nil {			return protoreflect.Value{}, err		}		return protoreflect.ValueOfUint64(v), nil	case protoreflect.FloatKind:		v, err := strconv.ParseFloat(value, 32)		if err != nil {			return protoreflect.Value{}, err		}		return protoreflect.ValueOfFloat32(float32(v)), nil	case protoreflect.DoubleKind:		v, err := strconv.ParseFloat(value, 64)		if err != nil {			return protoreflect.Value{}, err		}		return protoreflect.ValueOfFloat64(v), nil	case protoreflect.StringKind:		return protoreflect.ValueOfString(value), nil	case protoreflect.BytesKind:		v, err := base64.StdEncoding.DecodeString(value)		if err != nil {			return protoreflect.Value{}, err		}		return protoreflect.ValueOfBytes(v), nil	case protoreflect.MessageKind, protoreflect.GroupKind:		return parseMessage(fd.Message(), value)	default:		panic(fmt.Sprintf("unknown field kind: %v", fd.Kind()))	}}func parseMessage(md protoreflect.MessageDescriptor, value string) (protoreflect.Value, error) {	var msg proto.Message	switch md.FullName() {	case "google.protobuf.Timestamp":		if value == "null" {			break		}		t, err := time.Parse(time.RFC3339Nano, value)		if err != nil {			return protoreflect.Value{}, err		}		msg = timestamppb.New(t)	case "google.protobuf.Duration":		if value == "null" {			break		}		d, err := time.ParseDuration(value)		if err != nil {			return protoreflect.Value{}, err		}		msg = durationpb.New(d)	case "google.protobuf.DoubleValue":		v, err := strconv.ParseFloat(value, 64)		if err != nil {			return protoreflect.Value{}, err		}		msg = wrapperspb.Double(v)	case "google.protobuf.FloatValue":		v, err := strconv.ParseFloat(value, 32)		if err != nil {			return protoreflect.Value{}, err		}		msg = wrapperspb.Float(float32(v))	case "google.protobuf.Int64Value":		v, err := strconv.ParseInt(value, 10, 64)		if err != nil {			return protoreflect.Value{}, err		}		msg = wrapperspb.Int64(v)	case "google.protobuf.Int32Value":		v, err := strconv.ParseInt(value, 10, 32)		if err != nil {			return protoreflect.Value{}, err		}		msg = wrapperspb.Int32(int32(v))	case "google.protobuf.UInt64Value":		v, err := strconv.ParseUint(value, 10, 64)		if err != nil {			return protoreflect.Value{}, err		}		msg = wrapperspb.UInt64(v)	case "google.protobuf.UInt32Value":		v, err := strconv.ParseUint(value, 10, 32)		if err != nil {			return protoreflect.Value{}, err		}		msg = wrapperspb.UInt32(uint32(v))	case "google.protobuf.BoolValue":		v, err := strconv.ParseBool(value)		if err != nil {			return protoreflect.Value{}, err		}		msg = wrapperspb.Bool(v)	case "google.protobuf.StringValue":		msg = wrapperspb.String(value)	case "google.protobuf.BytesValue":		v, err := base64.StdEncoding.DecodeString(value)		if err != nil {			return protoreflect.Value{}, err		}		msg = wrapperspb.Bytes(v)	case "google.protobuf.FieldMask":		fm := &field_mask.FieldMask{}		fm.Paths = append(fm.Paths, strings.Split(value, ",")...)		msg = fm	default:		return protoreflect.Value{}, fmt.Errorf("unsupported message type: %q", string(md.FullName()))	}	return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil}
 |