dcsunny 4 жил өмнө
parent
commit
f79996e219

+ 18 - 0
http/binding/bind.go

@@ -0,0 +1,18 @@
+package binding
+
+import (
+	"net/http"
+
+	"google.golang.org/protobuf/proto"
+)
+
+// BindForm bind form parameters to target.
+func BindForm(req *http.Request, target interface{}) error {
+	if err := req.ParseForm(); err != nil {
+		return err
+	}
+	if msg, ok := target.(proto.Message); ok {
+		return mapProto(msg, req.Form)
+	}
+	return mapForm(target, req.Form)
+}

+ 385 - 0
http/binding/form.go

@@ -0,0 +1,385 @@
+package binding
+
+// Copyright 2014 Manu Martinez-Almeida.  All rights reserved.
+// Use of this source code is governed by a MIT style
+// license that can be found in the LICENSE file.
+
+import (
+	"encoding/json"
+	"errors"
+	"fmt"
+	"reflect"
+	"strconv"
+	"strings"
+	"time"
+)
+
+var (
+	errUnknownType = errors.New("unknown type")
+	emptyField     = reflect.StructField{}
+)
+
+func mapForm(ptr interface{}, form map[string][]string) error {
+	return mapFormByTag(ptr, form, "json")
+}
+
+func mapFormByTag(ptr interface{}, form map[string][]string, tag string) error {
+	ptrVal := reflect.ValueOf(ptr)
+	var pointed interface{}
+	if ptrVal.Kind() == reflect.Ptr {
+		ptrVal = ptrVal.Elem()
+		pointed = ptrVal.Interface()
+	}
+	if ptrVal.Kind() == reflect.Map &&
+		ptrVal.Type().Key().Kind() == reflect.String {
+		if pointed != nil {
+			ptr = pointed
+		}
+		return setFormMap(ptr, form)
+	}
+	return mappingByPtr(ptr, formSource(form), tag)
+}
+
+// setter tries to set value on a walking by fields of a struct
+type setter interface {
+	TrySet(value reflect.Value, field reflect.StructField, key string, opt setOptions) (isSetted bool, err error)
+}
+
+type formSource map[string][]string
+
+var _ setter = formSource(nil)
+
+// TrySet tries to set a value by request's form source (like map[string][]string)
+func (form formSource) TrySet(value reflect.Value, field reflect.StructField, tagValue string, opt setOptions) (isSetted bool, err error) {
+	return setByForm(value, field, form, tagValue, opt)
+}
+
+func mappingByPtr(ptr interface{}, setter setter, tag string) error {
+	_, err := mapping(reflect.ValueOf(ptr), emptyField, setter, tag)
+	return err
+}
+
+func mapping(value reflect.Value, field reflect.StructField, setter setter, tag string) (bool, error) {
+	if field.Tag.Get(tag) == "-" { // just ignoring this field
+		return false, nil
+	}
+
+	var vKind = value.Kind()
+
+	if vKind == reflect.Ptr {
+		var isNew bool
+		vPtr := value
+		if value.IsNil() {
+			isNew = true
+			vPtr = reflect.New(value.Type().Elem())
+		}
+		isSetted, err := mapping(vPtr.Elem(), field, setter, tag)
+		if err != nil {
+			return false, err
+		}
+		if isNew && isSetted {
+			value.Set(vPtr)
+		}
+		return isSetted, nil
+	}
+
+	if vKind != reflect.Struct || !field.Anonymous {
+		ok, err := tryToSetValue(value, field, setter, tag)
+		if err != nil {
+			return false, err
+		}
+		if ok {
+			return true, nil
+		}
+	}
+
+	if vKind == reflect.Struct {
+		tValue := value.Type()
+
+		var isSetted bool
+		for i := 0; i < value.NumField(); i++ {
+			sf := tValue.Field(i)
+			if sf.PkgPath != "" && !sf.Anonymous { // unexported
+				continue
+			}
+			ok, err := mapping(value.Field(i), tValue.Field(i), setter, tag)
+			if err != nil {
+				return false, err
+			}
+			isSetted = isSetted || ok
+		}
+		return isSetted, nil
+	}
+	return false, nil
+}
+
+type setOptions struct {
+	isDefaultExists bool
+	defaultValue    string
+}
+
+func tryToSetValue(value reflect.Value, field reflect.StructField, setter setter, tag string) (bool, error) {
+	var tagValue string
+	var setOpt setOptions
+
+	tagValue = field.Tag.Get(tag)
+	tagValue, opts := head(tagValue, ",")
+
+	if tagValue == "" { // default value is FieldName
+		tagValue = field.Name
+	}
+	if tagValue == "" { // when field is "emptyField" variable
+		return false, nil
+	}
+
+	var opt string
+	for len(opts) > 0 {
+		opt, opts = head(opts, ",")
+
+		if k, v := head(opt, "="); k == "default" {
+			setOpt.isDefaultExists = true
+			setOpt.defaultValue = v
+		}
+	}
+
+	return setter.TrySet(value, field, tagValue, setOpt)
+}
+
+func setByForm(value reflect.Value, field reflect.StructField, form map[string][]string, tagValue string, opt setOptions) (isSetted bool, err error) {
+	vs, ok := form[tagValue]
+	if !ok && !opt.isDefaultExists {
+		return false, nil
+	}
+
+	switch value.Kind() {
+	case reflect.Slice:
+		if !ok {
+			vs = []string{opt.defaultValue}
+		}
+		return true, setSlice(vs, value, field)
+	case reflect.Array:
+		if !ok {
+			vs = []string{opt.defaultValue}
+		}
+		if len(vs) != value.Len() {
+			return false, fmt.Errorf("%q is not valid value for %s", vs, value.Type().String())
+		}
+		return true, setArray(vs, value, field)
+	default:
+		var val string
+		if !ok {
+			val = opt.defaultValue
+		}
+
+		if len(vs) > 0 {
+			val = vs[0]
+		}
+		return true, setWithProperType(val, value, field)
+	}
+}
+
+func setWithProperType(val string, value reflect.Value, field reflect.StructField) error {
+	switch value.Kind() {
+	case reflect.Int:
+		return setIntField(val, 0, value)
+	case reflect.Int8:
+		return setIntField(val, 8, value)
+	case reflect.Int16:
+		return setIntField(val, 16, value)
+	case reflect.Int32:
+		return setIntField(val, 32, value)
+	case reflect.Int64:
+		switch value.Interface().(type) {
+		case time.Duration:
+			return setTimeDuration(val, value, field)
+		}
+		return setIntField(val, 64, value)
+	case reflect.Uint:
+		return setUintField(val, 0, value)
+	case reflect.Uint8:
+		return setUintField(val, 8, value)
+	case reflect.Uint16:
+		return setUintField(val, 16, value)
+	case reflect.Uint32:
+		return setUintField(val, 32, value)
+	case reflect.Uint64:
+		return setUintField(val, 64, value)
+	case reflect.Bool:
+		return setBoolField(val, value)
+	case reflect.Float32:
+		return setFloatField(val, 32, value)
+	case reflect.Float64:
+		return setFloatField(val, 64, value)
+	case reflect.String:
+		value.SetString(val)
+	case reflect.Struct:
+		switch value.Interface().(type) {
+		case time.Time:
+			return setTimeField(val, field, value)
+		}
+		return json.Unmarshal([]byte(val), value.Addr().Interface())
+	case reflect.Map:
+		return json.Unmarshal([]byte(val), value.Addr().Interface())
+	default:
+		return errUnknownType
+	}
+	return nil
+}
+
+func setIntField(val string, bitSize int, field reflect.Value) error {
+	if val == "" {
+		val = "0"
+	}
+	intVal, err := strconv.ParseInt(val, 10, bitSize)
+	if err == nil {
+		field.SetInt(intVal)
+	}
+	return err
+}
+
+func setUintField(val string, bitSize int, field reflect.Value) error {
+	if val == "" {
+		val = "0"
+	}
+	uintVal, err := strconv.ParseUint(val, 10, bitSize)
+	if err == nil {
+		field.SetUint(uintVal)
+	}
+	return err
+}
+
+func setBoolField(val string, field reflect.Value) error {
+	if val == "" {
+		val = "false"
+	}
+	boolVal, err := strconv.ParseBool(val)
+	if err == nil {
+		field.SetBool(boolVal)
+	}
+	return err
+}
+
+func setFloatField(val string, bitSize int, field reflect.Value) error {
+	if val == "" {
+		val = "0.0"
+	}
+	floatVal, err := strconv.ParseFloat(val, bitSize)
+	if err == nil {
+		field.SetFloat(floatVal)
+	}
+	return err
+}
+
+func setTimeField(val string, structField reflect.StructField, value reflect.Value) error {
+	timeFormat := structField.Tag.Get("time_format")
+	if timeFormat == "" {
+		timeFormat = time.RFC3339
+	}
+
+	switch tf := strings.ToLower(timeFormat); tf {
+	case "unix", "unixnano":
+		tv, err := strconv.ParseInt(val, 10, 64)
+		if err != nil {
+			return err
+		}
+
+		d := time.Duration(1)
+		if tf == "unixnano" {
+			d = time.Second
+		}
+
+		t := time.Unix(tv/int64(d), tv%int64(d))
+		value.Set(reflect.ValueOf(t))
+		return nil
+
+	}
+
+	if val == "" {
+		value.Set(reflect.ValueOf(time.Time{}))
+		return nil
+	}
+
+	l := time.Local
+	if isUTC, _ := strconv.ParseBool(structField.Tag.Get("time_utc")); isUTC {
+		l = time.UTC
+	}
+
+	if locTag := structField.Tag.Get("time_location"); locTag != "" {
+		loc, err := time.LoadLocation(locTag)
+		if err != nil {
+			return err
+		}
+		l = loc
+	}
+
+	t, err := time.ParseInLocation(timeFormat, val, l)
+	if err != nil {
+		return err
+	}
+
+	value.Set(reflect.ValueOf(t))
+	return nil
+}
+
+func setArray(vals []string, value reflect.Value, field reflect.StructField) error {
+	for i, s := range vals {
+		err := setWithProperType(s, value.Index(i), field)
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func setSlice(vals []string, value reflect.Value, field reflect.StructField) error {
+	slice := reflect.MakeSlice(value.Type(), len(vals), len(vals))
+	err := setArray(vals, slice, field)
+	if err != nil {
+		return err
+	}
+	value.Set(slice)
+	return nil
+}
+
+func setTimeDuration(val string, value reflect.Value, field reflect.StructField) error {
+	d, err := time.ParseDuration(val)
+	if err != nil {
+		return err
+	}
+	value.Set(reflect.ValueOf(d))
+	return nil
+}
+
+func head(str, sep string) (head string, tail string) {
+	idx := strings.Index(str, sep)
+	if idx < 0 {
+		return str, ""
+	}
+	return str[:idx], str[idx+len(sep):]
+}
+
+func setFormMap(ptr interface{}, form map[string][]string) error {
+	el := reflect.TypeOf(ptr).Elem()
+
+	if el.Kind() == reflect.Slice {
+		ptrMap, ok := ptr.(map[string][]string)
+		if !ok {
+			return errors.New("cannot convert to map slices of strings")
+		}
+		for k, v := range form {
+			ptrMap[k] = v
+		}
+
+		return nil
+	}
+
+	ptrMap, ok := ptr.(map[string]string)
+	if !ok {
+		return errors.New("cannot convert to map of strings")
+	}
+	for k, v := range form {
+		ptrMap[k] = v[len(v)-1] // pick last
+	}
+
+	return nil
+}

+ 1 - 1
http/binding/proto.go

@@ -135,7 +135,7 @@ func parseField(fd protoreflect.FieldDescriptor, value string) (protoreflect.Val
 		}
 		v := enum.Descriptor().Values().ByName(protoreflect.Name(value))
 		if v == nil {
-			i, err := strconv.Atoi(value)
+			i, err := strconv.ParseInt(value, 10, 32)
 			if err != nil {
 				return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value)
 			}

+ 2 - 3
http/handle.go

@@ -7,12 +7,11 @@ import (
 	"net/http"
 	"strings"
 
-	"google.golang.org/protobuf/types/known/emptypb"
-
 	"github.com/go-kratos/kratos/v2/encoding"
 	"github.com/go-kratos/kratos/v2/errors"
-	"github.com/go-kratos/kratos/v2/transport/http/binding"
+	"google.golang.org/protobuf/types/known/emptypb"
 
+	"git.ikuban.com/server/kratos-utils/http/binding"
 	"git.ikuban.com/server/kratos-utils/http/encoding/json"
 	_ "github.com/go-kratos/kratos/v2/encoding/proto"
 )