| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 |
- package handler
- import (
- "context"
- "errors"
- "strings"
- context2 "git.ikuban.com/server/kratos-utils/v2/transport/http/context"
- "git.ikuban.com/server/kratos-utils/v2/transport/middleware"
- "github.com/go-kratos/kratos/v2/transport/http"
- "google.golang.org/grpc"
- "google.golang.org/grpc/metadata"
- )
- func serverStreamHandler(srv any, stream grpc.StreamDesc, option *middleware.Option) http.HandlerFunc {
- return func(ctx http.Context) error {
- http.SetOperation(ctx, option.Path)
- dec := func(in any) error {
- if err := ctx.Bind(&in); err != nil {
- return err
- }
- if err := ctx.BindQuery(&in); err != nil {
- return err
- }
- return nil
- }
- httpCtx := ctx
- h := ctx.Middleware(func(ctx context.Context, _ interface{}) (interface{}, error) {
- streamCtx := context2.NewStreamContext(ctx)
- err := stream.Handler(srv, newServerStream(streamCtx, dec, httpCtx.Response()))
- return nil, err
- })
- newCtx := middleware.NewOptionContext(ctx, option)
- _, err := h(newCtx, nil)
- if err != nil {
- return err
- }
- return nil
- }
- }
- var _ grpc.ServerStream = (*serverStream)(nil)
- type serverStream struct {
- ctx context2.StreamContext
- dec func(interface{}) error
- w http.ResponseWriter
- metadata metadata.MD
- isSendHeader bool
- }
- func newServerStream(ctx context2.StreamContext, dec func(interface{}) error, w http.ResponseWriter) grpc.ServerStream {
- s := &serverStream{
- ctx: ctx,
- dec: dec,
- w: w,
- }
- return s
- }
- // 设置header,SendHeader主动发送或者第一次SendMsg时发送
- func (s *serverStream) SetHeader(md metadata.MD) error {
- if md == nil {
- return nil
- }
- s.metadata = metadata.Join(s.metadata, md)
- return nil
- }
- // 主动发送header (只调一次)
- func (s *serverStream) SendHeader(md metadata.MD) error {
- if s.isSendHeader {
- return errors.New("header has been sent")
- }
- s.isSendHeader = true
- if err := s.SetHeader(md); err != nil {
- return err
- }
- // 没有指定响应头则设置默认响应头
- if s.metadata.Len() == 0 {
- err := s.SetHeader(metadata.MD{
- "Content-Type": []string{"text/event-stream"},
- "Cache-Control": []string{"no-cache"},
- "Connection": []string{"keep-alive"},
- "X-Accel-Buffering": []string{"no"},
- })
- if err != nil {
- return err
- }
- }
- for k, v := range s.metadata {
- if len(v) == 0 {
- continue
- }
- s.w.Header().Set(k, strings.Join(v, "; "))
- }
- return nil
- }
- func (s *serverStream) SetTrailer(md metadata.MD) {
- return
- }
- func (s *serverStream) Context() context.Context {
- return s.ctx
- }
- func (s *serverStream) SendMsg(m interface{}) error {
- if !s.isSendHeader {
- // 设置流式响应 headers
- err := s.SendHeader(nil)
- if err != nil {
- return err
- }
- }
- switch data := m.(type) {
- case []byte:
- _, err := s.w.Write(data)
- if err != nil {
- return err
- }
- case string:
- _, err := s.w.Write([]byte(data))
- if err != nil {
- return err
- }
- default:
- // 使用自定义序列化逻辑
- serialized, err := s.ctx.Serialize(data)
- if err != nil {
- return err
- }
- _, err = s.w.Write(append(serialized, byte('\n')))
- if err != nil {
- return err
- }
- }
- s.w.(http.Flusher).Flush()
- return nil
- }
- func (s *serverStream) RecvMsg(m any) error {
- return s.dec(m)
- }
|