Browse Source

新增context的param解析

dcsunny 4 years ago
parent
commit
58fc46cefc
6 changed files with 85 additions and 1 deletions
  1. 1 1
      http/encoding/jsonpb/json.go
  2. 24 0
      http/handle.go
  3. 7 0
      http/middleware/rpc_value.go
  4. 30 0
      http/param/auth.go
  5. 14 0
      http/param/base.go
  6. 9 0
      http/param/page.go

+ 1 - 1
http/encoding/jsonpb/json.go

@@ -13,7 +13,7 @@ import (
 )
 
 // Name is the name registered for the json codec.
-const Name = "jsonpb"
+const Name = "json"
 
 var jsonpbMarshaler *jsonpb.Marshaler
 

+ 24 - 0
http/handle.go

@@ -1,9 +1,14 @@
 package http
 
 import (
+	"io/ioutil"
 	"net/http"
 	"strings"
 
+	"google.golang.org/protobuf/types/known/emptypb"
+
+	"github.com/go-kratos/kratos/v2/transport/http/binding"
+
 	"github.com/go-kratos/kratos/v2/encoding"
 	"github.com/go-kratos/kratos/v2/errors"
 
@@ -11,6 +16,25 @@ import (
 	_ "github.com/go-kratos/kratos/v2/transport"
 )
 
+// decodeRequest decodes the request body to object.
+func DecodeRequest(req *http.Request, v interface{}) error {
+	method := strings.ToUpper(req.Method)
+	if method == "POST" || method == "PUT" || method == "DELETE" {
+		if _, ok := v.(*emptypb.Empty); ok {
+			return binding.BindForm(req, v)
+		}
+		subtype := contentSubtype(req.Header.Get(ContentTypeHeader))
+		if codec := encoding.GetCodec(subtype); codec != nil {
+			data, err := ioutil.ReadAll(req.Body)
+			if err != nil {
+				return err
+			}
+			return codec.Unmarshal(data, v)
+		}
+	}
+	return binding.BindForm(req, v)
+}
+
 func ErrHandle(w http.ResponseWriter, r *http.Request, err error) {
 	se, ok := errors.FromError(err)
 	if !ok {

+ 7 - 0
http/middleware/rpc_value.go

@@ -3,6 +3,7 @@ package middleware
 import (
 	"context"
 	"encoding/json"
+	"net/url"
 	"strconv"
 
 	"google.golang.org/grpc/metadata"
@@ -40,6 +41,12 @@ func GrpcValue(handler middleware.Handler) middleware.Handler {
 				_pageSize, _ := strconv.ParseInt(pageSize[0], 10, 64)
 				ctx = context.WithValue(ctx, "page_size", _pageSize)
 			}
+
+			query := md.Get("query")
+			if len(query) > 0 {
+				_query, _ := url.ParseQuery(query[0])
+				ctx = context.WithValue(ctx, "query", _query)
+			}
 		}
 		return handler(ctx, req)
 	}

+ 30 - 0
http/param/auth.go

@@ -0,0 +1,30 @@
+package param
+
+import (
+	"context"
+	"encoding/json"
+)
+
+func GetUserID(c context.Context) int64 {
+	userID := c.Value("user_id")
+	if _, ok := userID.(int64); ok {
+		return userID.(int64)
+	}
+	return 0
+}
+
+func GetAuthToken(c context.Context) string {
+	token := c.Value("token")
+	if _, ok := token.(string); ok {
+		return token.(string)
+	}
+	return ""
+}
+
+func GetJwtClaims(c context.Context) json.RawMessage {
+	claim := c.Value("claim")
+	if _, ok := claim.(json.RawMessage); ok {
+		return claim.(json.RawMessage)
+	}
+	return []byte{}
+}

+ 14 - 0
http/param/base.go

@@ -0,0 +1,14 @@
+package param
+
+import (
+	"context"
+	"net/url"
+)
+
+func GetQuery(c context.Context) url.Values {
+	query := c.Value("query")
+	if _, ok := query.(url.Values); ok {
+		return query.(url.Values)
+	}
+	return url.Values{}
+}

+ 9 - 0
http/param/page.go

@@ -0,0 +1,9 @@
+package param
+
+import "context"
+
+func GetPage(c context.Context) (int64, int64) {
+	page := c.Value("page")
+	pageSize := c.Value("page_size")
+	return page.(int64), pageSize.(int64)
+}