package handler import ( "context" "encoding/json" "errors" "strings" "git.ikuban.com/server/kratos-utils/v2/transport/middleware" "github.com/go-kratos/kratos/v2/transport/http" "google.golang.org/grpc" "google.golang.org/grpc/metadata" ) func serverStreamHandler(srv any, stream grpc.StreamDesc, option *middleware.Option) http.HandlerFunc { return func(ctx http.Context) error { http.SetOperation(ctx, option.Path) dec := func(in any) error { if err := ctx.Bind(&in); err != nil { return err } if err := ctx.BindQuery(&in); err != nil { return err } return nil } httpCtx := ctx h := ctx.Middleware(func(ctx context.Context, _ interface{}) (interface{}, error) { err := stream.Handler(srv, newServerStream(ctx, dec, httpCtx.Response())) return nil, err }) newCtx := middleware.NewOptionContext(ctx, option) _, err := h(newCtx, nil) if err != nil { return err } return nil } } var _ grpc.ServerStream = (*serverStream)(nil) type serverStream struct { ctx context.Context dec func(interface{}) error w http.ResponseWriter metadata metadata.MD isSendHeader bool } func newServerStream(ctx context.Context, dec func(interface{}) error, w http.ResponseWriter) grpc.ServerStream { s := &serverStream{ ctx: ctx, dec: dec, w: w, } return s } // 设置header,SendHeader主动发送或者第一次SendMsg时发送 func (s *serverStream) SetHeader(md metadata.MD) error { if md == nil { return nil } s.metadata = metadata.Join(s.metadata, md) return nil } // 主动发送header (只调一次) func (s *serverStream) SendHeader(md metadata.MD) error { if s.isSendHeader { return errors.New("header has been sent") } s.isSendHeader = true if err := s.SetHeader(md); err != nil { return err } for k, v := range s.metadata { if len(v) == 0 { continue } s.w.Header().Set(k, strings.Join(v, "; ")) } return nil } func (s *serverStream) SetTrailer(md metadata.MD) { return } func (s *serverStream) Context() context.Context { return s.ctx } func (s *serverStream) SendMsg(m interface{}) error { if !s.isSendHeader { err := s.SendHeader(nil) if err != nil { return err } } switch data := m.(type) { case []byte: _, err := s.w.Write(data) if err != nil { return err } default: j, err := json.Marshal(data) if err != nil { return err } _, err = s.w.Write(append(j, byte('\n'))) if err != nil { return err } } s.w.(http.Flusher).Flush() return nil } func (s *serverStream) RecvMsg(m any) error { return s.dec(m) }