| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112 |
- package handler
- import (
- "context"
- "encoding/json"
- "git.ikuban.com/server/kratos-utils/v2/http/middleware"
- "github.com/go-kratos/kratos/v2/transport/http"
- "google.golang.org/grpc"
- "google.golang.org/grpc/metadata"
- )
- func serverStreamHandler(srv any, stream grpc.StreamDesc, option *middleware.Option) http.HandlerFunc {
- return func(ctx http.Context) error {
- http.SetOperation(ctx, option.Path)
- dec := func(in any) error {
- if err := ctx.Bind(&in); err != nil {
- return err
- }
- if err := ctx.BindQuery(&in); err != nil {
- return err
- }
- return nil
- }
- httpCtx := ctx
- h := ctx.Middleware(func(ctx context.Context, _ interface{}) (interface{}, error) {
- err := stream.Handler(srv, newServerStream(ctx, dec, httpCtx.Response()))
- return nil, err
- })
- newCtx := middleware.NewOptionContext(ctx, option)
- _, err := h(newCtx, nil)
- if err != nil {
- return err
- }
- return nil
- }
- }
- var _ grpc.ServerStream = (*serverStream)(nil)
- type serverStream struct {
- ctx context.Context
- dec func(interface{}) error
- w http.ResponseWriter
- metadata metadata.MD
- }
- func newServerStream(ctx context.Context, dec func(interface{}) error, w http.ResponseWriter) grpc.ServerStream {
- s := &serverStream{
- ctx: ctx,
- dec: dec,
- w: w,
- }
- return s
- }
- // 设置header
- func (s *serverStream) SetHeader(md metadata.MD) error {
- s.metadata = md
- return nil
- }
- // 发送header (只调一次)
- func (s *serverStream) SendHeader(md metadata.MD) error {
- for k, v := range s.metadata {
- if len(v) == 0 {
- continue
- }
- s.w.Header().Set(k, v[0])
- }
- return nil
- }
- func (s *serverStream) SetTrailer(md metadata.MD) {
- return
- }
- func (s *serverStream) Context() context.Context {
- return s.ctx
- }
- func (s *serverStream) SendMsg(m interface{}) error {
- switch data := m.(type) {
- case []byte:
- _, err := s.w.Write(data)
- if err != nil {
- return err
- }
- default:
- j, err := json.Marshal(data)
- if err != nil {
- return err
- }
- _, err = s.w.Write(append(j, byte('\n')))
- if err != nil {
- return err
- }
- }
- s.w.(http.Flusher).Flush()
- return nil
- }
- func (s *serverStream) RecvMsg(m any) error {
- return s.dec(m)
- }
|