package http import ( "bytes" json2 "encoding/json" "io" "net/http" "strconv" "strings" "git.ikuban.com/server/kratos-utils/v2/codes" "git.ikuban.com/server/kratos-utils/v2/transport/http/reply" "github.com/go-kratos/kratos/v2/errors" "github.com/go-kratos/kratos/v2/encoding" "google.golang.org/protobuf/types/known/emptypb" "git.ikuban.com/server/kratos-utils/v2/transport/http/binding" "git.ikuban.com/server/kratos-utils/v2/transport/http/encoding/json" _ "github.com/go-kratos/kratos/v2/encoding/proto" ) // DecodeRequest decodeRequest decodes the request body to object. func DecodeRequest(req *http.Request, v interface{}) error { if v == nil { return nil } method := strings.ToUpper(req.Method) if method == "POST" || method == "PUT" || method == "DELETE" { contextType := req.Header.Get(ContentTypeHeader) if strings.HasPrefix(contextType, "multipart/form-data") { return parseForm(req, v) } if _, ok := v.(*emptypb.Empty); ok { return binding.BindForm(req, v) } subtype := contentSubtype(contextType) if codec := encoding.GetCodec(subtype); codec != nil { data, err := io.ReadAll(req.Body) if err != nil { return err } return codec.Unmarshal(data, v) } } return binding.BindForm(req, v) } func parseForm(req *http.Request, v interface{}) error { err := req.ParseMultipartForm(32 << 20) if err != nil { return err } if req.MultipartForm == nil { return nil } value := make(map[string]interface{}) if req.MultipartForm.File != nil { for k1, v1 := range req.MultipartForm.File { f, err := v1[0].Open() if err != nil { return err } var buf bytes.Buffer _, err = buf.ReadFrom(f) if err != nil { return err } value[k1] = buf.Bytes() value[k1+"Filename"] = v1[0].Filename } } if req.MultipartForm.Value != nil { for k1, v1 := range req.MultipartForm.Value { value[k1] = v1[0] } } j, err := json2.Marshal(value) if err != nil { return err } err = json2.Unmarshal(j, v) return err } // EncodeResponse encodes the object to the HTTP response. func EncodeResponse(w http.ResponseWriter, r *http.Request, v interface{}) error { if v == nil { return nil } codec := codecForRequest(r) data, err := codec.Marshal(v) if err != nil { return err } w.Header().Set(ContentTypeHeader, contentType(codec.Name())) _, _ = w.Write(data) return nil } func ErrHandle(w http.ResponseWriter, r *http.Request, err error) { if err == nil { return } st := errors.FromError(err) if st == nil { st = codes.Error(10500, err.Error()).(*errors.Error) } message := st.Message metadata := st.GetMetadata() code := st.Code if _, ok := metadata["code"]; ok { _code, _ := strconv.Atoi(metadata["code"]) code = int32(_code) } if code == 1000 { return } w.Header().Set(ContentTypeHeader, "application/json; charset=utf-8") if code == 0 { w.WriteHeader(200) } else if code >= 301 && code <= 307 { w.WriteHeader(int(code)) http.Redirect(w, r, message, int(code)) return } else if code == 401 { w.WriteHeader(401) code = 10401 } else { if code < 10000 { code = 10000 + code } if code < 10100 && message == "" { message = "系统错误" } w.WriteHeader(400) } _reply := reply.SuccessReply{ Code: code, Message: message, Data: nil, } if st.Reason != "" { var _data interface{} json2.Unmarshal([]byte(st.Reason), &_data) _reply.Data = _data } data, err := json2.Marshal(_reply) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } w.Write(data) } const baseContentType = "application" var ( acceptHeader = http.CanonicalHeaderKey("Accept") ContentTypeHeader = http.CanonicalHeaderKey("Content-Type") ) func contentType(subtype string) string { return strings.Join([]string{baseContentType, subtype}, "/") } // codecForRequest get encoding.Codec via http.Request func codecForRequest(r *http.Request) encoding.Codec { var codec encoding.Codec for _, accept := range r.Header[acceptHeader] { codeName := contentSubtype(accept) if codeName == "json" { codec = encoding.GetCodec(json.Name) break } if codec = encoding.GetCodec(codeName); codec != nil { break } } if codec == nil { codec = encoding.GetCodec(json.Name) } return codec } func contentSubtype(contentType string) string { if contentType == baseContentType { return "" } if !strings.HasPrefix(contentType, baseContentType) { return "" } switch contentType[len(baseContentType)] { case '/', ';': if i := strings.Index(contentType, ";"); i != -1 { return contentType[len(baseContentType)+1 : i] } return contentType[len(baseContentType)+1:] default: return "" } } func GetBody(r *http.Request) []byte { b, _ := io.ReadAll(r.Body) return b }