tools.go 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. package mcp
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "strings"
  8. annotations2 "git.ikuban.com/server/kubanapis/kuban/api/annotations"
  9. openapi_v3 "github.com/google/gnostic/openapiv3"
  10. mcp2 "github.com/mark3labs/mcp-go/mcp"
  11. "github.com/mark3labs/mcp-go/server"
  12. "google.golang.org/genproto/googleapis/api/annotations"
  13. "google.golang.org/grpc"
  14. "google.golang.org/protobuf/proto"
  15. "google.golang.org/protobuf/reflect/protoreflect"
  16. "google.golang.org/protobuf/reflect/protoregistry"
  17. )
  18. func ServerAddTools(s *server.MCPServer, srv any, svcDesc grpc.ServiceDesc) error {
  19. serviceName := strings.ReplaceAll(svcDesc.ServiceName, ".", "_")
  20. handlerMap := make(map[string]grpc.MethodDesc)
  21. for _, _v := range svcDesc.Methods {
  22. v := _v
  23. mapK := serviceName + "_" + v.MethodName
  24. handlerMap[mapK] = v
  25. }
  26. d, err := protoregistry.GlobalFiles.FindFileByPath(svcDesc.Metadata.(string))
  27. if err != nil {
  28. return err
  29. }
  30. if d.Services().Len() == 0 {
  31. return nil
  32. }
  33. ser := d.Services().Get(0)
  34. for j := 0; j < ser.Methods().Len(); j++ {
  35. method := ser.Methods().Get(j)
  36. t, h := serverAddToolsByMethod(serviceName, srv, method, handlerMap)
  37. if t == nil || h == nil {
  38. continue
  39. }
  40. s.AddTool(*t, h)
  41. }
  42. return nil
  43. }
  44. func serverAddToolsByMethod(serviceName string, srv any, method protoreflect.MethodDescriptor, handlerMap map[string]grpc.MethodDesc) (*mcp2.Tool, server.ToolHandlerFunc) {
  45. methodMcpOpts, _ := proto.GetExtension(method.Options(), annotations2.E_Options).(*annotations2.Options)
  46. if methodMcpOpts == nil || methodMcpOpts.McpOptions == nil || !methodMcpOpts.GetMcpOptions().Enabled {
  47. return nil, nil
  48. }
  49. methodOperation, _ := proto.GetExtension(method.Options(), openapi_v3.E_Operation).(*openapi_v3.Operation)
  50. description := ""
  51. if methodOperation != nil {
  52. description = methodOperation.Description
  53. if description == "" {
  54. description = methodOperation.Summary
  55. }
  56. }
  57. toolOptions := []mcp2.ToolOption{mcp2.WithDescription(description)}
  58. for k := 0; k < method.Input().Fields().Len(); k++ {
  59. input := method.Input().Fields().Get(k)
  60. inputOperation, _ := proto.GetExtension(input.Options(), openapi_v3.E_Property).(*openapi_v3.Schema)
  61. inputOperation2, _ := proto.GetExtension(input.Options(), annotations.E_FieldBehavior).([]annotations.FieldBehavior)
  62. inputDescription := ""
  63. if inputOperation != nil {
  64. inputDescription = inputOperation.GetDescription()
  65. }
  66. propertyOption := make([]mcp2.PropertyOption, 0)
  67. if inputOperation2 != nil && len(inputOperation2) > 0 && inputOperation2[0] == annotations.FieldBehavior_REQUIRED {
  68. propertyOption = append(propertyOption, mcp2.Required())
  69. }
  70. if input.IsList() {
  71. propertyOption = append(propertyOption, mcp2.Items(getInputArrayItems(input)))
  72. propertyOption = append(propertyOption, mcp2.Description(inputDescription))
  73. toolOptions = append(toolOptions, mcp2.WithArray(string(input.Name()), propertyOption...))
  74. } else if input.IsMap() {
  75. additionalProperties, descriptionSuffix := getInputMapProperties(input)
  76. inputDescription = inputDescription + descriptionSuffix
  77. propertyOption = append(propertyOption, mcp2.Description(inputDescription))
  78. propertyOption = append(propertyOption, mcp2.AdditionalProperties(additionalProperties))
  79. toolOptions = append(toolOptions, mcp2.WithObject(string(input.Name()), propertyOption...))
  80. } else {
  81. propertyOption = append(propertyOption, mcp2.Description(inputDescription))
  82. switch input.Kind() {
  83. case protoreflect.StringKind:
  84. toolOptions = append(toolOptions, mcp2.WithString(string(input.Name()), propertyOption...))
  85. case protoreflect.BoolKind:
  86. toolOptions = append(toolOptions, mcp2.WithBoolean(string(input.Name()), propertyOption...))
  87. case protoreflect.DoubleKind, protoreflect.FloatKind,
  88. protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind,
  89. protoreflect.Fixed64Kind, protoreflect.Fixed32Kind,
  90. protoreflect.Sint64Kind, protoreflect.Sint32Kind,
  91. protoreflect.Uint64Kind, protoreflect.Uint32Kind,
  92. protoreflect.Int64Kind, protoreflect.Int32Kind:
  93. toolOptions = append(toolOptions, mcp2.WithNumber(string(input.Name()), propertyOption...))
  94. case protoreflect.MessageKind:
  95. propertyOption = append(propertyOption, mcp2.Properties(getFiledMessageParamProperties(input.Message(), false)))
  96. toolOptions = append(toolOptions, mcp2.WithObject(string(input.Name()), propertyOption...))
  97. }
  98. }
  99. }
  100. toolName := serviceName + "_" + string(method.Name())
  101. t := mcp2.NewTool(toolName, toolOptions...)
  102. h := func(ctx context.Context, request mcp2.CallToolRequest) (*mcp2.CallToolResult, error) {
  103. if _, ok := handlerMap[toolName]; !ok {
  104. return nil, errors.New("没有实现")
  105. }
  106. arg := request.GetArguments()
  107. argJson, _ := json.Marshal(arg)
  108. dec := func(in any) error {
  109. decErr := json.Unmarshal(argJson, &in)
  110. if decErr != nil {
  111. return decErr
  112. }
  113. return nil
  114. }
  115. interceptor := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
  116. return handler(ctx, req)
  117. }
  118. handler := handlerMap[toolName]
  119. out, outErr := handler.Handler(srv, ctx, dec, interceptor)
  120. if outErr != nil {
  121. return nil, outErr
  122. }
  123. outJson, _ := json.Marshal(out)
  124. callToolResult := &mcp2.CallToolResult{
  125. Content: []mcp2.Content{
  126. mcp2.TextContent{
  127. Type: "text",
  128. Text: string(outJson),
  129. },
  130. },
  131. }
  132. return callToolResult, nil
  133. }
  134. return &t, h
  135. }
  136. func getInputArrayItems(input protoreflect.FieldDescriptor) map[string]any {
  137. inputMap := make(map[string]any)
  138. switch input.Kind() {
  139. case protoreflect.StringKind:
  140. inputMap["type"] = "string"
  141. case protoreflect.BoolKind:
  142. inputMap["type"] = "boolean"
  143. case protoreflect.DoubleKind, protoreflect.FloatKind,
  144. protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind,
  145. protoreflect.Fixed64Kind, protoreflect.Fixed32Kind,
  146. protoreflect.Sint64Kind, protoreflect.Sint32Kind,
  147. protoreflect.Uint64Kind, protoreflect.Uint32Kind,
  148. protoreflect.Int64Kind, protoreflect.Int32Kind:
  149. inputMap["type"] = "number"
  150. case protoreflect.MessageKind:
  151. inputMap["type"] = "object"
  152. propertiesMap := getFiledMessageParamProperties(input.Message(), true)
  153. inputMap["properties"] = propertiesMap
  154. }
  155. return inputMap
  156. }
  157. func getInputMapProperties(input protoreflect.FieldDescriptor) (map[string]any, string) {
  158. messageParamMap := make(map[string]any)
  159. paramMap := make(map[string]any)
  160. switch input.MapValue().Kind() {
  161. case protoreflect.StringKind:
  162. paramMap["type"] = "string"
  163. case protoreflect.BoolKind:
  164. paramMap["type"] = "boolean"
  165. case protoreflect.DoubleKind, protoreflect.FloatKind,
  166. protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind,
  167. protoreflect.Fixed64Kind, protoreflect.Fixed32Kind,
  168. protoreflect.Sint64Kind, protoreflect.Sint32Kind,
  169. protoreflect.Uint64Kind, protoreflect.Uint32Kind,
  170. protoreflect.Int64Kind, protoreflect.Int32Kind:
  171. paramMap["type"] = "number"
  172. case protoreflect.MessageKind:
  173. paramMap["type"] = "object"
  174. paramMap["properties"] = getFiledMessageParamProperties(input.Message(), false)
  175. default:
  176. break
  177. }
  178. var keyType string
  179. switch input.MapKey().Kind() {
  180. case protoreflect.StringKind:
  181. keyType = "string"
  182. case protoreflect.BoolKind:
  183. keyType = "boolean"
  184. case protoreflect.DoubleKind, protoreflect.FloatKind,
  185. protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind,
  186. protoreflect.Fixed64Kind, protoreflect.Fixed32Kind,
  187. protoreflect.Sint64Kind, protoreflect.Sint32Kind,
  188. protoreflect.Uint64Kind, protoreflect.Uint32Kind,
  189. protoreflect.Int64Kind, protoreflect.Int32Kind:
  190. keyType = "number"
  191. }
  192. descriptionSuffix := fmt.Sprintf("(type of key: %s, type of value: %s)", keyType, paramMap["type"])
  193. messageParamMap[string(input.Name())] = paramMap
  194. return messageParamMap, descriptionSuffix
  195. }
  196. func getFiledMessageParamProperties(message protoreflect.MessageDescriptor, needRequired bool) map[string]any {
  197. messageParamMap := make(map[string]any)
  198. for i := 0; i < message.Fields().Len(); i++ {
  199. input := message.Fields().Get(i)
  200. paramMap := make(map[string]any)
  201. inputOperation, _ := proto.GetExtension(input.Options(), openapi_v3.E_Property).(*openapi_v3.Schema)
  202. inputDescription := ""
  203. if inputOperation != nil {
  204. inputDescription = inputOperation.GetDescription()
  205. }
  206. paramMap["description"] = inputDescription
  207. if needRequired {
  208. inputOperation2, _ := proto.GetExtension(input.Options(), annotations.E_FieldBehavior).([]annotations.FieldBehavior)
  209. if inputOperation2 != nil && len(inputOperation2) > 0 && inputOperation2[0] == annotations.FieldBehavior_REQUIRED {
  210. paramMap["required"] = true
  211. }
  212. }
  213. if input.IsList() {
  214. paramMap["type"] = "array"
  215. paramMap["items"] = getInputArrayItems(input)
  216. } else if input.IsMap() {
  217. paramMap["type"] = "object"
  218. var descriptionSuffix string
  219. paramMap["additionalProperties"], descriptionSuffix = getInputMapProperties(input)
  220. paramMap["description"] = inputDescription + descriptionSuffix
  221. } else {
  222. switch input.Kind() {
  223. case protoreflect.StringKind:
  224. paramMap["type"] = "string"
  225. case protoreflect.BoolKind:
  226. paramMap["type"] = "boolean"
  227. case protoreflect.DoubleKind, protoreflect.FloatKind,
  228. protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind,
  229. protoreflect.Fixed64Kind, protoreflect.Fixed32Kind,
  230. protoreflect.Sint64Kind, protoreflect.Sint32Kind,
  231. protoreflect.Uint64Kind, protoreflect.Uint32Kind,
  232. protoreflect.Int64Kind, protoreflect.Int32Kind:
  233. paramMap["type"] = "number"
  234. case protoreflect.MessageKind:
  235. paramMap["type"] = "object"
  236. paramMap["properties"] = getFiledMessageParamProperties(input.Message(), false)
  237. default:
  238. break
  239. }
  240. }
  241. messageParamMap[string(input.Name())] = paramMap
  242. }
  243. return messageParamMap
  244. }