generator.go 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059
  1. // Copyright 2020 Google LLC. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. package generator
  16. import (
  17. "fmt"
  18. "log"
  19. "net/url"
  20. "regexp"
  21. "sort"
  22. "strings"
  23. http2 "net/http"
  24. "google.golang.org/protobuf/types/descriptorpb"
  25. "google.golang.org/genproto/googleapis/api/annotations"
  26. status_pb "google.golang.org/genproto/googleapis/rpc/status"
  27. "google.golang.org/protobuf/compiler/protogen"
  28. "google.golang.org/protobuf/proto"
  29. "google.golang.org/protobuf/reflect/protoreflect"
  30. any_pb "google.golang.org/protobuf/types/known/anypb"
  31. "git.ikuban.com/server/kratos-utils/v2/transport/http/handler"
  32. wk "git.ikuban.com/server/swagger-api/v2/generator/wellknown"
  33. v3 "github.com/google/gnostic/openapiv3"
  34. )
  35. type Configuration struct {
  36. Version *string
  37. Title *string
  38. Description *string
  39. Naming *string
  40. FQSchemaNaming *bool
  41. EnumType *string
  42. CircularDepth *int
  43. DefaultResponse *bool
  44. OutputMode *string
  45. }
  46. const (
  47. infoURL = "git.ikuban.com/server/swagger-api"
  48. )
  49. // In order to dynamically add google.rpc.Status responses we need
  50. // to know the message descriptors for google.rpc.Status as well
  51. // as google.protobuf.Any.
  52. var statusProtoDesc = (&status_pb.Status{}).ProtoReflect().Descriptor()
  53. var anyProtoDesc = (&any_pb.Any{}).ProtoReflect().Descriptor()
  54. // OpenAPIv3Generator holds internal state needed to generate an OpenAPIv3 document for a transcoded Protocol Buffer service.
  55. type OpenAPIv3Generator struct {
  56. conf Configuration
  57. plugin *protogen.Plugin
  58. inputFiles []*protogen.File
  59. reflect *OpenAPIv3Reflector
  60. generatedSchemas []string // Names of schemas that have already been generated.
  61. linterRulePattern *regexp.Regexp
  62. pathPattern *regexp.Regexp
  63. namedPathPattern *regexp.Regexp
  64. }
  65. // NewOpenAPIv3Generator creates a new generator for a protoc plugin invocation.
  66. func NewOpenAPIv3Generator(plugin *protogen.Plugin, conf Configuration, inputFiles []*protogen.File) *OpenAPIv3Generator {
  67. return &OpenAPIv3Generator{
  68. conf: conf,
  69. plugin: plugin,
  70. inputFiles: inputFiles,
  71. reflect: NewOpenAPIv3Reflector(conf),
  72. generatedSchemas: make([]string, 0),
  73. linterRulePattern: regexp.MustCompile(`\(-- .* --\)`),
  74. pathPattern: regexp.MustCompile("{([^=}]+)}"),
  75. namedPathPattern: regexp.MustCompile("{(.+)=(.+)}"),
  76. }
  77. }
  78. // Run runs the generator.
  79. func (g *OpenAPIv3Generator) Run(outputFile *protogen.GeneratedFile) error {
  80. d := g.buildDocumentV3()
  81. bytes, err := d.YAMLValue("Generated with protoc-gen-openapi\n" + infoURL)
  82. if err != nil {
  83. return fmt.Errorf("failed to marshal yaml: %s", err.Error())
  84. }
  85. if _, err = outputFile.Write(bytes); err != nil {
  86. return fmt.Errorf("failed to write yaml: %s", err.Error())
  87. }
  88. return nil
  89. }
  90. func (g *OpenAPIv3Generator) RunV2() ([]byte, error) {
  91. d := g.buildDocumentV3()
  92. bytes, err := d.YAMLValue("Generated with protoc-gen-openapi\n" + infoURL)
  93. if err != nil {
  94. return bytes, fmt.Errorf("failed to marshal yaml: %s", err.Error())
  95. }
  96. return bytes, nil
  97. }
  98. // buildDocumentV3 builds an OpenAPIv3 document for a plugin request.
  99. func (g *OpenAPIv3Generator) buildDocumentV3() *v3.Document {
  100. d := &v3.Document{}
  101. d.Openapi = "3.0.3"
  102. d.Info = &v3.Info{
  103. Version: *g.conf.Version,
  104. Title: *g.conf.Title,
  105. Description: *g.conf.Description,
  106. }
  107. d.Paths = &v3.Paths{}
  108. d.Components = &v3.Components{
  109. Schemas: &v3.SchemasOrReferences{
  110. AdditionalProperties: []*v3.NamedSchemaOrReference{},
  111. },
  112. }
  113. // Go through the files and add the services to the documents, keeping
  114. // track of which schemas are referenced in the response so we can
  115. // add them later.
  116. for _, file := range g.inputFiles {
  117. if file.Generate {
  118. // Merge any `Document` annotations with the current
  119. extDocument := proto.GetExtension(file.Desc.Options(), v3.E_Document)
  120. if extDocument != nil {
  121. proto.Merge(d, extDocument.(*v3.Document))
  122. }
  123. g.addPathsToDocumentV3(d, file.Services)
  124. }
  125. }
  126. // While we have required schemas left to generate, go through the files again
  127. // looking for the related message and adding them to the document if required.
  128. for len(g.reflect.requiredSchemas) > 0 {
  129. count := len(g.reflect.requiredSchemas)
  130. for _, file := range g.plugin.Files {
  131. g.addSchemasForMessagesToDocumentV3(d, file.Messages, file.Proto.GetEdition())
  132. }
  133. g.reflect.requiredSchemas = g.reflect.requiredSchemas[count:len(g.reflect.requiredSchemas)]
  134. }
  135. // If there is only 1 service, then use it's title for the
  136. // document, if the document is missing it.
  137. if len(d.Tags) == 1 {
  138. if d.Info.Title == "" && d.Tags[0].Name != "" {
  139. d.Info.Title = d.Tags[0].Name + " API"
  140. }
  141. if d.Info.Description == "" {
  142. d.Info.Description = d.Tags[0].Description
  143. }
  144. d.Tags[0].Description = ""
  145. }
  146. allServers := []string{}
  147. // If paths methods has servers, but they're all the same, then move servers to path level
  148. for _, path := range d.Paths.Path {
  149. servers := []string{}
  150. // Only 1 server will ever be set, per method, by the generator
  151. if path.Value.Get != nil && len(path.Value.Get.Servers) == 1 {
  152. servers = appendUnique(servers, path.Value.Get.Servers[0].Url)
  153. allServers = appendUnique(allServers, path.Value.Get.Servers[0].Url)
  154. }
  155. if path.Value.Post != nil && len(path.Value.Post.Servers) == 1 {
  156. servers = appendUnique(servers, path.Value.Post.Servers[0].Url)
  157. allServers = appendUnique(allServers, path.Value.Post.Servers[0].Url)
  158. }
  159. if path.Value.Put != nil && len(path.Value.Put.Servers) == 1 {
  160. servers = appendUnique(servers, path.Value.Put.Servers[0].Url)
  161. allServers = appendUnique(allServers, path.Value.Put.Servers[0].Url)
  162. }
  163. if path.Value.Delete != nil && len(path.Value.Delete.Servers) == 1 {
  164. servers = appendUnique(servers, path.Value.Delete.Servers[0].Url)
  165. allServers = appendUnique(allServers, path.Value.Delete.Servers[0].Url)
  166. }
  167. if path.Value.Patch != nil && len(path.Value.Patch.Servers) == 1 {
  168. servers = appendUnique(servers, path.Value.Patch.Servers[0].Url)
  169. allServers = appendUnique(allServers, path.Value.Patch.Servers[0].Url)
  170. }
  171. if path.Value.Head != nil && len(path.Value.Head.Servers) == 1 {
  172. servers = appendUnique(servers, path.Value.Head.Servers[0].Url)
  173. allServers = appendUnique(allServers, path.Value.Head.Servers[0].Url)
  174. }
  175. if path.Value.Options != nil && len(path.Value.Options.Servers) == 1 {
  176. servers = appendUnique(servers, path.Value.Options.Servers[0].Url)
  177. allServers = appendUnique(allServers, path.Value.Options.Servers[0].Url)
  178. }
  179. if path.Value.Trace != nil && len(path.Value.Trace.Servers) == 1 {
  180. servers = appendUnique(servers, path.Value.Trace.Servers[0].Url)
  181. allServers = appendUnique(allServers, path.Value.Trace.Servers[0].Url)
  182. }
  183. if len(servers) == 1 {
  184. path.Value.Servers = []*v3.Server{{Url: servers[0]}}
  185. if path.Value.Get != nil {
  186. path.Value.Get.Servers = nil
  187. }
  188. if path.Value.Post != nil {
  189. path.Value.Post.Servers = nil
  190. }
  191. if path.Value.Put != nil {
  192. path.Value.Put.Servers = nil
  193. }
  194. if path.Value.Delete != nil {
  195. path.Value.Delete.Servers = nil
  196. }
  197. if path.Value.Patch != nil {
  198. path.Value.Patch.Servers = nil
  199. }
  200. if path.Value.Head != nil {
  201. path.Value.Head.Servers = nil
  202. }
  203. if path.Value.Options != nil {
  204. path.Value.Options.Servers = nil
  205. }
  206. if path.Value.Trace != nil {
  207. path.Value.Trace.Servers = nil
  208. }
  209. }
  210. }
  211. // Set all servers on API level
  212. if len(allServers) > 0 {
  213. d.Servers = []*v3.Server{}
  214. for _, server := range allServers {
  215. d.Servers = append(d.Servers, &v3.Server{Url: server})
  216. }
  217. }
  218. // If there is only 1 server, we can safely remove all path level servers
  219. if len(allServers) == 1 {
  220. for _, path := range d.Paths.Path {
  221. path.Value.Servers = nil
  222. }
  223. }
  224. // Sort the tags.
  225. {
  226. pairs := d.Tags
  227. sort.Slice(pairs, func(i, j int) bool {
  228. return pairs[i].Name < pairs[j].Name
  229. })
  230. d.Tags = pairs
  231. }
  232. // Sort the paths.
  233. {
  234. pairs := d.Paths.Path
  235. sort.Slice(pairs, func(i, j int) bool {
  236. return pairs[i].Name < pairs[j].Name
  237. })
  238. d.Paths.Path = pairs
  239. }
  240. // Sort the schemas.
  241. {
  242. pairs := d.Components.Schemas.AdditionalProperties
  243. sort.Slice(pairs, func(i, j int) bool {
  244. return pairs[i].Name < pairs[j].Name
  245. })
  246. d.Components.Schemas.AdditionalProperties = pairs
  247. }
  248. // Deduplicate and sort security schemes.
  249. if d.Components.SecuritySchemes != nil && d.Components.SecuritySchemes.AdditionalProperties != nil {
  250. seen := make(map[string]bool)
  251. unique := make([]*v3.NamedSecuritySchemeOrReference, 0)
  252. for _, scheme := range d.Components.SecuritySchemes.AdditionalProperties {
  253. if !seen[scheme.Name] {
  254. seen[scheme.Name] = true
  255. unique = append(unique, scheme)
  256. }
  257. }
  258. sort.Slice(unique, func(i, j int) bool {
  259. return unique[i].Name < unique[j].Name
  260. })
  261. d.Components.SecuritySchemes.AdditionalProperties = unique
  262. }
  263. // Deduplicate and sort responses.
  264. if d.Components.Responses != nil && d.Components.Responses.AdditionalProperties != nil {
  265. seen := make(map[string]bool)
  266. unique := make([]*v3.NamedResponseOrReference, 0)
  267. for _, resp := range d.Components.Responses.AdditionalProperties {
  268. if !seen[resp.Name] {
  269. seen[resp.Name] = true
  270. unique = append(unique, resp)
  271. }
  272. }
  273. sort.Slice(unique, func(i, j int) bool {
  274. return unique[i].Name < unique[j].Name
  275. })
  276. d.Components.Responses.AdditionalProperties = unique
  277. }
  278. // Deduplicate and sort parameters.
  279. if d.Components.Parameters != nil && d.Components.Parameters.AdditionalProperties != nil {
  280. seen := make(map[string]bool)
  281. unique := make([]*v3.NamedParameterOrReference, 0)
  282. for _, param := range d.Components.Parameters.AdditionalProperties {
  283. if !seen[param.Name] {
  284. seen[param.Name] = true
  285. unique = append(unique, param)
  286. }
  287. }
  288. sort.Slice(unique, func(i, j int) bool {
  289. return unique[i].Name < unique[j].Name
  290. })
  291. d.Components.Parameters.AdditionalProperties = unique
  292. }
  293. // Deduplicate and sort request bodies.
  294. if d.Components.RequestBodies != nil && d.Components.RequestBodies.AdditionalProperties != nil {
  295. seen := make(map[string]bool)
  296. unique := make([]*v3.NamedRequestBodyOrReference, 0)
  297. for _, body := range d.Components.RequestBodies.AdditionalProperties {
  298. if !seen[body.Name] {
  299. seen[body.Name] = true
  300. unique = append(unique, body)
  301. }
  302. }
  303. sort.Slice(unique, func(i, j int) bool {
  304. return unique[i].Name < unique[j].Name
  305. })
  306. d.Components.RequestBodies.AdditionalProperties = unique
  307. }
  308. // Deduplicate and sort headers.
  309. if d.Components.Headers != nil && d.Components.Headers.AdditionalProperties != nil {
  310. seen := make(map[string]bool)
  311. unique := make([]*v3.NamedHeaderOrReference, 0)
  312. for _, header := range d.Components.Headers.AdditionalProperties {
  313. if !seen[header.Name] {
  314. seen[header.Name] = true
  315. unique = append(unique, header)
  316. }
  317. }
  318. sort.Slice(unique, func(i, j int) bool {
  319. return unique[i].Name < unique[j].Name
  320. })
  321. d.Components.Headers.AdditionalProperties = unique
  322. }
  323. // Deduplicate servers by URL.
  324. if d.Servers != nil && len(d.Servers) > 0 {
  325. seen := make(map[string]bool)
  326. unique := make([]*v3.Server, 0)
  327. for _, server := range d.Servers {
  328. if !seen[server.Url] {
  329. seen[server.Url] = true
  330. unique = append(unique, server)
  331. }
  332. }
  333. d.Servers = unique
  334. }
  335. return d
  336. }
  337. // filterCommentString removes linter rules from comments.
  338. func (g *OpenAPIv3Generator) filterCommentString(c protogen.Comments) string {
  339. comment := g.linterRulePattern.ReplaceAllString(string(c), "")
  340. return strings.TrimSpace(comment)
  341. }
  342. func (g *OpenAPIv3Generator) findField(name string, inMessage *protogen.Message) *protogen.Field {
  343. for _, field := range inMessage.Fields {
  344. if string(field.Desc.Name()) == name || string(field.Desc.JSONName()) == name {
  345. return field
  346. }
  347. }
  348. return nil
  349. }
  350. func (g *OpenAPIv3Generator) findAndFormatFieldName(name string, inMessage *protogen.Message) string {
  351. field := g.findField(name, inMessage)
  352. if field != nil {
  353. return g.reflect.formatFieldName(field.Desc)
  354. }
  355. return name
  356. }
  357. // Note that fields which are mapped to URL query parameters must have a primitive type
  358. // or a repeated primitive type or a non-repeated message type.
  359. // In the case of a repeated type, the parameter can be repeated in the URL as ...?param=A&param=B.
  360. // In the case of a message type, each field of the message is mapped to a separate parameter,
  361. // such as ...?foo.a=A&foo.b=B&foo.c=C.
  362. // There are exceptions:
  363. // - for wrapper types it will use the same representation as the wrapped primitive type in JSON
  364. // - for google.protobuf.timestamp type it will be serialized as a string
  365. //
  366. // maps, Struct and Empty can NOT be used
  367. // messages can have any number of sub messages - including circular (e.g. sub.subsub.sub.subsub.id)
  368. // buildQueryParamsV3 extracts any valid query params, including sub and recursive messages
  369. func (g *OpenAPIv3Generator) buildQueryParamsV3(field *protogen.Field) []*v3.ParameterOrReference {
  370. depths := map[string]int{}
  371. return g._buildQueryParamsV3(field, depths)
  372. }
  373. // depths are used to keep track of how many times a message's fields has been seen
  374. func (g *OpenAPIv3Generator) _buildQueryParamsV3(field *protogen.Field, depths map[string]int) []*v3.ParameterOrReference {
  375. parameters := []*v3.ParameterOrReference{}
  376. queryFieldName := g.reflect.formatFieldName(field.Desc)
  377. fieldDescription := g.filterCommentString(field.Comments.Leading)
  378. if field.Desc.IsMap() {
  379. // Map types are not allowed in query parameteres
  380. return parameters
  381. } else if field.Desc.Kind() == protoreflect.MessageKind {
  382. typeName := g.reflect.fullMessageTypeName(field.Desc.Message())
  383. switch typeName {
  384. case ".google.protobuf.Value":
  385. fieldSchema := g.reflect.schemaOrReferenceForField(field.Desc)
  386. parameters = append(parameters,
  387. &v3.ParameterOrReference{
  388. Oneof: &v3.ParameterOrReference_Parameter{
  389. Parameter: &v3.Parameter{
  390. Name: queryFieldName,
  391. In: "query",
  392. Description: fieldDescription,
  393. Required: false,
  394. Schema: fieldSchema,
  395. },
  396. },
  397. })
  398. return parameters
  399. case ".google.protobuf.BoolValue", ".google.protobuf.BytesValue", ".google.protobuf.Int32Value", ".google.protobuf.UInt32Value",
  400. ".google.protobuf.StringValue", ".google.protobuf.Int64Value", ".google.protobuf.UInt64Value", ".google.protobuf.FloatValue",
  401. ".google.protobuf.DoubleValue":
  402. valueField := getValueField(field.Message.Desc)
  403. fieldSchema := g.reflect.schemaOrReferenceForField(valueField)
  404. parameters = append(parameters,
  405. &v3.ParameterOrReference{
  406. Oneof: &v3.ParameterOrReference_Parameter{
  407. Parameter: &v3.Parameter{
  408. Name: queryFieldName,
  409. In: "query",
  410. Description: fieldDescription,
  411. Required: false,
  412. Schema: fieldSchema,
  413. },
  414. },
  415. })
  416. return parameters
  417. case ".google.protobuf.Timestamp":
  418. fieldSchema := g.reflect.schemaOrReferenceForMessage(field.Message.Desc)
  419. parameters = append(parameters,
  420. &v3.ParameterOrReference{
  421. Oneof: &v3.ParameterOrReference_Parameter{
  422. Parameter: &v3.Parameter{
  423. Name: queryFieldName,
  424. In: "query",
  425. Description: fieldDescription,
  426. Required: false,
  427. Schema: fieldSchema,
  428. },
  429. },
  430. })
  431. return parameters
  432. case ".google.protobuf.Duration":
  433. fieldSchema := g.reflect.schemaOrReferenceForMessage(field.Message.Desc)
  434. parameters = append(parameters,
  435. &v3.ParameterOrReference{
  436. Oneof: &v3.ParameterOrReference_Parameter{
  437. Parameter: &v3.Parameter{
  438. Name: queryFieldName,
  439. In: "query",
  440. Description: fieldDescription,
  441. Required: false,
  442. Schema: fieldSchema,
  443. },
  444. },
  445. })
  446. return parameters
  447. }
  448. if field.Desc.IsList() {
  449. // Only non-repeated message types are valid
  450. return parameters
  451. }
  452. // Represent field masks directly as strings (don't expand them).
  453. if typeName == ".google.protobuf.FieldMask" {
  454. fieldSchema := g.reflect.schemaOrReferenceForField(field.Desc)
  455. parameters = append(parameters,
  456. &v3.ParameterOrReference{
  457. Oneof: &v3.ParameterOrReference_Parameter{
  458. Parameter: &v3.Parameter{
  459. Name: queryFieldName,
  460. In: "query",
  461. Description: fieldDescription,
  462. Required: false,
  463. Schema: fieldSchema,
  464. },
  465. },
  466. })
  467. return parameters
  468. }
  469. // Sub messages are allowed, even circular, as long as the final type is a primitive.
  470. // Go through each of the sub message fields
  471. for _, subField := range field.Message.Fields {
  472. subFieldFullName := string(subField.Desc.FullName())
  473. seen, ok := depths[subFieldFullName]
  474. if !ok {
  475. depths[subFieldFullName] = 0
  476. }
  477. if seen < *g.conf.CircularDepth {
  478. depths[subFieldFullName]++
  479. subParams := g._buildQueryParamsV3(subField, depths)
  480. for _, subParam := range subParams {
  481. if param, ok := subParam.Oneof.(*v3.ParameterOrReference_Parameter); ok {
  482. param.Parameter.Name = queryFieldName + "." + param.Parameter.Name
  483. parameters = append(parameters, subParam)
  484. }
  485. }
  486. }
  487. }
  488. } else if field.Desc.Kind() != protoreflect.GroupKind {
  489. // schemaOrReferenceForField also handles array types
  490. fieldSchema := g.reflect.schemaOrReferenceForField(field.Desc)
  491. parameters = append(parameters,
  492. &v3.ParameterOrReference{
  493. Oneof: &v3.ParameterOrReference_Parameter{
  494. Parameter: &v3.Parameter{
  495. Name: queryFieldName,
  496. In: "query",
  497. Description: fieldDescription,
  498. Required: false,
  499. Schema: fieldSchema,
  500. },
  501. },
  502. })
  503. }
  504. return parameters
  505. }
  506. // buildOperationV3 constructs an operation for a set of values.
  507. func (g *OpenAPIv3Generator) buildOperationV3(
  508. d *v3.Document,
  509. operationID string,
  510. tagName string,
  511. description string,
  512. defaultHost string,
  513. path string,
  514. bodyField string,
  515. inputMessage *protogen.Message,
  516. outputMessage *protogen.Message,
  517. ) (*v3.Operation, string) {
  518. // coveredParameters tracks the parameters that have been used in the body or path.
  519. coveredParameters := make([]string, 0)
  520. if bodyField != "" {
  521. coveredParameters = append(coveredParameters, bodyField)
  522. }
  523. // Initialize the list of operation parameters.
  524. parameters := []*v3.ParameterOrReference{}
  525. // Find simple path parameters like {id}
  526. if allMatches := g.pathPattern.FindAllStringSubmatch(path, -1); allMatches != nil {
  527. for _, matches := range allMatches {
  528. // Add the value to the list of covered parameters.
  529. coveredParameters = append(coveredParameters, matches[1])
  530. pathParameter := g.findAndFormatFieldName(matches[1], inputMessage)
  531. path = strings.Replace(path, matches[1], pathParameter, 1)
  532. // Add the path parameters to the operation parameters.
  533. var fieldSchema *v3.SchemaOrReference
  534. var fieldDescription string
  535. field := g.findField(pathParameter, inputMessage)
  536. if field != nil {
  537. fieldSchema = g.reflect.schemaOrReferenceForField(field.Desc)
  538. fieldDescription = g.filterCommentString(field.Comments.Leading)
  539. } else {
  540. // If field does not exist, it is safe to set it to string, as it is ignored downstream
  541. fieldSchema = &v3.SchemaOrReference{
  542. Oneof: &v3.SchemaOrReference_Schema{
  543. Schema: &v3.Schema{
  544. Type: "string",
  545. },
  546. },
  547. }
  548. }
  549. parameters = append(parameters,
  550. &v3.ParameterOrReference{
  551. Oneof: &v3.ParameterOrReference_Parameter{
  552. Parameter: &v3.Parameter{
  553. Name: pathParameter,
  554. In: "path",
  555. Description: fieldDescription,
  556. Required: true,
  557. Schema: fieldSchema,
  558. },
  559. },
  560. })
  561. }
  562. }
  563. // Find named path parameters like {name=shelves/*}
  564. if matches := g.namedPathPattern.FindStringSubmatch(path); matches != nil {
  565. // Build a list of named path parameters.
  566. namedPathParameters := make([]string, 0)
  567. // Add the "name=" "name" value to the list of covered parameters.
  568. coveredParameters = append(coveredParameters, matches[1])
  569. // Convert the path from the starred form to use named path parameters.
  570. starredPath := matches[2]
  571. parts := strings.Split(starredPath, "/")
  572. // The starred path is assumed to be in the form "things/*/otherthings/*".
  573. // We want to convert it to "things/{thingsId}/otherthings/{otherthingsId}".
  574. for i := 0; i < len(parts)-1; i += 2 {
  575. section := parts[i]
  576. namedPathParameter := g.findAndFormatFieldName(section, inputMessage)
  577. namedPathParameter = singular(namedPathParameter)
  578. parts[i+1] = "{" + namedPathParameter + "}"
  579. namedPathParameters = append(namedPathParameters, namedPathParameter)
  580. }
  581. // Rewrite the path to use the path parameters.
  582. newPath := strings.Join(parts, "/")
  583. path = strings.Replace(path, matches[0], newPath, 1)
  584. // Add the named path parameters to the operation parameters.
  585. for _, namedPathParameter := range namedPathParameters {
  586. parameters = append(parameters,
  587. &v3.ParameterOrReference{
  588. Oneof: &v3.ParameterOrReference_Parameter{
  589. Parameter: &v3.Parameter{
  590. Name: namedPathParameter,
  591. In: "path",
  592. Required: true,
  593. Description: "The " + namedPathParameter + " id.",
  594. Schema: &v3.SchemaOrReference{
  595. Oneof: &v3.SchemaOrReference_Schema{
  596. Schema: &v3.Schema{
  597. Type: "string",
  598. },
  599. },
  600. },
  601. },
  602. },
  603. })
  604. }
  605. }
  606. // Add any unhandled fields in the request message as query parameters.
  607. if bodyField != "*" && string(inputMessage.Desc.FullName()) != "google.api.HttpBody" {
  608. for _, field := range inputMessage.Fields {
  609. fieldName := string(field.Desc.Name())
  610. if !contains(coveredParameters, fieldName) && fieldName != bodyField {
  611. fieldParams := g.buildQueryParamsV3(field)
  612. parameters = append(parameters, fieldParams...)
  613. }
  614. }
  615. }
  616. // Create the response.
  617. name, content := g.reflect.responseContentForMessage(outputMessage.Desc)
  618. responses := &v3.Responses{
  619. ResponseOrReference: []*v3.NamedResponseOrReference{
  620. {
  621. Name: name,
  622. Value: &v3.ResponseOrReference{
  623. Oneof: &v3.ResponseOrReference_Response{
  624. Response: &v3.Response{
  625. Description: "OK",
  626. Content: content,
  627. },
  628. },
  629. },
  630. },
  631. },
  632. }
  633. // Add the default reponse if needed
  634. if *g.conf.DefaultResponse {
  635. anySchemaName := g.reflect.formatMessageName(anyProtoDesc)
  636. anySchema := wk.NewGoogleProtobufAnySchema(anySchemaName)
  637. g.addSchemaToDocumentV3(d, anySchema)
  638. statusSchemaName := g.reflect.formatMessageName(statusProtoDesc)
  639. statusSchema := wk.NewGoogleRpcStatusSchema(statusSchemaName, anySchemaName)
  640. g.addSchemaToDocumentV3(d, statusSchema)
  641. defaultResponse := &v3.NamedResponseOrReference{
  642. Name: "default",
  643. Value: &v3.ResponseOrReference{
  644. Oneof: &v3.ResponseOrReference_Response{
  645. Response: &v3.Response{
  646. Description: "Default error response",
  647. Content: wk.NewApplicationJsonMediaType(&v3.SchemaOrReference{
  648. Oneof: &v3.SchemaOrReference_Reference{
  649. Reference: &v3.Reference{XRef: "#/components/schemas/" + statusSchemaName}}}),
  650. },
  651. },
  652. },
  653. }
  654. responses.ResponseOrReference = append(responses.ResponseOrReference, defaultResponse)
  655. }
  656. // Create the operation.
  657. op := &v3.Operation{
  658. Tags: []string{tagName},
  659. Description: description,
  660. OperationId: operationID,
  661. Parameters: parameters,
  662. Responses: responses,
  663. }
  664. if defaultHost != "" {
  665. hostURL, err := url.Parse(defaultHost)
  666. if err == nil {
  667. hostURL.Scheme = "https"
  668. op.Servers = append(op.Servers, &v3.Server{Url: hostURL.String()})
  669. }
  670. }
  671. // If a body field is specified, we need to pass a message as the request body.
  672. if bodyField != "" {
  673. var requestSchema *v3.SchemaOrReference
  674. if bodyField == "*" {
  675. // Pass the entire request message as the request body.
  676. requestSchema = g.reflect.schemaOrReferenceForMessage(inputMessage.Desc)
  677. } else {
  678. // If body refers to a message field, use that type.
  679. for _, field := range inputMessage.Fields {
  680. if string(field.Desc.Name()) == bodyField {
  681. switch field.Desc.Kind() {
  682. case protoreflect.StringKind:
  683. requestSchema = &v3.SchemaOrReference{
  684. Oneof: &v3.SchemaOrReference_Schema{
  685. Schema: &v3.Schema{
  686. Type: "string",
  687. },
  688. },
  689. }
  690. case protoreflect.MessageKind:
  691. requestSchema = g.reflect.schemaOrReferenceForMessage(field.Message.Desc)
  692. default:
  693. log.Printf("unsupported field type %+v", field.Desc)
  694. }
  695. break
  696. }
  697. }
  698. }
  699. op.RequestBody = &v3.RequestBodyOrReference{
  700. Oneof: &v3.RequestBodyOrReference_RequestBody{
  701. RequestBody: &v3.RequestBody{
  702. Required: true,
  703. Content: &v3.MediaTypes{
  704. AdditionalProperties: []*v3.NamedMediaType{
  705. {
  706. Name: "application/json",
  707. Value: &v3.MediaType{
  708. Schema: requestSchema,
  709. },
  710. },
  711. },
  712. },
  713. },
  714. },
  715. }
  716. }
  717. return op, path
  718. }
  719. // addOperationToDocumentV3 adds an operation to the specified path/method.
  720. func (g *OpenAPIv3Generator) addOperationToDocumentV3(d *v3.Document, op *v3.Operation, path string, methodName string) {
  721. var selectedPathItem *v3.NamedPathItem
  722. for _, namedPathItem := range d.Paths.Path {
  723. if namedPathItem.Name == path {
  724. selectedPathItem = namedPathItem
  725. break
  726. }
  727. }
  728. // If we get here, we need to create a path item.
  729. if selectedPathItem == nil {
  730. selectedPathItem = &v3.NamedPathItem{Name: path, Value: &v3.PathItem{}}
  731. d.Paths.Path = append(d.Paths.Path, selectedPathItem)
  732. }
  733. // Set the operation on the specified method.
  734. switch methodName {
  735. case "GET":
  736. selectedPathItem.Value.Get = op
  737. case "POST":
  738. selectedPathItem.Value.Post = op
  739. case "PUT":
  740. selectedPathItem.Value.Put = op
  741. case "DELETE":
  742. selectedPathItem.Value.Delete = op
  743. case "PATCH":
  744. selectedPathItem.Value.Patch = op
  745. case http2.MethodHead:
  746. selectedPathItem.Value.Head = op
  747. case http2.MethodOptions:
  748. selectedPathItem.Value.Options = op
  749. case http2.MethodTrace:
  750. selectedPathItem.Value.Trace = op
  751. }
  752. }
  753. // addPathsToDocumentV3 adds paths from a specified file descriptor.
  754. func (g *OpenAPIv3Generator) addPathsToDocumentV3(d *v3.Document, services []*protogen.Service) {
  755. for _, service := range services {
  756. annotationsCount := 0
  757. for _, method := range service.Methods {
  758. comment := g.filterCommentString(method.Comments.Leading)
  759. inputMessage := method.Input
  760. outputMessage := method.Output
  761. operationID := service.GoName + "_" + method.GoName
  762. extOperation := proto.GetExtension(method.Desc.Options(), v3.E_Operation)
  763. if extOperation == nil || extOperation == v3.E_Operation.InterfaceOf(v3.E_Operation.Zero()) {
  764. continue
  765. }
  766. var path string
  767. var httpMethod string
  768. var bodyField string
  769. httpOperation := proto.GetExtension(method.Desc.Options(), annotations.E_Http)
  770. if httpOperation != nil && httpOperation != annotations.E_Http.InterfaceOf(annotations.E_Http.Zero()) {
  771. _httpOperation := httpOperation.(*annotations.HttpRule)
  772. switch httpRule := _httpOperation.GetPattern().(type) {
  773. case *annotations.HttpRule_Post:
  774. path = httpRule.Post
  775. httpMethod = http2.MethodPost
  776. bodyField = _httpOperation.GetBody()
  777. case *annotations.HttpRule_Get:
  778. path = httpRule.Get
  779. httpMethod = http2.MethodGet
  780. bodyField = ""
  781. case *annotations.HttpRule_Delete:
  782. path = httpRule.Delete
  783. httpMethod = http2.MethodDelete
  784. bodyField = ""
  785. case *annotations.HttpRule_Put:
  786. path = httpRule.Put
  787. httpMethod = http2.MethodPut
  788. bodyField = _httpOperation.GetBody()
  789. case *annotations.HttpRule_Patch:
  790. path = httpRule.Patch
  791. httpMethod = http2.MethodPatch
  792. bodyField = _httpOperation.GetBody()
  793. case *annotations.HttpRule_Custom:
  794. path = httpRule.Custom.Path
  795. httpMethod = httpRule.Custom.Kind
  796. bodyField = _httpOperation.GetBody()
  797. }
  798. }
  799. annotationsCount++
  800. if path == "" {
  801. path = handler.PathGenerator(string(service.Desc.FullName()), method.GoName)
  802. }
  803. if httpMethod == "" {
  804. httpMethod = http2.MethodPost
  805. }
  806. if bodyField == "" && (httpMethod == http2.MethodPost || httpMethod == http2.MethodPut || httpMethod == http2.MethodPatch) {
  807. bodyField = "*"
  808. }
  809. defaultHost := proto.GetExtension(service.Desc.Options(), annotations.E_DefaultHost).(string)
  810. op, path2 := g.buildOperationV3(
  811. d, operationID, service.GoName, comment, defaultHost, path, bodyField, inputMessage, outputMessage)
  812. // Merge any `Operation` annotations with the current
  813. proto.Merge(op, extOperation.(*v3.Operation))
  814. g.addOperationToDocumentV3(d, op, path2, httpMethod)
  815. }
  816. if annotationsCount > 0 {
  817. comment := g.filterCommentString(service.Comments.Leading)
  818. d.Tags = append(d.Tags, &v3.Tag{Name: service.GoName, Description: comment})
  819. }
  820. }
  821. }
  822. // addSchemaForMessageToDocumentV3 adds the schema to the document if required
  823. func (g *OpenAPIv3Generator) addSchemaToDocumentV3(d *v3.Document, schema *v3.NamedSchemaOrReference) {
  824. if contains(g.generatedSchemas, schema.Name) {
  825. return
  826. }
  827. g.generatedSchemas = append(g.generatedSchemas, schema.Name)
  828. d.Components.Schemas.AdditionalProperties = append(d.Components.Schemas.AdditionalProperties, schema)
  829. }
  830. // addSchemasForMessagesToDocumentV3 adds info from one file descriptor.
  831. func (g *OpenAPIv3Generator) addSchemasForMessagesToDocumentV3(d *v3.Document, messages []*protogen.Message, edition descriptorpb.Edition) {
  832. // For each message, generate a definition.
  833. for _, message := range messages {
  834. if message.Messages != nil {
  835. g.addSchemasForMessagesToDocumentV3(d, message.Messages, edition)
  836. }
  837. schemaName := g.reflect.formatMessageName(message.Desc)
  838. // Only generate this if we need it and haven't already generated it.
  839. if !contains(g.reflect.requiredSchemas, schemaName) ||
  840. contains(g.generatedSchemas, schemaName) {
  841. continue
  842. }
  843. typeName := g.reflect.fullMessageTypeName(message.Desc)
  844. messageDescription := g.filterCommentString(message.Comments.Leading)
  845. // `google.protobuf.Value` and `google.protobuf.Any` have special JSON transcoding
  846. // so we can't just reflect on the message descriptor.
  847. if typeName == ".google.protobuf.Value" {
  848. g.addSchemaToDocumentV3(d, wk.NewGoogleProtobufValueSchema(schemaName))
  849. continue
  850. } else if typeName == ".google.protobuf.Any" {
  851. g.addSchemaToDocumentV3(d, wk.NewGoogleProtobufAnySchema(schemaName))
  852. continue
  853. } else if typeName == ".google.rpc.Status" {
  854. anySchemaName := g.reflect.formatMessageName(anyProtoDesc)
  855. g.addSchemaToDocumentV3(d, wk.NewGoogleProtobufAnySchema(anySchemaName))
  856. g.addSchemaToDocumentV3(d, wk.NewGoogleRpcStatusSchema(schemaName, anySchemaName))
  857. continue
  858. }
  859. // Build an array holding the fields of the message.
  860. definitionProperties := &v3.Properties{
  861. AdditionalProperties: make([]*v3.NamedSchemaOrReference, 0),
  862. }
  863. var required []string
  864. for _, field := range message.Fields {
  865. // Get the field description from the comments.
  866. description := g.filterCommentString(field.Comments.Leading)
  867. // Check the field annotations to see if this is a readonly or writeonly field.
  868. inputOnly := false
  869. outputOnly := false
  870. isRequired := true
  871. extension := proto.GetExtension(field.Desc.Options(), annotations.E_FieldBehavior)
  872. if extension != nil {
  873. switch v := extension.(type) {
  874. case []annotations.FieldBehavior:
  875. for _, vv := range v {
  876. switch vv {
  877. case annotations.FieldBehavior_OUTPUT_ONLY:
  878. outputOnly = true
  879. case annotations.FieldBehavior_INPUT_ONLY:
  880. inputOnly = true
  881. case annotations.FieldBehavior_OPTIONAL:
  882. isRequired = false
  883. }
  884. }
  885. default:
  886. log.Printf("unsupported extension type %T", extension)
  887. }
  888. }
  889. if edition == descriptorpb.Edition_EDITION_2023 {
  890. if fieldOptions, ok := field.Desc.Options().(*descriptorpb.FieldOptions); ok {
  891. if fieldOptions.GetFeatures().GetFieldPresence() == descriptorpb.FeatureSet_EXPLICIT {
  892. isRequired = false
  893. }
  894. }
  895. }
  896. if isRequired {
  897. required = append(required, g.reflect.formatFieldName(field.Desc))
  898. }
  899. // The field is either described by a reference or a schema.
  900. fieldSchema := g.reflect.schemaOrReferenceForField(field.Desc)
  901. if fieldSchema == nil {
  902. continue
  903. }
  904. // If this field has siblings and is a $ref now, create a new schema use `allOf` to wrap it
  905. wrapperNeeded := inputOnly || outputOnly || description != ""
  906. if wrapperNeeded {
  907. if _, ok := fieldSchema.Oneof.(*v3.SchemaOrReference_Reference); ok {
  908. fieldSchema = &v3.SchemaOrReference{Oneof: &v3.SchemaOrReference_Schema{Schema: &v3.Schema{
  909. AllOf: []*v3.SchemaOrReference{fieldSchema},
  910. }}}
  911. }
  912. }
  913. if schema, ok := fieldSchema.Oneof.(*v3.SchemaOrReference_Schema); ok {
  914. schema.Schema.Description = description
  915. schema.Schema.ReadOnly = outputOnly
  916. schema.Schema.WriteOnly = inputOnly
  917. // Merge any `Property` annotations with the current
  918. extProperty := proto.GetExtension(field.Desc.Options(), v3.E_Property)
  919. if extProperty != nil {
  920. proto.Merge(schema.Schema, extProperty.(*v3.Schema))
  921. }
  922. }
  923. definitionProperties.AdditionalProperties = append(
  924. definitionProperties.AdditionalProperties,
  925. &v3.NamedSchemaOrReference{
  926. Name: g.reflect.formatFieldName(field.Desc),
  927. Value: fieldSchema,
  928. },
  929. )
  930. }
  931. schema := &v3.Schema{
  932. Type: "object",
  933. Description: messageDescription,
  934. Properties: definitionProperties,
  935. Required: required,
  936. }
  937. // Merge any `Schema` annotations with the current
  938. extSchema := proto.GetExtension(message.Desc.Options(), v3.E_Schema)
  939. if extSchema != nil {
  940. proto.Merge(schema, extSchema.(*v3.Schema))
  941. }
  942. // Add the schema to the components.schema list.
  943. g.addSchemaToDocumentV3(d, &v3.NamedSchemaOrReference{
  944. Name: schemaName,
  945. Value: &v3.SchemaOrReference{
  946. Oneof: &v3.SchemaOrReference_Schema{
  947. Schema: schema,
  948. },
  949. },
  950. })
  951. }
  952. }