|
|
@@ -3,10 +3,12 @@ package handler
|
|
|
import (
|
|
|
"context"
|
|
|
"encoding/json"
|
|
|
+ "errors"
|
|
|
"git.ikuban.com/server/kratos-utils/v2/http/middleware"
|
|
|
"github.com/go-kratos/kratos/v2/transport/http"
|
|
|
"google.golang.org/grpc"
|
|
|
"google.golang.org/grpc/metadata"
|
|
|
+ "strings"
|
|
|
)
|
|
|
|
|
|
func serverStreamHandler(srv any, stream grpc.StreamDesc, option *middleware.Option) http.HandlerFunc {
|
|
|
@@ -42,10 +44,11 @@ func serverStreamHandler(srv any, stream grpc.StreamDesc, option *middleware.Opt
|
|
|
var _ grpc.ServerStream = (*serverStream)(nil)
|
|
|
|
|
|
type serverStream struct {
|
|
|
- ctx context.Context
|
|
|
- dec func(interface{}) error
|
|
|
- w http.ResponseWriter
|
|
|
- metadata metadata.MD
|
|
|
+ ctx context.Context
|
|
|
+ dec func(interface{}) error
|
|
|
+ w http.ResponseWriter
|
|
|
+ metadata metadata.MD
|
|
|
+ isSendHeader bool
|
|
|
}
|
|
|
|
|
|
func newServerStream(ctx context.Context, dec func(interface{}) error, w http.ResponseWriter) grpc.ServerStream {
|
|
|
@@ -58,19 +61,32 @@ func newServerStream(ctx context.Context, dec func(interface{}) error, w http.Re
|
|
|
return s
|
|
|
}
|
|
|
|
|
|
-// 设置header
|
|
|
+// 设置header,SendHeader主动发送或者第一次SendMsg时发送
|
|
|
func (s *serverStream) SetHeader(md metadata.MD) error {
|
|
|
- s.metadata = md
|
|
|
+ if md == nil {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ s.metadata = metadata.Join(s.metadata, md)
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-// 发送header (只调一次)
|
|
|
+// 主动发送header (只调一次)
|
|
|
func (s *serverStream) SendHeader(md metadata.MD) error {
|
|
|
+ if s.isSendHeader {
|
|
|
+ return errors.New("header has been sent")
|
|
|
+ }
|
|
|
+ s.isSendHeader = true
|
|
|
+
|
|
|
+ if err := s.SetHeader(md); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
for k, v := range s.metadata {
|
|
|
if len(v) == 0 {
|
|
|
continue
|
|
|
}
|
|
|
- s.w.Header().Set(k, v[0])
|
|
|
+ s.w.Header().Set(k, strings.Join(v, "; "))
|
|
|
}
|
|
|
return nil
|
|
|
}
|
|
|
@@ -84,6 +100,12 @@ func (s *serverStream) Context() context.Context {
|
|
|
}
|
|
|
|
|
|
func (s *serverStream) SendMsg(m interface{}) error {
|
|
|
+ if !s.isSendHeader {
|
|
|
+ err := s.SendHeader(nil)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
switch data := m.(type) {
|
|
|
case []byte:
|