services.go 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. package descriptor
  2. import (
  3. "fmt"
  4. "strings"
  5. "git.ikuban.com/server/swagger-api/protoc-gen-openapiv2/internal/httprule"
  6. "github.com/golang/glog"
  7. options "google.golang.org/genproto/googleapis/api/annotations"
  8. "google.golang.org/protobuf/proto"
  9. "google.golang.org/protobuf/types/descriptorpb"
  10. )
  11. // loadServices registers services and their methods from "targetFile" to "r".
  12. // It must be called after loadFile is called for all files so that loadServices
  13. // can resolve names of message types and their fields.
  14. func (r *Registry) loadServices(file *File) error {
  15. glog.V(1).Infof("Loading services from %s", file.GetName())
  16. var svcs []*Service
  17. for _, sd := range file.GetService() {
  18. glog.V(2).Infof("Registering %s", sd.GetName())
  19. svc := &Service{
  20. File: file,
  21. ServiceDescriptorProto: sd,
  22. ForcePrefixedName: r.standalone,
  23. }
  24. for _, md := range sd.GetMethod() {
  25. glog.V(2).Infof("Processing %s.%s", sd.GetName(), md.GetName())
  26. opts, err := extractAPIOptions(md)
  27. if err != nil {
  28. glog.Errorf("Failed to extract HttpRule from %s.%s: %v", svc.GetName(), md.GetName(), err)
  29. return err
  30. }
  31. var optsList []*options.HttpRule
  32. if r.generateRPCMethods {
  33. defaultOpts, err := defaultAPIOptions(svc, md)
  34. if err != nil {
  35. glog.Errorf("Failed to generate default HttpRule from %s.%s: %v", svc.GetName(), md.GetName(), err)
  36. return err
  37. }
  38. optsList = append(optsList, defaultOpts)
  39. } else {
  40. optsList = r.LookupExternalHTTPRules((&Method{Service: svc, MethodDescriptorProto: md}).FQMN())
  41. if opts != nil {
  42. optsList = append(optsList, opts)
  43. }
  44. }
  45. if len(optsList) == 0 {
  46. if r.generateUnboundMethods {
  47. defaultOpts, err := defaultAPIOptions(svc, md)
  48. if err != nil {
  49. glog.Errorf("Failed to generate default HttpRule from %s.%s: %v", svc.GetName(), md.GetName(), err)
  50. return err
  51. }
  52. optsList = append(optsList, defaultOpts)
  53. } else {
  54. logFn := glog.V(1).Infof
  55. if r.warnOnUnboundMethods {
  56. logFn = glog.Warningf
  57. }
  58. logFn("No HttpRule found for method: %s.%s", svc.GetName(), md.GetName())
  59. }
  60. }
  61. meth, err := r.newMethod(svc, md, optsList)
  62. if err != nil {
  63. return err
  64. }
  65. svc.Methods = append(svc.Methods, meth)
  66. }
  67. if len(svc.Methods) == 0 {
  68. continue
  69. }
  70. glog.V(2).Infof("Registered %s with %d method(s)", svc.GetName(), len(svc.Methods))
  71. svcs = append(svcs, svc)
  72. }
  73. file.Services = svcs
  74. return nil
  75. }
  76. func (r *Registry) newMethod(svc *Service, md *descriptorpb.MethodDescriptorProto, optsList []*options.HttpRule) (*Method, error) {
  77. requestType, err := r.LookupMsg(svc.File.GetPackage(), md.GetInputType())
  78. if err != nil {
  79. return nil, err
  80. }
  81. responseType, err := r.LookupMsg(svc.File.GetPackage(), md.GetOutputType())
  82. if err != nil {
  83. return nil, err
  84. }
  85. meth := &Method{
  86. Service: svc,
  87. MethodDescriptorProto: md,
  88. RequestType: requestType,
  89. ResponseType: responseType,
  90. }
  91. newBinding := func(opts *options.HttpRule, idx int) (*Binding, error) {
  92. var (
  93. httpMethod string
  94. pathTemplate string
  95. )
  96. switch {
  97. case opts.GetGet() != "":
  98. httpMethod = "GET"
  99. pathTemplate = opts.GetGet()
  100. if opts.Body != "" {
  101. return nil, fmt.Errorf("must not set request body when http method is GET: %s", md.GetName())
  102. }
  103. case opts.GetPut() != "":
  104. httpMethod = "PUT"
  105. pathTemplate = opts.GetPut()
  106. case opts.GetPost() != "":
  107. httpMethod = "POST"
  108. pathTemplate = opts.GetPost()
  109. case opts.GetDelete() != "":
  110. httpMethod = "DELETE"
  111. pathTemplate = opts.GetDelete()
  112. if opts.Body != "" && !r.allowDeleteBody {
  113. return nil, fmt.Errorf("must not set request body when http method is DELETE except allow_delete_body option is true: %s", md.GetName())
  114. }
  115. case opts.GetPatch() != "":
  116. httpMethod = "PATCH"
  117. pathTemplate = opts.GetPatch()
  118. case opts.GetCustom() != nil:
  119. custom := opts.GetCustom()
  120. httpMethod = custom.Kind
  121. pathTemplate = custom.Path
  122. default:
  123. glog.V(1).Infof("No pattern specified in google.api.HttpRule: %s", md.GetName())
  124. return nil, nil
  125. }
  126. parsed, err := httprule.Parse(pathTemplate)
  127. if err != nil {
  128. return nil, err
  129. }
  130. tmpl := parsed.Compile()
  131. if md.GetClientStreaming() && len(tmpl.Fields) > 0 {
  132. return nil, fmt.Errorf("cannot use path parameter in client streaming")
  133. }
  134. b := &Binding{
  135. Method: meth,
  136. Index: idx,
  137. PathTmpl: tmpl,
  138. HTTPMethod: httpMethod,
  139. }
  140. for _, f := range tmpl.Fields {
  141. param, err := r.newParam(meth, f)
  142. if err != nil {
  143. return nil, err
  144. }
  145. b.PathParams = append(b.PathParams, param)
  146. }
  147. // TODO(yugui) Handle query params
  148. b.Body, err = r.newBody(meth, opts.Body)
  149. if err != nil {
  150. return nil, err
  151. }
  152. b.ResponseBody, err = r.newResponse(meth, opts.ResponseBody)
  153. if err != nil {
  154. return nil, err
  155. }
  156. return b, nil
  157. }
  158. applyOpts := func(opts *options.HttpRule) error {
  159. b, err := newBinding(opts, len(meth.Bindings))
  160. if err != nil {
  161. return err
  162. }
  163. if b != nil {
  164. meth.Bindings = append(meth.Bindings, b)
  165. }
  166. for _, additional := range opts.GetAdditionalBindings() {
  167. if len(additional.AdditionalBindings) > 0 {
  168. return fmt.Errorf("additional_binding in additional_binding not allowed: %s.%s", svc.GetName(), meth.GetName())
  169. }
  170. b, err := newBinding(additional, len(meth.Bindings))
  171. if err != nil {
  172. return err
  173. }
  174. meth.Bindings = append(meth.Bindings, b)
  175. }
  176. return nil
  177. }
  178. for _, opts := range optsList {
  179. if err := applyOpts(opts); err != nil {
  180. return nil, err
  181. }
  182. }
  183. return meth, nil
  184. }
  185. func extractAPIOptions(meth *descriptorpb.MethodDescriptorProto) (*options.HttpRule, error) {
  186. if meth.Options == nil {
  187. return nil, nil
  188. }
  189. if !proto.HasExtension(meth.Options, options.E_Http) {
  190. return nil, nil
  191. }
  192. ext := proto.GetExtension(meth.Options, options.E_Http)
  193. opts, ok := ext.(*options.HttpRule)
  194. if !ok {
  195. return nil, fmt.Errorf("extension is %T; want an HttpRule", ext)
  196. }
  197. return opts, nil
  198. }
  199. func defaultAPIOptions(svc *Service, md *descriptorpb.MethodDescriptorProto) (*options.HttpRule, error) {
  200. // FQSN prefixes the service's full name with a '.', e.g.: '.example.ExampleService'
  201. fqsn := strings.TrimPrefix(svc.FQSN(), ".")
  202. // This generates an HttpRule that matches the gRPC mapping to HTTP/2 described in
  203. // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests
  204. // i.e.:
  205. // * method is POST
  206. // * path is "/<service name>/<method name>"
  207. // * body should contain the serialized request message
  208. rule := &options.HttpRule{
  209. Pattern: &options.HttpRule_Post{
  210. Post: fmt.Sprintf("/%s/%s", fqsn, md.GetName()),
  211. },
  212. Body: "*",
  213. }
  214. return rule, nil
  215. }
  216. func (r *Registry) newParam(meth *Method, path string) (Parameter, error) {
  217. msg := meth.RequestType
  218. fields, err := r.resolveFieldPath(msg, path, true)
  219. if err != nil {
  220. return Parameter{}, err
  221. }
  222. l := len(fields)
  223. if l == 0 {
  224. return Parameter{}, fmt.Errorf("invalid field access list for %s", path)
  225. }
  226. target := fields[l-1].Target
  227. switch target.GetType() {
  228. case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE, descriptorpb.FieldDescriptorProto_TYPE_GROUP:
  229. glog.V(2).Infoln("found aggregate type:", target, target.TypeName)
  230. if IsWellKnownType(*target.TypeName) {
  231. glog.V(2).Infoln("found well known aggregate type:", target)
  232. } else {
  233. return Parameter{}, fmt.Errorf("%s.%s: %s is a protobuf message type. Protobuf message types cannot be used as path parameters, use a scalar value type (such as string) instead", meth.Service.GetName(), meth.GetName(), path)
  234. }
  235. }
  236. return Parameter{
  237. FieldPath: FieldPath(fields),
  238. Method: meth,
  239. Target: fields[l-1].Target,
  240. }, nil
  241. }
  242. func (r *Registry) newBody(meth *Method, path string) (*Body, error) {
  243. msg := meth.RequestType
  244. switch path {
  245. case "":
  246. return nil, nil
  247. case "*":
  248. return &Body{FieldPath: nil}, nil
  249. }
  250. fields, err := r.resolveFieldPath(msg, path, false)
  251. if err != nil {
  252. return nil, err
  253. }
  254. return &Body{FieldPath: FieldPath(fields)}, nil
  255. }
  256. func (r *Registry) newResponse(meth *Method, path string) (*Body, error) {
  257. msg := meth.ResponseType
  258. switch path {
  259. case "", "*":
  260. return nil, nil
  261. }
  262. fields, err := r.resolveFieldPath(msg, path, false)
  263. if err != nil {
  264. return nil, err
  265. }
  266. return &Body{FieldPath: FieldPath(fields)}, nil
  267. }
  268. // lookupField looks up a field named "name" within "msg".
  269. // It returns nil if no such field found.
  270. func lookupField(msg *Message, name string) *Field {
  271. for _, f := range msg.Fields {
  272. if f.GetName() == name {
  273. return f
  274. }
  275. }
  276. return nil
  277. }
  278. // resolveFieldPath resolves "path" into a list of fieldDescriptor, starting from "msg".
  279. func (r *Registry) resolveFieldPath(msg *Message, path string, isPathParam bool) ([]FieldPathComponent, error) {
  280. if path == "" {
  281. return nil, nil
  282. }
  283. root := msg
  284. var result []FieldPathComponent
  285. for i, c := range strings.Split(path, ".") {
  286. if i > 0 {
  287. f := result[i-1].Target
  288. switch f.GetType() {
  289. case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE, descriptorpb.FieldDescriptorProto_TYPE_GROUP:
  290. var err error
  291. msg, err = r.LookupMsg(msg.FQMN(), f.GetTypeName())
  292. if err != nil {
  293. return nil, err
  294. }
  295. default:
  296. return nil, fmt.Errorf("not an aggregate type: %s in %s", f.GetName(), path)
  297. }
  298. }
  299. glog.V(2).Infof("Lookup %s in %s", c, msg.FQMN())
  300. f := lookupField(msg, c)
  301. if f == nil {
  302. return nil, fmt.Errorf("no field %q found in %s", path, root.GetName())
  303. }
  304. if !(isPathParam || r.allowRepeatedFieldsInBody) && f.GetLabel() == descriptorpb.FieldDescriptorProto_LABEL_REPEATED {
  305. return nil, fmt.Errorf("repeated field not allowed in field path: %s in %s", f.GetName(), path)
  306. }
  307. if isPathParam && f.GetProto3Optional() {
  308. return nil, fmt.Errorf("optional field not allowed in field path: %s in %s", f.GetName(), path)
  309. }
  310. result = append(result, FieldPathComponent{Name: c, Target: f})
  311. }
  312. return result, nil
  313. }