| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136 |
- package mcp
- import (
- "context"
- "errors"
- "fmt"
- "net/http"
- "net/url"
- "github.com/go-kratos/kratos/v2/log"
- "github.com/go-kratos/kratos/v2/transport"
- "github.com/mark3labs/mcp-go/server"
- )
- var (
- _ transport.Server = (*Server)(nil)
- _ transport.Endpointer = (*Server)(nil)
- _ http.Handler = (*Server)(nil)
- )
- // MiddlewareFunc is a function that takes an http.Handler and returns an http.Handler.
- type MiddlewareFunc func(http.Handler) http.Handler
- // ServerOption is an HTTP server option.
- type ServerOption func(*Server)
- // Address with server address.
- func Address(addr string) ServerOption {
- return func(s *Server) {
- s.address = addr
- }
- }
- // Endpoint with server address.
- func Endpoint(endpoint *url.URL) ServerOption {
- return func(s *Server) {
- s.endpoint = endpoint
- }
- }
- // Middleware with server middleware.
- func Middleware(m MiddlewareFunc) ServerOption {
- return func(s *Server) {
- s.middleware = m
- }
- }
- // SrvOptions with server options.
- func SrvOptions(opts ...server.ServerOption) ServerOption {
- return func(s *Server) {
- s.srvOpts = append(s.srvOpts, opts...)
- }
- }
- func StreamableHTTPOptions(opts ...server.StreamableHTTPOption) ServerOption {
- return func(s *Server) {
- s.streamableHTTPOpts = append(s.streamableHTTPOpts, opts...)
- }
- }
- // Server is a MCP server.
- type Server struct {
- *server.MCPServer
- srv *http.Server
- streamableHttpServer *server.StreamableHTTPServer
- middleware MiddlewareFunc
- address string
- endpoint *url.URL
- srvOpts []server.ServerOption
- streamableHTTPOpts []server.StreamableHTTPOption
- mux *http.ServeMux
- }
- // NewServer creates a new MCP server.
- func NewServer(name, version string, opts ...ServerOption) *Server {
- srv := &Server{
- middleware: func(next http.Handler) http.Handler { return next },
- }
- for _, o := range opts {
- o(srv)
- }
- if srv.mux == nil {
- srv.mux = http.NewServeMux()
- }
- srv.MCPServer = server.NewMCPServer(name, version, srv.srvOpts...)
- if srv.srv == nil {
- srv.srv = &http.Server{}
- }
- srv.streamableHttpServer = server.NewStreamableHTTPServer(srv.MCPServer, append(srv.streamableHTTPOpts, server.WithStreamableHTTPServer(srv.srv))...)
- srv.mux.Handle("/mcp", srv.streamableHttpServer)
- srv.srv.Handler = srv.middleware(srv.mux)
- return srv
- }
- // ServeHTTP implements the http.Handler interface.
- func (s *Server) ServeHTTP(res http.ResponseWriter, req *http.Request) {
- s.streamableHttpServer.ServeHTTP(res, req)
- }
- // Endpoint return a real address to registry endpoint.
- // examples:
- // - http://127.0.0.1:8000
- func (s *Server) Endpoint() (*url.URL, error) {
- if s.endpoint != nil {
- return s.endpoint, nil
- }
- return url.Parse(fmt.Sprintf("http://%s", s.address))
- }
- // Start start the MCP server.
- func (s *Server) Start(_ context.Context) error {
- log.Infof("[MCP] server listening on: %s", s.address)
- if err := s.streamableHttpServer.Start(s.address); err != nil {
- if !errors.Is(err, http.ErrServerClosed) {
- return err
- }
- }
- return nil
- }
- // Stop stop the MCP server.
- func (s *Server) Stop(ctx context.Context) error {
- defer func() {
- log.Info("[MCP] server stopping")
- }()
- err := s.streamableHttpServer.Shutdown(ctx)
- if err != nil {
- if ctx.Err() != nil {
- log.Warn("[MCP] server couldn't stop gracefully in time, doing force stop")
- err = s.srv.Close()
- }
- }
- return err
- }
|