فهرست منبع

feat(transport): 引入流式上下文并支持自定义序列化

- 新增 StreamContext 接口及其实现,用于管理流式传输上下文
- 支持通过 SetSerializer 设置全局自定义序列化器
- 在 serverStream 中使用 StreamContext 替代原生 context
-优化 SendMsg 方法,支持字符串直接写入和自定义序列化逻辑
- 发送流式响应时自动设置 SSE 相关 HTTP 头部信息
- 移除对 encoding/json 包的直接依赖,改用封装后的序列化方法
dcsunny 1 هفته پیش
والد
کامیت
b8d20985d6
2فایلهای تغییر یافته به همراه61 افزوده شده و 10 حذف شده
  1. 41 0
      transport/http/context/stream_context.go
  2. 20 10
      transport/http/handler/server_stream_handler.go

+ 41 - 0
transport/http/context/stream_context.go

@@ -0,0 +1,41 @@
+package context
+
+import (
+	"context"
+	"encoding/json"
+)
+
+type StreamContext interface {
+	context.Context
+	Serialize(any) ([]byte, error)
+}
+
+type streamContext struct {
+	context.Context
+}
+
+type Serialize func(any) ([]byte, error)
+
+var _ StreamContext = (*streamContext)(nil)
+
+var (
+	// serialize 是全局的自定义序列化器
+	serialize Serialize
+)
+
+func (s *streamContext) Serialize(v any) ([]byte, error) {
+	if serialize != nil {
+		return serialize(v)
+	}
+	return json.Marshal(v)
+}
+
+func NewStreamContext(ctx context.Context) StreamContext {
+	return &streamContext{
+		Context: ctx,
+	}
+}
+
+func SetSerializer(s Serialize) {
+	serialize = s
+}

+ 20 - 10
transport/http/handler/server_stream_handler.go

@@ -2,10 +2,10 @@ package handler
 
 import (
 	"context"
-	"encoding/json"
 	"errors"
 	"strings"
 
+	context2 "git.ikuban.com/server/kratos-utils/v2/transport/http/context"
 	"git.ikuban.com/server/kratos-utils/v2/transport/middleware"
 	"github.com/go-kratos/kratos/v2/transport/http"
 	"google.golang.org/grpc"
@@ -28,7 +28,8 @@ func serverStreamHandler(srv any, stream grpc.StreamDesc, option *middleware.Opt
 
 		httpCtx := ctx
 		h := ctx.Middleware(func(ctx context.Context, _ interface{}) (interface{}, error) {
-			err := stream.Handler(srv, newServerStream(ctx, dec, httpCtx.Response()))
+			streamCtx := context2.NewStreamContext(ctx)
+			err := stream.Handler(srv, newServerStream(streamCtx, dec, httpCtx.Response()))
 			return nil, err
 		})
 
@@ -45,14 +46,14 @@ func serverStreamHandler(srv any, stream grpc.StreamDesc, option *middleware.Opt
 var _ grpc.ServerStream = (*serverStream)(nil)
 
 type serverStream struct {
-	ctx          context.Context
+	ctx          context2.StreamContext
 	dec          func(interface{}) error
 	w            http.ResponseWriter
 	metadata     metadata.MD
 	isSendHeader bool
 }
 
-func newServerStream(ctx context.Context, dec func(interface{}) error, w http.ResponseWriter) grpc.ServerStream {
+func newServerStream(ctx context2.StreamContext, dec func(interface{}) error, w http.ResponseWriter) grpc.ServerStream {
 	s := &serverStream{
 		ctx: ctx,
 		dec: dec,
@@ -102,30 +103,39 @@ func (s *serverStream) Context() context.Context {
 
 func (s *serverStream) SendMsg(m interface{}) error {
 	if !s.isSendHeader {
-		err := s.SendHeader(nil)
+		// 设置流式响应 headers
+		err := s.SendHeader(metadata.MD{
+			"Content-Type":      []string{"text/event-stream"},
+			"Cache-Control":     []string{"no-cache"},
+			"Connection":        []string{"keep-alive"},
+			"X-Accel-Buffering": []string{"no"},
+		})
 		if err != nil {
 			return err
 		}
 	}
-
 	switch data := m.(type) {
 	case []byte:
 		_, err := s.w.Write(data)
 		if err != nil {
 			return err
 		}
-
+	case string:
+		_, err := s.w.Write([]byte(data))
+		if err != nil {
+			return err
+		}
 	default:
-		j, err := json.Marshal(data)
+		// 使用自定义序列化逻辑
+		serialized, err := s.ctx.Serialize(data)
 		if err != nil {
 			return err
 		}
-		_, err = s.w.Write(append(j, byte('\n')))
+		_, err = s.w.Write(append(serialized, byte('\n')))
 		if err != nil {
 			return err
 		}
 	}
-
 	s.w.(http.Flusher).Flush()
 	return nil
 }