tools.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. package mcp
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "strings"
  7. annotations2 "git.ikuban.com/server/kubanapis/kuban/api/annotations"
  8. openapi_v3 "github.com/google/gnostic/openapiv3"
  9. mcp2 "github.com/mark3labs/mcp-go/mcp"
  10. "github.com/mark3labs/mcp-go/server"
  11. "google.golang.org/genproto/googleapis/api/annotations"
  12. "google.golang.org/grpc"
  13. "google.golang.org/protobuf/proto"
  14. "google.golang.org/protobuf/reflect/protoreflect"
  15. "google.golang.org/protobuf/reflect/protoregistry"
  16. )
  17. func ServerAddTools(s *server.MCPServer, srv any, svcDesc grpc.ServiceDesc) error {
  18. serviceName := strings.ReplaceAll(svcDesc.ServiceName, ".", "_")
  19. handlerMap := make(map[string]grpc.MethodDesc)
  20. for _, _v := range svcDesc.Methods {
  21. v := _v
  22. mapK := serviceName + "_" + v.MethodName
  23. handlerMap[mapK] = v
  24. }
  25. d, err := protoregistry.GlobalFiles.FindFileByPath(svcDesc.Metadata.(string))
  26. if err != nil {
  27. return err
  28. }
  29. if d.Services().Len() == 0 {
  30. return nil
  31. }
  32. ser := d.Services().Get(0)
  33. for j := 0; j < ser.Methods().Len(); j++ {
  34. method := ser.Methods().Get(j)
  35. t, h := serverAddToolsByMethod(serviceName, srv, method, handlerMap)
  36. if t == nil || h == nil {
  37. continue
  38. }
  39. s.AddTool(*t, h)
  40. }
  41. return nil
  42. }
  43. func serverAddToolsByMethod(serviceName string, srv any, method protoreflect.MethodDescriptor, handlerMap map[string]grpc.MethodDesc) (*mcp2.Tool, server.ToolHandlerFunc) {
  44. methodMcpOpts, _ := proto.GetExtension(method.Options(), annotations2.E_Options).(*annotations2.Options)
  45. if methodMcpOpts == nil || methodMcpOpts.McpOptions == nil || !methodMcpOpts.GetMcpOptions().Enabled {
  46. return nil, nil
  47. }
  48. methodOperation, _ := proto.GetExtension(method.Options(), openapi_v3.E_Operation).(*openapi_v3.Operation)
  49. description := ""
  50. if methodOperation != nil {
  51. description = methodOperation.Description
  52. if description == "" {
  53. description = methodOperation.Summary
  54. }
  55. }
  56. toolOptions := []mcp2.ToolOption{mcp2.WithDescription(description)}
  57. for k := 0; k < method.Input().Fields().Len(); k++ {
  58. input := method.Input().Fields().Get(k)
  59. inputOperation, _ := proto.GetExtension(input.Options(), openapi_v3.E_Property).(*openapi_v3.Schema)
  60. inputOperation2, _ := proto.GetExtension(input.Options(), annotations.E_FieldBehavior).([]annotations.FieldBehavior)
  61. inputDescription := ""
  62. if inputOperation != nil {
  63. inputDescription = inputOperation.GetDescription()
  64. }
  65. propertyOption := []mcp2.PropertyOption{mcp2.Description(inputDescription)}
  66. if inputOperation2 != nil && len(inputOperation2) > 0 && inputOperation2[0] == annotations.FieldBehavior_REQUIRED {
  67. propertyOption = append(propertyOption, mcp2.Required())
  68. }
  69. switch input.Kind() {
  70. case protoreflect.StringKind:
  71. toolOptions = append(toolOptions, mcp2.WithString(string(input.Name()), propertyOption...))
  72. case protoreflect.BoolKind:
  73. toolOptions = append(toolOptions, mcp2.WithBoolean(string(input.Name()), propertyOption...))
  74. case protoreflect.DoubleKind, protoreflect.FloatKind,
  75. protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind,
  76. protoreflect.Fixed64Kind, protoreflect.Fixed32Kind,
  77. protoreflect.Sint64Kind, protoreflect.Sint32Kind,
  78. protoreflect.Uint64Kind, protoreflect.Uint32Kind,
  79. protoreflect.Int64Kind, protoreflect.Int32Kind:
  80. toolOptions = append(toolOptions, mcp2.WithNumber(string(input.Name()), propertyOption...))
  81. case protoreflect.MessageKind:
  82. propertyOption = append(propertyOption, mcp2.Properties(getFiledMessageParamProperties(input.Message())))
  83. toolOptions = append(toolOptions, mcp2.WithObject(string(input.Name()), propertyOption...))
  84. }
  85. }
  86. toolName := serviceName + "_" + string(method.Name())
  87. t := mcp2.NewTool(toolName, toolOptions...)
  88. h := func(ctx context.Context, request mcp2.CallToolRequest) (*mcp2.CallToolResult, error) {
  89. if _, ok := handlerMap[toolName]; !ok {
  90. return nil, errors.New("没有实现")
  91. }
  92. arg := request.GetArguments()
  93. argJson, _ := json.Marshal(arg)
  94. dec := func(in any) error {
  95. decErr := json.Unmarshal(argJson, &in)
  96. if decErr != nil {
  97. return decErr
  98. }
  99. return nil
  100. }
  101. interceptor := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
  102. return handler(ctx, req)
  103. }
  104. handler := handlerMap[toolName]
  105. out, outErr := handler.Handler(srv, ctx, dec, interceptor)
  106. if outErr != nil {
  107. return nil, outErr
  108. }
  109. outJson, _ := json.Marshal(out)
  110. callToolResult := &mcp2.CallToolResult{
  111. Content: []mcp2.Content{
  112. mcp2.TextContent{
  113. Type: "text",
  114. Text: string(outJson),
  115. },
  116. },
  117. }
  118. return callToolResult, nil
  119. }
  120. return &t, h
  121. }
  122. func getFiledMessageParamProperties(message protoreflect.MessageDescriptor) map[string]any {
  123. messageParamMap := make(map[string]any)
  124. for i := 0; i < message.Fields().Len(); i++ {
  125. input := message.Fields().Get(i)
  126. paramMap := make(map[string]any)
  127. inputOperation, _ := proto.GetExtension(input.Options(), openapi_v3.E_Property).(*openapi_v3.Schema)
  128. inputDescription := ""
  129. if inputOperation != nil {
  130. inputDescription = inputOperation.GetDescription()
  131. }
  132. paramMap["description"] = inputDescription
  133. switch input.Kind() {
  134. case protoreflect.StringKind:
  135. paramMap["type"] = "string"
  136. case protoreflect.BoolKind:
  137. paramMap["type"] = "boolean"
  138. case protoreflect.DoubleKind, protoreflect.FloatKind,
  139. protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind,
  140. protoreflect.Fixed64Kind, protoreflect.Fixed32Kind,
  141. protoreflect.Sint64Kind, protoreflect.Sint32Kind,
  142. protoreflect.Uint64Kind, protoreflect.Uint32Kind,
  143. protoreflect.Int64Kind, protoreflect.Int32Kind:
  144. paramMap["type"] = "number"
  145. default:
  146. break
  147. }
  148. messageParamMap[string(input.Name())] = paramMap
  149. }
  150. return messageParamMap
  151. }