| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 | package handlerimport (	"context"	"errors"	"strings"	context2 "git.ikuban.com/server/kratos-utils/v2/transport/http/context"	"git.ikuban.com/server/kratos-utils/v2/transport/middleware"	"github.com/go-kratos/kratos/v2/transport/http"	"google.golang.org/grpc"	"google.golang.org/grpc/metadata")func serverStreamHandler(srv any, stream grpc.StreamDesc, option *middleware.Option) http.HandlerFunc {	return func(ctx http.Context) error {		http.SetOperation(ctx, option.Path)		dec := func(in any) error {			if err := ctx.Bind(&in); err != nil {				return err			}			if err := ctx.BindQuery(&in); err != nil {				return err			}			return nil		}		httpCtx := ctx		h := ctx.Middleware(func(ctx context.Context, _ interface{}) (interface{}, error) {			streamCtx := context2.NewStreamContext(ctx)			err := stream.Handler(srv, newServerStream(streamCtx, dec, httpCtx.Response()))			return nil, err		})		newCtx := middleware.NewOptionContext(ctx, option)		_, err := h(newCtx, nil)		if err != nil {			return err		}		return nil	}}var _ grpc.ServerStream = (*serverStream)(nil)type serverStream struct {	ctx          context2.StreamContext	dec          func(interface{}) error	w            http.ResponseWriter	metadata     metadata.MD	isSendHeader bool}func newServerStream(ctx context2.StreamContext, dec func(interface{}) error, w http.ResponseWriter) grpc.ServerStream {	s := &serverStream{		ctx: ctx,		dec: dec,		w:   w,	}	return s}// 设置header,SendHeader主动发送或者第一次SendMsg时发送func (s *serverStream) SetHeader(md metadata.MD) error {	if md == nil {		return nil	}	s.metadata = metadata.Join(s.metadata, md)	return nil}// 主动发送header (只调一次)func (s *serverStream) SendHeader(md metadata.MD) error {	if s.isSendHeader {		return errors.New("header has been sent")	}	s.isSendHeader = true	if err := s.SetHeader(md); err != nil {		return err	}	// 没有指定响应头则设置默认响应头	if s.metadata.Len() == 0 {		err := s.SetHeader(metadata.MD{			"Content-Type":      []string{"text/event-stream"},			"Cache-Control":     []string{"no-cache"},			"Connection":        []string{"keep-alive"},			"X-Accel-Buffering": []string{"no"},		})		if err != nil {			return err		}	}	for k, v := range s.metadata {		if len(v) == 0 {			continue		}		s.w.Header().Set(k, strings.Join(v, "; "))	}	return nil}func (s *serverStream) SetTrailer(md metadata.MD) {	return}func (s *serverStream) Context() context.Context {	return s.ctx}func (s *serverStream) SendMsg(m interface{}) error {	if !s.isSendHeader {		// 设置流式响应 headers		err := s.SendHeader(nil)		if err != nil {			return err		}	}	switch data := m.(type) {	case []byte:		_, err := s.w.Write(data)		if err != nil {			return err		}	case string:		_, err := s.w.Write([]byte(data))		if err != nil {			return err		}	default:		// 使用自定义序列化逻辑		serialized, err := s.ctx.Serialize(data)		if err != nil {			return err		}		_, err = s.w.Write(append(serialized, byte('\n')))		if err != nil {			return err		}	}	s.w.(http.Flusher).Flush()	return nil}func (s *serverStream) RecvMsg(m any) error {	return s.dec(m)}
 |