|
@@ -0,0 +1,71 @@
|
|
|
|
|
+package middleware
|
|
|
|
|
+
|
|
|
|
|
+import (
|
|
|
|
|
+ "context"
|
|
|
|
|
+ "strconv"
|
|
|
|
|
+ "strings"
|
|
|
|
|
+
|
|
|
|
|
+ "git.ikuban.com/server/kratos-utils/common"
|
|
|
|
|
+
|
|
|
|
|
+ "github.com/go-kratos/kratos/v2/errors"
|
|
|
|
|
+ "github.com/go-kratos/kratos/v2/middleware"
|
|
|
|
|
+ "github.com/go-kratos/kratos/v2/transport/http"
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+const (
|
|
|
|
|
+ Bearer = "Bearer "
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+var (
|
|
|
|
|
+ nowAuthURI = make(map[string]bool)
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+func Auth(handler middleware.Handler) middleware.Handler {
|
|
|
|
|
+ return func(ctx context.Context, req interface{}) (reply interface{}, err error) {
|
|
|
|
|
+ if info, ok := http.FromServerContext(ctx); ok {
|
|
|
|
|
+ uri := info.Request.RequestURI
|
|
|
|
|
+ if _, ok1 := nowAuthURI[uri]; ok1 {
|
|
|
|
|
+ return handler(ctx, req)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ token := info.Request.Header.Get("Authorization")
|
|
|
|
|
+ if token == "" {
|
|
|
|
|
+ return nil, unauthorized()
|
|
|
|
|
+ }
|
|
|
|
|
+ if strings.Contains(token, Bearer) {
|
|
|
|
|
+ token = strings.Replace(token, Bearer, "", 1)
|
|
|
|
|
+ } else {
|
|
|
|
|
+ return nil, unauthorized()
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ uid, _, claimMap, err := common.DefaultJWT.Parse(token)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, unauthorized()
|
|
|
|
|
+ }
|
|
|
|
|
+ var userID int64
|
|
|
|
|
+ if uid != "" {
|
|
|
|
|
+ userID, _ = strconv.ParseInt(uid, 10, 64)
|
|
|
|
|
+ }
|
|
|
|
|
+ if userID <= 0 {
|
|
|
|
|
+ return nil, unauthorized()
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ ctx = context.WithValue(ctx, "user_id", userID)
|
|
|
|
|
+ ctx = context.WithValue(ctx, "jwt_claims", claimMap)
|
|
|
|
|
+ ctx = context.WithValue(ctx, "auth_token", token)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return handler(ctx, req)
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func unauthorized() error {
|
|
|
|
|
+ return &errors.StatusError{
|
|
|
|
|
+ Code: 401,
|
|
|
|
|
+ Message: "Unauthorized",
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func AddNotAuthURI(r string) {
|
|
|
|
|
+ nowAuthURI[r] = true
|
|
|
|
|
+}
|