Sfoglia il codice sorgente

feat(http): 优化 serverStreamHandler 中的 header 设置和发送逻辑- 增加 isSendHeader 标志位,用于跟踪 header 是否已发送
- 重构 SetHeader 方法,支持合并元数据- 优化 SendHeader 方法,确保 header 只发送一次
- 在 SendMsg 方法中检查是否需要发送 header
- 修复了一些潜在的错误和逻辑问题

lihf 5 mesi fa
parent
commit
6433d9f9c4
1 ha cambiato i file con 30 aggiunte e 8 eliminazioni
  1. 30 8
      http/handler/server_stream_handler.go

+ 30 - 8
http/handler/server_stream_handler.go

@@ -3,10 +3,12 @@ package handler
 import (
 	"context"
 	"encoding/json"
+	"errors"
 	"git.ikuban.com/server/kratos-utils/v2/http/middleware"
 	"github.com/go-kratos/kratos/v2/transport/http"
 	"google.golang.org/grpc"
 	"google.golang.org/grpc/metadata"
+	"strings"
 )
 
 func serverStreamHandler(srv any, stream grpc.StreamDesc, option *middleware.Option) http.HandlerFunc {
@@ -42,10 +44,11 @@ func serverStreamHandler(srv any, stream grpc.StreamDesc, option *middleware.Opt
 var _ grpc.ServerStream = (*serverStream)(nil)
 
 type serverStream struct {
-	ctx      context.Context
-	dec      func(interface{}) error
-	w        http.ResponseWriter
-	metadata metadata.MD
+	ctx          context.Context
+	dec          func(interface{}) error
+	w            http.ResponseWriter
+	metadata     metadata.MD
+	isSendHeader bool
 }
 
 func newServerStream(ctx context.Context, dec func(interface{}) error, w http.ResponseWriter) grpc.ServerStream {
@@ -58,19 +61,32 @@ func newServerStream(ctx context.Context, dec func(interface{}) error, w http.Re
 	return s
 }
 
-// 设置header
+// 设置header,SendHeader主动发送或者第一次SendMsg时发送
 func (s *serverStream) SetHeader(md metadata.MD) error {
-	s.metadata = md
+	if md == nil {
+		return nil
+	}
+
+	s.metadata = metadata.Join(s.metadata, md)
 	return nil
 }
 
-// 发送header (只调一次)
+// 主动发送header (只调一次)
 func (s *serverStream) SendHeader(md metadata.MD) error {
+	if s.isSendHeader {
+		return errors.New("header has been sent")
+	}
+	s.isSendHeader = true
+
+	if err := s.SetHeader(md); err != nil {
+		return err
+	}
+
 	for k, v := range s.metadata {
 		if len(v) == 0 {
 			continue
 		}
-		s.w.Header().Set(k, v[0])
+		s.w.Header().Set(k, strings.Join(v, "; "))
 	}
 	return nil
 }
@@ -84,6 +100,12 @@ func (s *serverStream) Context() context.Context {
 }
 
 func (s *serverStream) SendMsg(m interface{}) error {
+	if !s.isSendHeader {
+		err := s.SendHeader(nil)
+		if err != nil {
+			return err
+		}
+	}
 
 	switch data := m.(type) {
 	case []byte: