| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135 |
- package handler
- import (
- "context"
- "encoding/json"
- "errors"
- "strings"
- "git.ikuban.com/server/kratos-utils/v2/transport/middleware"
- "github.com/go-kratos/kratos/v2/transport/http"
- "google.golang.org/grpc"
- "google.golang.org/grpc/metadata"
- )
- func serverStreamHandler(srv any, stream grpc.StreamDesc, option *middleware.Option) http.HandlerFunc {
- return func(ctx http.Context) error {
- http.SetOperation(ctx, option.Path)
- dec := func(in any) error {
- if err := ctx.Bind(&in); err != nil {
- return err
- }
- if err := ctx.BindQuery(&in); err != nil {
- return err
- }
- return nil
- }
- httpCtx := ctx
- h := ctx.Middleware(func(ctx context.Context, _ interface{}) (interface{}, error) {
- err := stream.Handler(srv, newServerStream(ctx, dec, httpCtx.Response()))
- return nil, err
- })
- newCtx := middleware.NewOptionContext(ctx, option)
- _, err := h(newCtx, nil)
- if err != nil {
- return err
- }
- return nil
- }
- }
- var _ grpc.ServerStream = (*serverStream)(nil)
- type serverStream struct {
- ctx context.Context
- dec func(interface{}) error
- w http.ResponseWriter
- metadata metadata.MD
- isSendHeader bool
- }
- func newServerStream(ctx context.Context, dec func(interface{}) error, w http.ResponseWriter) grpc.ServerStream {
- s := &serverStream{
- ctx: ctx,
- dec: dec,
- w: w,
- }
- return s
- }
- // 设置header,SendHeader主动发送或者第一次SendMsg时发送
- func (s *serverStream) SetHeader(md metadata.MD) error {
- if md == nil {
- return nil
- }
- s.metadata = metadata.Join(s.metadata, md)
- return nil
- }
- // 主动发送header (只调一次)
- func (s *serverStream) SendHeader(md metadata.MD) error {
- if s.isSendHeader {
- return errors.New("header has been sent")
- }
- s.isSendHeader = true
- if err := s.SetHeader(md); err != nil {
- return err
- }
- for k, v := range s.metadata {
- if len(v) == 0 {
- continue
- }
- s.w.Header().Set(k, strings.Join(v, "; "))
- }
- return nil
- }
- func (s *serverStream) SetTrailer(md metadata.MD) {
- return
- }
- func (s *serverStream) Context() context.Context {
- return s.ctx
- }
- func (s *serverStream) SendMsg(m interface{}) error {
- if !s.isSendHeader {
- err := s.SendHeader(nil)
- if err != nil {
- return err
- }
- }
- switch data := m.(type) {
- case []byte:
- _, err := s.w.Write(data)
- if err != nil {
- return err
- }
- default:
- j, err := json.Marshal(data)
- if err != nil {
- return err
- }
- _, err = s.w.Write(append(j, byte('\n')))
- if err != nil {
- return err
- }
- }
- s.w.(http.Flusher).Flush()
- return nil
- }
- func (s *serverStream) RecvMsg(m any) error {
- return s.dec(m)
- }
|