server_stream_handler.go 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. package handler
  2. import (
  3. "context"
  4. "encoding/json"
  5. "git.ikuban.com/server/kratos-utils/v2/http/middleware"
  6. "github.com/go-kratos/kratos/v2/transport/http"
  7. "google.golang.org/grpc"
  8. "google.golang.org/grpc/metadata"
  9. )
  10. func serverStreamHandler(srv any, stream grpc.StreamDesc, option *middleware.Option) http.HandlerFunc {
  11. return func(ctx http.Context) error {
  12. http.SetOperation(ctx, option.Path)
  13. dec := func(in any) error {
  14. if err := ctx.Bind(&in); err != nil {
  15. return err
  16. }
  17. if err := ctx.BindQuery(&in); err != nil {
  18. return err
  19. }
  20. return nil
  21. }
  22. httpCtx := ctx
  23. h := ctx.Middleware(func(ctx context.Context, _ interface{}) (interface{}, error) {
  24. err := stream.Handler(srv, newServerStream(ctx, dec, httpCtx.Response()))
  25. return nil, err
  26. })
  27. newCtx := middleware.NewOptionContext(ctx, option)
  28. _, err := h(newCtx, nil)
  29. if err != nil {
  30. return err
  31. }
  32. return nil
  33. }
  34. }
  35. var _ grpc.ServerStream = (*serverStream)(nil)
  36. type serverStream struct {
  37. ctx context.Context
  38. dec func(interface{}) error
  39. w http.ResponseWriter
  40. metadata metadata.MD
  41. }
  42. func newServerStream(ctx context.Context, dec func(interface{}) error, w http.ResponseWriter) grpc.ServerStream {
  43. s := &serverStream{
  44. ctx: ctx,
  45. dec: dec,
  46. w: w,
  47. }
  48. return s
  49. }
  50. // 设置header
  51. func (s *serverStream) SetHeader(md metadata.MD) error {
  52. s.metadata = md
  53. return nil
  54. }
  55. // 发送header (只调一次)
  56. func (s *serverStream) SendHeader(md metadata.MD) error {
  57. for k, v := range s.metadata {
  58. if len(v) == 0 {
  59. continue
  60. }
  61. s.w.Header().Set(k, v[0])
  62. }
  63. return nil
  64. }
  65. func (s *serverStream) SetTrailer(md metadata.MD) {
  66. return
  67. }
  68. func (s *serverStream) Context() context.Context {
  69. return s.ctx
  70. }
  71. func (s *serverStream) SendMsg(m interface{}) error {
  72. switch data := m.(type) {
  73. case []byte:
  74. _, err := s.w.Write(data)
  75. if err != nil {
  76. return err
  77. }
  78. default:
  79. j, err := json.Marshal(data)
  80. if err != nil {
  81. return err
  82. }
  83. _, err = s.w.Write(append(j, byte('\n')))
  84. if err != nil {
  85. return err
  86. }
  87. }
  88. s.w.(http.Flusher).Flush()
  89. return nil
  90. }
  91. func (s *serverStream) RecvMsg(m any) error {
  92. return s.dec(m)
  93. }