server.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. package mcp
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "net/http"
  7. "net/url"
  8. "github.com/go-kratos/kratos/v2/log"
  9. "github.com/go-kratos/kratos/v2/transport"
  10. "github.com/mark3labs/mcp-go/server"
  11. )
  12. var (
  13. _ transport.Server = (*Server)(nil)
  14. _ transport.Endpointer = (*Server)(nil)
  15. _ http.Handler = (*Server)(nil)
  16. )
  17. // MiddlewareFunc is a function that takes an http.Handler and returns an http.Handler.
  18. type MiddlewareFunc func(http.Handler) http.Handler
  19. // ServerOption is an HTTP server option.
  20. type ServerOption func(*Server)
  21. // Address with server address.
  22. func Address(addr string) ServerOption {
  23. return func(s *Server) {
  24. s.address = addr
  25. }
  26. }
  27. // Endpoint with server address.
  28. func Endpoint(endpoint *url.URL) ServerOption {
  29. return func(s *Server) {
  30. s.endpoint = endpoint
  31. }
  32. }
  33. // Middleware with server middleware.
  34. func Middleware(m MiddlewareFunc) ServerOption {
  35. return func(s *Server) {
  36. s.middleware = m
  37. }
  38. }
  39. // SrvOptions with server options.
  40. func SrvOptions(opts ...server.ServerOption) ServerOption {
  41. return func(s *Server) {
  42. s.srvOpts = append(s.srvOpts, opts...)
  43. }
  44. }
  45. func StreamableHTTPOptions(opts ...server.StreamableHTTPOption) ServerOption {
  46. return func(s *Server) {
  47. s.streamableHTTPOpts = append(s.streamableHTTPOpts, opts...)
  48. }
  49. }
  50. // Server is a MCP server.
  51. type Server struct {
  52. *server.MCPServer
  53. srv *http.Server
  54. streamableHttpServer *server.StreamableHTTPServer
  55. middleware MiddlewareFunc
  56. address string
  57. endpoint *url.URL
  58. srvOpts []server.ServerOption
  59. streamableHTTPOpts []server.StreamableHTTPOption
  60. mux *http.ServeMux
  61. }
  62. // NewServer creates a new MCP server.
  63. func NewServer(name, version string, opts ...ServerOption) *Server {
  64. srv := &Server{
  65. middleware: func(next http.Handler) http.Handler { return next },
  66. }
  67. for _, o := range opts {
  68. o(srv)
  69. }
  70. if srv.mux == nil {
  71. srv.mux = http.NewServeMux()
  72. }
  73. srv.MCPServer = server.NewMCPServer(name, version, srv.srvOpts...)
  74. if srv.srv == nil {
  75. srv.srv = &http.Server{}
  76. }
  77. srv.streamableHttpServer = server.NewStreamableHTTPServer(srv.MCPServer, append(srv.streamableHTTPOpts, server.WithStreamableHTTPServer(srv.srv))...)
  78. srv.mux.Handle("/mcp", srv.streamableHttpServer)
  79. srv.srv.Handler = srv.middleware(srv.mux)
  80. return srv
  81. }
  82. // ServeHTTP implements the http.Handler interface.
  83. func (s *Server) ServeHTTP(res http.ResponseWriter, req *http.Request) {
  84. s.streamableHttpServer.ServeHTTP(res, req)
  85. }
  86. // Endpoint return a real address to registry endpoint.
  87. // examples:
  88. // - http://127.0.0.1:8000
  89. func (s *Server) Endpoint() (*url.URL, error) {
  90. if s.endpoint != nil {
  91. return s.endpoint, nil
  92. }
  93. return url.Parse(fmt.Sprintf("http://%s", s.address))
  94. }
  95. // Start start the MCP server.
  96. func (s *Server) Start(_ context.Context) error {
  97. log.Infof("[MCP] server listening on: %s", s.address)
  98. if err := s.streamableHttpServer.Start(s.address); err != nil {
  99. if !errors.Is(err, http.ErrServerClosed) {
  100. return err
  101. }
  102. }
  103. return nil
  104. }
  105. // Stop stop the MCP server.
  106. func (s *Server) Stop(ctx context.Context) error {
  107. defer func() {
  108. log.Info("[MCP] server stopping")
  109. }()
  110. err := s.streamableHttpServer.Shutdown(ctx)
  111. if err != nil {
  112. if ctx.Err() != nil {
  113. log.Warn("[MCP] server couldn't stop gracefully in time, doing force stop")
  114. err = s.srv.Close()
  115. }
  116. }
  117. return err
  118. }