package handler import ( "context" "errors" "strings" context2 "git.ikuban.com/server/kratos-utils/v2/transport/http/context" "git.ikuban.com/server/kratos-utils/v2/transport/middleware" "github.com/go-kratos/kratos/v2/transport/http" "google.golang.org/grpc" "google.golang.org/grpc/metadata" ) func serverStreamHandler(srv any, stream grpc.StreamDesc, option *middleware.Option) http.HandlerFunc { return func(ctx http.Context) error { http.SetOperation(ctx, option.Path) dec := func(in any) error { if err := ctx.Bind(&in); err != nil { return err } if err := ctx.BindQuery(&in); err != nil { return err } return nil } httpCtx := ctx h := ctx.Middleware(func(ctx context.Context, _ interface{}) (interface{}, error) { streamCtx := context2.NewStreamContext(ctx) err := stream.Handler(srv, newServerStream(streamCtx, dec, httpCtx.Response())) return nil, err }) newCtx := middleware.NewOptionContext(ctx, option) _, err := h(newCtx, nil) if err != nil { return err } return nil } } var _ grpc.ServerStream = (*serverStream)(nil) type serverStream struct { ctx context2.StreamContext dec func(interface{}) error w http.ResponseWriter metadata metadata.MD isSendHeader bool } func newServerStream(ctx context2.StreamContext, dec func(interface{}) error, w http.ResponseWriter) grpc.ServerStream { s := &serverStream{ ctx: ctx, dec: dec, w: w, } return s } // 设置header,SendHeader主动发送或者第一次SendMsg时发送 func (s *serverStream) SetHeader(md metadata.MD) error { if md == nil { return nil } s.metadata = metadata.Join(s.metadata, md) return nil } // 主动发送header (只调一次) func (s *serverStream) SendHeader(md metadata.MD) error { if s.isSendHeader { return errors.New("header has been sent") } s.isSendHeader = true if err := s.SetHeader(md); err != nil { return err } // 没有指定响应头则设置默认响应头 if s.metadata.Len() == 0 { err := s.SetHeader(metadata.MD{ "Content-Type": []string{"text/event-stream"}, "Cache-Control": []string{"no-cache"}, "Connection": []string{"keep-alive"}, "X-Accel-Buffering": []string{"no"}, }) if err != nil { return err } } for k, v := range s.metadata { if len(v) == 0 { continue } s.w.Header().Set(k, strings.Join(v, "; ")) } return nil } func (s *serverStream) SetTrailer(md metadata.MD) { return } func (s *serverStream) Context() context.Context { return s.ctx } func (s *serverStream) SendMsg(m interface{}) error { if !s.isSendHeader { // 设置流式响应 headers err := s.SendHeader(nil) if err != nil { return err } } switch data := m.(type) { case []byte: _, err := s.w.Write(data) if err != nil { return err } case string: _, err := s.w.Write([]byte(data)) if err != nil { return err } default: // 使用自定义序列化逻辑 serialized, err := s.ctx.Serialize(data) if err != nil { return err } _, err = s.w.Write(append(serialized, byte('\n'))) if err != nil { return err } } s.w.(http.Flusher).Flush() return nil } func (s *serverStream) RecvMsg(m any) error { return s.dec(m) }