tools.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. package mcp
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "strings"
  7. annotations2 "git.ikuban.com/server/kubanapis/kuban/api/annotations"
  8. openapi_v3 "github.com/google/gnostic/openapiv3"
  9. mcp2 "github.com/mark3labs/mcp-go/mcp"
  10. "github.com/mark3labs/mcp-go/server"
  11. "google.golang.org/genproto/googleapis/api/annotations"
  12. "google.golang.org/grpc"
  13. "google.golang.org/protobuf/proto"
  14. "google.golang.org/protobuf/reflect/protoreflect"
  15. "google.golang.org/protobuf/reflect/protoregistry"
  16. )
  17. func ServerAddTools(s *server.MCPServer, srv any, svcDesc grpc.ServiceDesc) error {
  18. serviceName := strings.ReplaceAll(svcDesc.ServiceName, ".", "_")
  19. handlerMap := make(map[string]grpc.MethodDesc)
  20. for _, _v := range svcDesc.Methods {
  21. v := _v
  22. mapK := serviceName + "_" + v.MethodName
  23. handlerMap[mapK] = v
  24. }
  25. d, err := protoregistry.GlobalFiles.FindFileByPath(svcDesc.Metadata.(string))
  26. if err != nil {
  27. return err
  28. }
  29. if d.Services().Len() == 0 {
  30. return nil
  31. }
  32. ser := d.Services().Get(0)
  33. for j := 0; j < ser.Methods().Len(); j++ {
  34. method := ser.Methods().Get(j)
  35. t, h := serverAddToolsByMethod(serviceName, srv, method, handlerMap)
  36. s.AddTool(*t, h)
  37. }
  38. return nil
  39. }
  40. func serverAddToolsByMethod(serviceName string, srv any, method protoreflect.MethodDescriptor, handlerMap map[string]grpc.MethodDesc) (*mcp2.Tool, server.ToolHandlerFunc) {
  41. methodMcpOpts, _ := proto.GetExtension(method.Options(), annotations2.E_Options).(*annotations2.McpOptions)
  42. if methodMcpOpts == nil || !methodMcpOpts.Enabled {
  43. return nil, nil
  44. }
  45. methodOperation, _ := proto.GetExtension(method.Options(), openapi_v3.E_Operation).(*openapi_v3.Operation)
  46. description := ""
  47. if methodOperation != nil {
  48. description = methodOperation.Description
  49. if description == "" {
  50. description = methodOperation.Summary
  51. }
  52. }
  53. toolOptions := []mcp2.ToolOption{mcp2.WithDescription(description)}
  54. for k := 0; k < method.Input().Fields().Len(); k++ {
  55. input := method.Input().Fields().Get(k)
  56. inputOperation, _ := proto.GetExtension(input.Options(), openapi_v3.E_Property).(*openapi_v3.Schema)
  57. inputOperation2, _ := proto.GetExtension(input.Options(), annotations.E_FieldBehavior).([]annotations.FieldBehavior)
  58. inputDescription := ""
  59. if inputOperation != nil {
  60. inputDescription = inputOperation.GetDescription()
  61. }
  62. propertyOption := []mcp2.PropertyOption{mcp2.Description(inputDescription)}
  63. if inputOperation2 != nil && len(inputOperation2) > 0 && inputOperation2[0] == annotations.FieldBehavior_REQUIRED {
  64. propertyOption = append(propertyOption, mcp2.Required())
  65. }
  66. switch input.Kind() {
  67. case protoreflect.StringKind:
  68. toolOptions = append(toolOptions, mcp2.WithString(string(input.Name()), propertyOption...))
  69. case protoreflect.BoolKind:
  70. toolOptions = append(toolOptions, mcp2.WithBoolean(string(input.Name()), propertyOption...))
  71. case protoreflect.DoubleKind, protoreflect.FloatKind,
  72. protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind,
  73. protoreflect.Fixed64Kind, protoreflect.Fixed32Kind,
  74. protoreflect.Sint64Kind, protoreflect.Sint32Kind,
  75. protoreflect.Uint64Kind, protoreflect.Uint32Kind,
  76. protoreflect.Int64Kind, protoreflect.Int32Kind:
  77. toolOptions = append(toolOptions, mcp2.WithNumber(string(input.Name()), propertyOption...))
  78. case protoreflect.MessageKind:
  79. propertyOption = append(propertyOption, mcp2.Properties(getFiledMessageParamProperties(input.Message())))
  80. toolOptions = append(toolOptions, mcp2.WithObject(string(input.Name()), propertyOption...))
  81. }
  82. }
  83. toolName := serviceName + "_" + string(method.Name())
  84. t := mcp2.NewTool(toolName, toolOptions...)
  85. h := func(ctx context.Context, request mcp2.CallToolRequest) (*mcp2.CallToolResult, error) {
  86. if _, ok := handlerMap[toolName]; !ok {
  87. return nil, errors.New("没有实现")
  88. }
  89. arg := request.GetArguments()
  90. argJson, _ := json.Marshal(arg)
  91. dec := func(in any) error {
  92. decErr := json.Unmarshal(argJson, &in)
  93. if decErr != nil {
  94. return decErr
  95. }
  96. return nil
  97. }
  98. interceptor := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
  99. return handler(ctx, req)
  100. }
  101. handler := handlerMap[toolName]
  102. out, outErr := handler.Handler(srv, ctx, dec, interceptor)
  103. if outErr != nil {
  104. return nil, outErr
  105. }
  106. outJson, _ := json.Marshal(out)
  107. callToolResult := &mcp2.CallToolResult{
  108. Content: []mcp2.Content{
  109. mcp2.TextContent{
  110. Type: "text",
  111. Text: string(outJson),
  112. },
  113. },
  114. }
  115. return callToolResult, nil
  116. }
  117. return &t, h
  118. }
  119. func getFiledMessageParamProperties(message protoreflect.MessageDescriptor) map[string]any {
  120. messageParamMap := make(map[string]any)
  121. for i := 0; i < message.Fields().Len(); i++ {
  122. input := message.Fields().Get(i)
  123. paramMap := make(map[string]any)
  124. inputOperation, _ := proto.GetExtension(input.Options(), openapi_v3.E_Property).(*openapi_v3.Schema)
  125. inputDescription := ""
  126. if inputOperation != nil {
  127. inputDescription = inputOperation.GetDescription()
  128. }
  129. paramMap["description"] = inputDescription
  130. switch input.Kind() {
  131. case protoreflect.StringKind:
  132. paramMap["type"] = "string"
  133. case protoreflect.BoolKind:
  134. paramMap["type"] = "boolean"
  135. case protoreflect.DoubleKind, protoreflect.FloatKind,
  136. protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind,
  137. protoreflect.Fixed64Kind, protoreflect.Fixed32Kind,
  138. protoreflect.Sint64Kind, protoreflect.Sint32Kind,
  139. protoreflect.Uint64Kind, protoreflect.Uint32Kind,
  140. protoreflect.Int64Kind, protoreflect.Int32Kind:
  141. paramMap["type"] = "number"
  142. default:
  143. break
  144. }
  145. messageParamMap[string(input.Name())] = paramMap
  146. }
  147. return messageParamMap
  148. }