package http import ( "bytes" json2 "encoding/json" "io/ioutil" "net/http" "strings" "github.com/go-kratos/kratos/v2/errors" "google.golang.org/protobuf/encoding/protojson" "github.com/go-kratos/kratos/v2/encoding" "google.golang.org/protobuf/types/known/emptypb" "git.ikuban.com/server/kratos-utils/http/binding" "git.ikuban.com/server/kratos-utils/http/encoding/json" _ "github.com/go-kratos/kratos/v2/encoding/proto" ) // decodeRequest decodes the request body to object. func DecodeRequest(req *http.Request, v interface{}) error { method := strings.ToUpper(req.Method) if method == "POST" || method == "PUT" || method == "DELETE" { contextType := req.Header.Get(ContentTypeHeader) if strings.HasPrefix(contextType, "multipart/form-data") { return parseForm(req, v) } if _, ok := v.(*emptypb.Empty); ok { return binding.BindForm(req, v) } subtype := contentSubtype(contextType) if codec := encoding.GetCodec(subtype); codec != nil { data, err := ioutil.ReadAll(req.Body) if err != nil { return err } return codec.Unmarshal(data, v) } } return binding.BindForm(req, v) } func parseForm(req *http.Request, v interface{}) error { err := req.ParseMultipartForm(32 << 20) if err != nil { return err } if req.MultipartForm == nil { return nil } value := make(map[string]interface{}) if req.MultipartForm.File != nil { for k1, v1 := range req.MultipartForm.File { f, err := v1[0].Open() if err != nil { return err } var buf bytes.Buffer _, err = buf.ReadFrom(f) if err != nil { return err } value[k1] = buf.Bytes() value[k1+"Filename"] = v1[0].Filename } } if req.MultipartForm.Value != nil { for k1, v1 := range req.MultipartForm.Value { value[k1] = v1 } } j, err := json2.Marshal(value) if err != nil { return err } err = json2.Unmarshal(j, v) return err } // encodeResponse encodes the object to the HTTP response. func EncodeResponse(w http.ResponseWriter, r *http.Request, v interface{}) error { codec := codecForRequest(r) data, err := codec.Marshal(v) if err != nil { return err } w.Header().Set(ContentTypeHeader, contentType(codec.Name())) _, _ = w.Write(data) return nil } func ErrHandle(w http.ResponseWriter, r *http.Request, err error) { st := errors.FromError(err) if st == nil { st = errors.New(10500, "", "", err.Error()) } status := st.GRPCStatus() code := status.Proto().Code if code == 1000 { return } w.Header().Set(ContentTypeHeader, "application/json; charset=utf-8") if code == 0 { w.WriteHeader(200) } else if code >= 301 && code <= 307 { w.WriteHeader(int(code)) http.Redirect(w, r, status.Message(), int(code)) return } else if code == 401 { w.WriteHeader(401) } else { if code < 10000 { code = 10000 + code } if code < 10100 && status.Message() == "" { status.Proto().Message = "系统错误" } w.WriteHeader(400) } data, err := protojson.Marshal(status.Proto()) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } w.Write(data) //se := errors.FromError(err) //if !ok { // se = &errors.StatusError{ // Code: 10500, // Message: err.Error(), // } //} //if se.Code == -1 { // return //} //codec := codecForRequest(r) //w.Header().Set(ContentTypeHeader, contentType(codec.Name())) //if se.Code == 0 { // w.WriteHeader(200) //} else if se.Code >= 301 && se.Code <= 307 { // w.WriteHeader(int(se.Code)) // http.Redirect(w, r, se.Message, int(se.Code)) // return //} else { // if se.Code < 10000 { // se.Code = 10000 + se.Code // } // if se.Code < 10100 && se.Message == "" { // se.Message = "系统错误" // } // w.WriteHeader(400) //} //data, _ := codec.Marshal(se) //_, _ = w.Write(data) } const baseContentType = "application" var ( acceptHeader = http.CanonicalHeaderKey("Accept") ContentTypeHeader = http.CanonicalHeaderKey("Content-Type") ) func contentType(subtype string) string { return strings.Join([]string{baseContentType, subtype}, "/") } // codecForRequest get encoding.Codec via http.Request func codecForRequest(r *http.Request) encoding.Codec { var codec encoding.Codec for _, accept := range r.Header[acceptHeader] { codeName := contentSubtype(accept) if codeName == "json" { codec = encoding.GetCodec(json.Name) break } if codec = encoding.GetCodec(codeName); codec != nil { break } } if codec == nil { codec = encoding.GetCodec(json.Name) } return codec } func contentSubtype(contentType string) string { if contentType == baseContentType { return "" } if !strings.HasPrefix(contentType, baseContentType) { return "" } switch contentType[len(baseContentType)] { case '/', ';': if i := strings.Index(contentType, ";"); i != -1 { return contentType[len(baseContentType)+1 : i] } return contentType[len(baseContentType)+1:] default: return "" } }