generator.go 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967
  1. // Copyright 2020 Google LLC. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. package generator
  16. import (
  17. "fmt"
  18. "log"
  19. "net/url"
  20. "regexp"
  21. "sort"
  22. "strings"
  23. http2 "net/http"
  24. "google.golang.org/protobuf/types/descriptorpb"
  25. "google.golang.org/genproto/googleapis/api/annotations"
  26. status_pb "google.golang.org/genproto/googleapis/rpc/status"
  27. "google.golang.org/protobuf/compiler/protogen"
  28. "google.golang.org/protobuf/proto"
  29. "google.golang.org/protobuf/reflect/protoreflect"
  30. any_pb "google.golang.org/protobuf/types/known/anypb"
  31. wk "git.ikuban.com/server/swagger-api/v2/generator/wellknown"
  32. v3 "github.com/google/gnostic/openapiv3"
  33. )
  34. type Configuration struct {
  35. Version *string
  36. Title *string
  37. Description *string
  38. Naming *string
  39. FQSchemaNaming *bool
  40. EnumType *string
  41. CircularDepth *int
  42. DefaultResponse *bool
  43. OutputMode *string
  44. }
  45. const (
  46. infoURL = "git.ikuban.com/server/swagger-api"
  47. )
  48. // In order to dynamically add google.rpc.Status responses we need
  49. // to know the message descriptors for google.rpc.Status as well
  50. // as google.protobuf.Any.
  51. var statusProtoDesc = (&status_pb.Status{}).ProtoReflect().Descriptor()
  52. var anyProtoDesc = (&any_pb.Any{}).ProtoReflect().Descriptor()
  53. // OpenAPIv3Generator holds internal state needed to generate an OpenAPIv3 document for a transcoded Protocol Buffer service.
  54. type OpenAPIv3Generator struct {
  55. conf Configuration
  56. plugin *protogen.Plugin
  57. inputFiles []*protogen.File
  58. reflect *OpenAPIv3Reflector
  59. generatedSchemas []string // Names of schemas that have already been generated.
  60. linterRulePattern *regexp.Regexp
  61. pathPattern *regexp.Regexp
  62. namedPathPattern *regexp.Regexp
  63. }
  64. // NewOpenAPIv3Generator creates a new generator for a protoc plugin invocation.
  65. func NewOpenAPIv3Generator(plugin *protogen.Plugin, conf Configuration, inputFiles []*protogen.File) *OpenAPIv3Generator {
  66. return &OpenAPIv3Generator{
  67. conf: conf,
  68. plugin: plugin,
  69. inputFiles: inputFiles,
  70. reflect: NewOpenAPIv3Reflector(conf),
  71. generatedSchemas: make([]string, 0),
  72. linterRulePattern: regexp.MustCompile(`\(-- .* --\)`),
  73. pathPattern: regexp.MustCompile("{([^=}]+)}"),
  74. namedPathPattern: regexp.MustCompile("{(.+)=(.+)}"),
  75. }
  76. }
  77. // Run runs the generator.
  78. func (g *OpenAPIv3Generator) Run(outputFile *protogen.GeneratedFile) error {
  79. d := g.buildDocumentV3()
  80. bytes, err := d.YAMLValue("Generated with protoc-gen-openapi\n" + infoURL)
  81. if err != nil {
  82. return fmt.Errorf("failed to marshal yaml: %s", err.Error())
  83. }
  84. if _, err = outputFile.Write(bytes); err != nil {
  85. return fmt.Errorf("failed to write yaml: %s", err.Error())
  86. }
  87. return nil
  88. }
  89. func (g *OpenAPIv3Generator) RunV2() ([]byte, error) {
  90. d := g.buildDocumentV3()
  91. bytes, err := d.YAMLValue("Generated with protoc-gen-openapi\n" + infoURL)
  92. if err != nil {
  93. return bytes, fmt.Errorf("failed to marshal yaml: %s", err.Error())
  94. }
  95. return bytes, nil
  96. }
  97. // buildDocumentV3 builds an OpenAPIv3 document for a plugin request.
  98. func (g *OpenAPIv3Generator) buildDocumentV3() *v3.Document {
  99. d := &v3.Document{}
  100. d.Openapi = "3.0.3"
  101. d.Info = &v3.Info{
  102. Version: *g.conf.Version,
  103. Title: *g.conf.Title,
  104. Description: *g.conf.Description,
  105. }
  106. d.Paths = &v3.Paths{}
  107. d.Components = &v3.Components{
  108. Schemas: &v3.SchemasOrReferences{
  109. AdditionalProperties: []*v3.NamedSchemaOrReference{},
  110. },
  111. }
  112. // Go through the files and add the services to the documents, keeping
  113. // track of which schemas are referenced in the response so we can
  114. // add them later.
  115. for _, file := range g.inputFiles {
  116. if file.Generate {
  117. // Merge any `Document` annotations with the current
  118. extDocument := proto.GetExtension(file.Desc.Options(), v3.E_Document)
  119. if extDocument != nil {
  120. proto.Merge(d, extDocument.(*v3.Document))
  121. }
  122. g.addPathsToDocumentV3(d, file.Services)
  123. }
  124. }
  125. // While we have required schemas left to generate, go through the files again
  126. // looking for the related message and adding them to the document if required.
  127. for len(g.reflect.requiredSchemas) > 0 {
  128. count := len(g.reflect.requiredSchemas)
  129. for _, file := range g.plugin.Files {
  130. g.addSchemasForMessagesToDocumentV3(d, file.Messages, file.Proto.GetEdition())
  131. }
  132. g.reflect.requiredSchemas = g.reflect.requiredSchemas[count:len(g.reflect.requiredSchemas)]
  133. }
  134. // If there is only 1 service, then use it's title for the
  135. // document, if the document is missing it.
  136. if len(d.Tags) == 1 {
  137. if d.Info.Title == "" && d.Tags[0].Name != "" {
  138. d.Info.Title = d.Tags[0].Name + " API"
  139. }
  140. if d.Info.Description == "" {
  141. d.Info.Description = d.Tags[0].Description
  142. }
  143. d.Tags[0].Description = ""
  144. }
  145. allServers := []string{}
  146. // If paths methods has servers, but they're all the same, then move servers to path level
  147. for _, path := range d.Paths.Path {
  148. servers := []string{}
  149. // Only 1 server will ever be set, per method, by the generator
  150. if path.Value.Get != nil && len(path.Value.Get.Servers) == 1 {
  151. servers = appendUnique(servers, path.Value.Get.Servers[0].Url)
  152. allServers = appendUnique(servers, path.Value.Get.Servers[0].Url)
  153. }
  154. if path.Value.Post != nil && len(path.Value.Post.Servers) == 1 {
  155. servers = appendUnique(servers, path.Value.Post.Servers[0].Url)
  156. allServers = appendUnique(servers, path.Value.Post.Servers[0].Url)
  157. }
  158. if path.Value.Put != nil && len(path.Value.Put.Servers) == 1 {
  159. servers = appendUnique(servers, path.Value.Put.Servers[0].Url)
  160. allServers = appendUnique(servers, path.Value.Put.Servers[0].Url)
  161. }
  162. if path.Value.Delete != nil && len(path.Value.Delete.Servers) == 1 {
  163. servers = appendUnique(servers, path.Value.Delete.Servers[0].Url)
  164. allServers = appendUnique(servers, path.Value.Delete.Servers[0].Url)
  165. }
  166. if path.Value.Patch != nil && len(path.Value.Patch.Servers) == 1 {
  167. servers = appendUnique(servers, path.Value.Patch.Servers[0].Url)
  168. allServers = appendUnique(servers, path.Value.Patch.Servers[0].Url)
  169. }
  170. if path.Value.Head != nil && len(path.Value.Head.Servers) == 1 {
  171. servers = appendUnique(servers, path.Value.Head.Servers[0].Url)
  172. allServers = appendUnique(servers, path.Value.Head.Servers[0].Url)
  173. }
  174. if path.Value.Options != nil && len(path.Value.Options.Servers) == 1 {
  175. servers = appendUnique(servers, path.Value.Options.Servers[0].Url)
  176. allServers = appendUnique(servers, path.Value.Options.Servers[0].Url)
  177. }
  178. if path.Value.Trace != nil && len(path.Value.Trace.Servers) == 1 {
  179. servers = appendUnique(servers, path.Value.Trace.Servers[0].Url)
  180. allServers = appendUnique(servers, path.Value.Trace.Servers[0].Url)
  181. }
  182. if len(servers) == 1 {
  183. path.Value.Servers = []*v3.Server{{Url: servers[0]}}
  184. if path.Value.Get != nil {
  185. path.Value.Get.Servers = nil
  186. }
  187. if path.Value.Post != nil {
  188. path.Value.Post.Servers = nil
  189. }
  190. if path.Value.Put != nil {
  191. path.Value.Put.Servers = nil
  192. }
  193. if path.Value.Delete != nil {
  194. path.Value.Delete.Servers = nil
  195. }
  196. if path.Value.Patch != nil {
  197. path.Value.Patch.Servers = nil
  198. }
  199. if path.Value.Head != nil {
  200. path.Value.Head.Servers = nil
  201. }
  202. if path.Value.Options != nil {
  203. path.Value.Options.Servers = nil
  204. }
  205. if path.Value.Trace != nil {
  206. path.Value.Trace.Servers = nil
  207. }
  208. }
  209. }
  210. // Set all servers on API level
  211. if len(allServers) > 0 {
  212. d.Servers = []*v3.Server{}
  213. for _, server := range allServers {
  214. d.Servers = append(d.Servers, &v3.Server{Url: server})
  215. }
  216. }
  217. // If there is only 1 server, we can safely remove all path level servers
  218. if len(allServers) == 1 {
  219. for _, path := range d.Paths.Path {
  220. path.Value.Servers = nil
  221. }
  222. }
  223. // Sort the tags.
  224. {
  225. pairs := d.Tags
  226. sort.Slice(pairs, func(i, j int) bool {
  227. return pairs[i].Name < pairs[j].Name
  228. })
  229. d.Tags = pairs
  230. }
  231. // Sort the paths.
  232. {
  233. pairs := d.Paths.Path
  234. sort.Slice(pairs, func(i, j int) bool {
  235. return pairs[i].Name < pairs[j].Name
  236. })
  237. d.Paths.Path = pairs
  238. }
  239. // Sort the schemas.
  240. {
  241. pairs := d.Components.Schemas.AdditionalProperties
  242. sort.Slice(pairs, func(i, j int) bool {
  243. return pairs[i].Name < pairs[j].Name
  244. })
  245. d.Components.Schemas.AdditionalProperties = pairs
  246. }
  247. return d
  248. }
  249. // filterCommentString removes linter rules from comments.
  250. func (g *OpenAPIv3Generator) filterCommentString(c protogen.Comments) string {
  251. comment := g.linterRulePattern.ReplaceAllString(string(c), "")
  252. return strings.TrimSpace(comment)
  253. }
  254. func (g *OpenAPIv3Generator) findField(name string, inMessage *protogen.Message) *protogen.Field {
  255. for _, field := range inMessage.Fields {
  256. if string(field.Desc.Name()) == name || string(field.Desc.JSONName()) == name {
  257. return field
  258. }
  259. }
  260. return nil
  261. }
  262. func (g *OpenAPIv3Generator) findAndFormatFieldName(name string, inMessage *protogen.Message) string {
  263. field := g.findField(name, inMessage)
  264. if field != nil {
  265. return g.reflect.formatFieldName(field.Desc)
  266. }
  267. return name
  268. }
  269. // Note that fields which are mapped to URL query parameters must have a primitive type
  270. // or a repeated primitive type or a non-repeated message type.
  271. // In the case of a repeated type, the parameter can be repeated in the URL as ...?param=A&param=B.
  272. // In the case of a message type, each field of the message is mapped to a separate parameter,
  273. // such as ...?foo.a=A&foo.b=B&foo.c=C.
  274. // There are exceptions:
  275. // - for wrapper types it will use the same representation as the wrapped primitive type in JSON
  276. // - for google.protobuf.timestamp type it will be serialized as a string
  277. //
  278. // maps, Struct and Empty can NOT be used
  279. // messages can have any number of sub messages - including circular (e.g. sub.subsub.sub.subsub.id)
  280. // buildQueryParamsV3 extracts any valid query params, including sub and recursive messages
  281. func (g *OpenAPIv3Generator) buildQueryParamsV3(field *protogen.Field) []*v3.ParameterOrReference {
  282. depths := map[string]int{}
  283. return g._buildQueryParamsV3(field, depths)
  284. }
  285. // depths are used to keep track of how many times a message's fields has been seen
  286. func (g *OpenAPIv3Generator) _buildQueryParamsV3(field *protogen.Field, depths map[string]int) []*v3.ParameterOrReference {
  287. parameters := []*v3.ParameterOrReference{}
  288. queryFieldName := g.reflect.formatFieldName(field.Desc)
  289. fieldDescription := g.filterCommentString(field.Comments.Leading)
  290. if field.Desc.IsMap() {
  291. // Map types are not allowed in query parameteres
  292. return parameters
  293. } else if field.Desc.Kind() == protoreflect.MessageKind {
  294. typeName := g.reflect.fullMessageTypeName(field.Desc.Message())
  295. switch typeName {
  296. case ".google.protobuf.Value":
  297. fieldSchema := g.reflect.schemaOrReferenceForField(field.Desc)
  298. parameters = append(parameters,
  299. &v3.ParameterOrReference{
  300. Oneof: &v3.ParameterOrReference_Parameter{
  301. Parameter: &v3.Parameter{
  302. Name: queryFieldName,
  303. In: "query",
  304. Description: fieldDescription,
  305. Required: false,
  306. Schema: fieldSchema,
  307. },
  308. },
  309. })
  310. return parameters
  311. case ".google.protobuf.BoolValue", ".google.protobuf.BytesValue", ".google.protobuf.Int32Value", ".google.protobuf.UInt32Value",
  312. ".google.protobuf.StringValue", ".google.protobuf.Int64Value", ".google.protobuf.UInt64Value", ".google.protobuf.FloatValue",
  313. ".google.protobuf.DoubleValue":
  314. valueField := getValueField(field.Message.Desc)
  315. fieldSchema := g.reflect.schemaOrReferenceForField(valueField)
  316. parameters = append(parameters,
  317. &v3.ParameterOrReference{
  318. Oneof: &v3.ParameterOrReference_Parameter{
  319. Parameter: &v3.Parameter{
  320. Name: queryFieldName,
  321. In: "query",
  322. Description: fieldDescription,
  323. Required: false,
  324. Schema: fieldSchema,
  325. },
  326. },
  327. })
  328. return parameters
  329. case ".google.protobuf.Timestamp":
  330. fieldSchema := g.reflect.schemaOrReferenceForMessage(field.Message.Desc)
  331. parameters = append(parameters,
  332. &v3.ParameterOrReference{
  333. Oneof: &v3.ParameterOrReference_Parameter{
  334. Parameter: &v3.Parameter{
  335. Name: queryFieldName,
  336. In: "query",
  337. Description: fieldDescription,
  338. Required: false,
  339. Schema: fieldSchema,
  340. },
  341. },
  342. })
  343. return parameters
  344. case ".google.protobuf.Duration":
  345. fieldSchema := g.reflect.schemaOrReferenceForMessage(field.Message.Desc)
  346. parameters = append(parameters,
  347. &v3.ParameterOrReference{
  348. Oneof: &v3.ParameterOrReference_Parameter{
  349. Parameter: &v3.Parameter{
  350. Name: queryFieldName,
  351. In: "query",
  352. Description: fieldDescription,
  353. Required: false,
  354. Schema: fieldSchema,
  355. },
  356. },
  357. })
  358. return parameters
  359. }
  360. if field.Desc.IsList() {
  361. // Only non-repeated message types are valid
  362. return parameters
  363. }
  364. // Represent field masks directly as strings (don't expand them).
  365. if typeName == ".google.protobuf.FieldMask" {
  366. fieldSchema := g.reflect.schemaOrReferenceForField(field.Desc)
  367. parameters = append(parameters,
  368. &v3.ParameterOrReference{
  369. Oneof: &v3.ParameterOrReference_Parameter{
  370. Parameter: &v3.Parameter{
  371. Name: queryFieldName,
  372. In: "query",
  373. Description: fieldDescription,
  374. Required: false,
  375. Schema: fieldSchema,
  376. },
  377. },
  378. })
  379. return parameters
  380. }
  381. // Sub messages are allowed, even circular, as long as the final type is a primitive.
  382. // Go through each of the sub message fields
  383. for _, subField := range field.Message.Fields {
  384. subFieldFullName := string(subField.Desc.FullName())
  385. seen, ok := depths[subFieldFullName]
  386. if !ok {
  387. depths[subFieldFullName] = 0
  388. }
  389. if seen < *g.conf.CircularDepth {
  390. depths[subFieldFullName]++
  391. subParams := g._buildQueryParamsV3(subField, depths)
  392. for _, subParam := range subParams {
  393. if param, ok := subParam.Oneof.(*v3.ParameterOrReference_Parameter); ok {
  394. param.Parameter.Name = queryFieldName + "." + param.Parameter.Name
  395. parameters = append(parameters, subParam)
  396. }
  397. }
  398. }
  399. }
  400. } else if field.Desc.Kind() != protoreflect.GroupKind {
  401. // schemaOrReferenceForField also handles array types
  402. fieldSchema := g.reflect.schemaOrReferenceForField(field.Desc)
  403. parameters = append(parameters,
  404. &v3.ParameterOrReference{
  405. Oneof: &v3.ParameterOrReference_Parameter{
  406. Parameter: &v3.Parameter{
  407. Name: queryFieldName,
  408. In: "query",
  409. Description: fieldDescription,
  410. Required: false,
  411. Schema: fieldSchema,
  412. },
  413. },
  414. })
  415. }
  416. return parameters
  417. }
  418. // buildOperationV3 constructs an operation for a set of values.
  419. func (g *OpenAPIv3Generator) buildOperationV3(
  420. d *v3.Document,
  421. operationID string,
  422. tagName string,
  423. description string,
  424. defaultHost string,
  425. path string,
  426. bodyField string,
  427. inputMessage *protogen.Message,
  428. outputMessage *protogen.Message,
  429. ) (*v3.Operation, string) {
  430. // coveredParameters tracks the parameters that have been used in the body or path.
  431. coveredParameters := make([]string, 0)
  432. if bodyField != "" {
  433. coveredParameters = append(coveredParameters, bodyField)
  434. }
  435. // Initialize the list of operation parameters.
  436. parameters := []*v3.ParameterOrReference{}
  437. // Find simple path parameters like {id}
  438. if allMatches := g.pathPattern.FindAllStringSubmatch(path, -1); allMatches != nil {
  439. for _, matches := range allMatches {
  440. // Add the value to the list of covered parameters.
  441. coveredParameters = append(coveredParameters, matches[1])
  442. pathParameter := g.findAndFormatFieldName(matches[1], inputMessage)
  443. path = strings.Replace(path, matches[1], pathParameter, 1)
  444. // Add the path parameters to the operation parameters.
  445. var fieldSchema *v3.SchemaOrReference
  446. var fieldDescription string
  447. field := g.findField(pathParameter, inputMessage)
  448. if field != nil {
  449. fieldSchema = g.reflect.schemaOrReferenceForField(field.Desc)
  450. fieldDescription = g.filterCommentString(field.Comments.Leading)
  451. } else {
  452. // If field does not exist, it is safe to set it to string, as it is ignored downstream
  453. fieldSchema = &v3.SchemaOrReference{
  454. Oneof: &v3.SchemaOrReference_Schema{
  455. Schema: &v3.Schema{
  456. Type: "string",
  457. },
  458. },
  459. }
  460. }
  461. parameters = append(parameters,
  462. &v3.ParameterOrReference{
  463. Oneof: &v3.ParameterOrReference_Parameter{
  464. Parameter: &v3.Parameter{
  465. Name: pathParameter,
  466. In: "path",
  467. Description: fieldDescription,
  468. Required: true,
  469. Schema: fieldSchema,
  470. },
  471. },
  472. })
  473. }
  474. }
  475. // Find named path parameters like {name=shelves/*}
  476. if matches := g.namedPathPattern.FindStringSubmatch(path); matches != nil {
  477. // Build a list of named path parameters.
  478. namedPathParameters := make([]string, 0)
  479. // Add the "name=" "name" value to the list of covered parameters.
  480. coveredParameters = append(coveredParameters, matches[1])
  481. // Convert the path from the starred form to use named path parameters.
  482. starredPath := matches[2]
  483. parts := strings.Split(starredPath, "/")
  484. // The starred path is assumed to be in the form "things/*/otherthings/*".
  485. // We want to convert it to "things/{thingsId}/otherthings/{otherthingsId}".
  486. for i := 0; i < len(parts)-1; i += 2 {
  487. section := parts[i]
  488. namedPathParameter := g.findAndFormatFieldName(section, inputMessage)
  489. namedPathParameter = singular(namedPathParameter)
  490. parts[i+1] = "{" + namedPathParameter + "}"
  491. namedPathParameters = append(namedPathParameters, namedPathParameter)
  492. }
  493. // Rewrite the path to use the path parameters.
  494. newPath := strings.Join(parts, "/")
  495. path = strings.Replace(path, matches[0], newPath, 1)
  496. // Add the named path parameters to the operation parameters.
  497. for _, namedPathParameter := range namedPathParameters {
  498. parameters = append(parameters,
  499. &v3.ParameterOrReference{
  500. Oneof: &v3.ParameterOrReference_Parameter{
  501. Parameter: &v3.Parameter{
  502. Name: namedPathParameter,
  503. In: "path",
  504. Required: true,
  505. Description: "The " + namedPathParameter + " id.",
  506. Schema: &v3.SchemaOrReference{
  507. Oneof: &v3.SchemaOrReference_Schema{
  508. Schema: &v3.Schema{
  509. Type: "string",
  510. },
  511. },
  512. },
  513. },
  514. },
  515. })
  516. }
  517. }
  518. // Add any unhandled fields in the request message as query parameters.
  519. if bodyField != "*" && string(inputMessage.Desc.FullName()) != "google.api.HttpBody" {
  520. for _, field := range inputMessage.Fields {
  521. fieldName := string(field.Desc.Name())
  522. if !contains(coveredParameters, fieldName) && fieldName != bodyField {
  523. fieldParams := g.buildQueryParamsV3(field)
  524. parameters = append(parameters, fieldParams...)
  525. }
  526. }
  527. }
  528. // Create the response.
  529. name, content := g.reflect.responseContentForMessage(outputMessage.Desc)
  530. responses := &v3.Responses{
  531. ResponseOrReference: []*v3.NamedResponseOrReference{
  532. {
  533. Name: name,
  534. Value: &v3.ResponseOrReference{
  535. Oneof: &v3.ResponseOrReference_Response{
  536. Response: &v3.Response{
  537. Description: "OK",
  538. Content: content,
  539. },
  540. },
  541. },
  542. },
  543. },
  544. }
  545. // Add the default reponse if needed
  546. if *g.conf.DefaultResponse {
  547. anySchemaName := g.reflect.formatMessageName(anyProtoDesc)
  548. anySchema := wk.NewGoogleProtobufAnySchema(anySchemaName)
  549. g.addSchemaToDocumentV3(d, anySchema)
  550. statusSchemaName := g.reflect.formatMessageName(statusProtoDesc)
  551. statusSchema := wk.NewGoogleRpcStatusSchema(statusSchemaName, anySchemaName)
  552. g.addSchemaToDocumentV3(d, statusSchema)
  553. defaultResponse := &v3.NamedResponseOrReference{
  554. Name: "default",
  555. Value: &v3.ResponseOrReference{
  556. Oneof: &v3.ResponseOrReference_Response{
  557. Response: &v3.Response{
  558. Description: "Default error response",
  559. Content: wk.NewApplicationJsonMediaType(&v3.SchemaOrReference{
  560. Oneof: &v3.SchemaOrReference_Reference{
  561. Reference: &v3.Reference{XRef: "#/components/schemas/" + statusSchemaName}}}),
  562. },
  563. },
  564. },
  565. }
  566. responses.ResponseOrReference = append(responses.ResponseOrReference, defaultResponse)
  567. }
  568. // Create the operation.
  569. op := &v3.Operation{
  570. Tags: []string{tagName},
  571. Description: description,
  572. OperationId: operationID,
  573. Parameters: parameters,
  574. Responses: responses,
  575. }
  576. if defaultHost != "" {
  577. hostURL, err := url.Parse(defaultHost)
  578. if err == nil {
  579. hostURL.Scheme = "https"
  580. op.Servers = append(op.Servers, &v3.Server{Url: hostURL.String()})
  581. }
  582. }
  583. // If a body field is specified, we need to pass a message as the request body.
  584. if bodyField != "" {
  585. var requestSchema *v3.SchemaOrReference
  586. if bodyField == "*" {
  587. // Pass the entire request message as the request body.
  588. requestSchema = g.reflect.schemaOrReferenceForMessage(inputMessage.Desc)
  589. } else {
  590. // If body refers to a message field, use that type.
  591. for _, field := range inputMessage.Fields {
  592. if string(field.Desc.Name()) == bodyField {
  593. switch field.Desc.Kind() {
  594. case protoreflect.StringKind:
  595. requestSchema = &v3.SchemaOrReference{
  596. Oneof: &v3.SchemaOrReference_Schema{
  597. Schema: &v3.Schema{
  598. Type: "string",
  599. },
  600. },
  601. }
  602. case protoreflect.MessageKind:
  603. requestSchema = g.reflect.schemaOrReferenceForMessage(field.Message.Desc)
  604. default:
  605. log.Printf("unsupported field type %+v", field.Desc)
  606. }
  607. break
  608. }
  609. }
  610. }
  611. op.RequestBody = &v3.RequestBodyOrReference{
  612. Oneof: &v3.RequestBodyOrReference_RequestBody{
  613. RequestBody: &v3.RequestBody{
  614. Required: true,
  615. Content: &v3.MediaTypes{
  616. AdditionalProperties: []*v3.NamedMediaType{
  617. {
  618. Name: "application/json",
  619. Value: &v3.MediaType{
  620. Schema: requestSchema,
  621. },
  622. },
  623. },
  624. },
  625. },
  626. },
  627. }
  628. }
  629. return op, path
  630. }
  631. // addOperationToDocumentV3 adds an operation to the specified path/method.
  632. func (g *OpenAPIv3Generator) addOperationToDocumentV3(d *v3.Document, op *v3.Operation, path string, methodName string) {
  633. var selectedPathItem *v3.NamedPathItem
  634. for _, namedPathItem := range d.Paths.Path {
  635. if namedPathItem.Name == path {
  636. selectedPathItem = namedPathItem
  637. break
  638. }
  639. }
  640. // If we get here, we need to create a path item.
  641. if selectedPathItem == nil {
  642. selectedPathItem = &v3.NamedPathItem{Name: path, Value: &v3.PathItem{}}
  643. d.Paths.Path = append(d.Paths.Path, selectedPathItem)
  644. }
  645. // Set the operation on the specified method.
  646. switch methodName {
  647. case "GET":
  648. selectedPathItem.Value.Get = op
  649. case "POST":
  650. selectedPathItem.Value.Post = op
  651. case "PUT":
  652. selectedPathItem.Value.Put = op
  653. case "DELETE":
  654. selectedPathItem.Value.Delete = op
  655. case "PATCH":
  656. selectedPathItem.Value.Patch = op
  657. case http2.MethodHead:
  658. selectedPathItem.Value.Head = op
  659. case http2.MethodOptions:
  660. selectedPathItem.Value.Options = op
  661. case http2.MethodTrace:
  662. selectedPathItem.Value.Trace = op
  663. }
  664. }
  665. // addPathsToDocumentV3 adds paths from a specified file descriptor.
  666. func (g *OpenAPIv3Generator) addPathsToDocumentV3(d *v3.Document, services []*protogen.Service) {
  667. for _, service := range services {
  668. annotationsCount := 0
  669. for _, method := range service.Methods {
  670. comment := g.filterCommentString(method.Comments.Leading)
  671. inputMessage := method.Input
  672. outputMessage := method.Output
  673. operationID := service.GoName + "_" + method.GoName
  674. extOperation := proto.GetExtension(method.Desc.Options(), v3.E_Operation)
  675. if extOperation == nil || extOperation == v3.E_Operation.InterfaceOf(v3.E_Operation.Zero()) {
  676. continue
  677. }
  678. httpOperation := proto.GetExtension(method.Desc.Options(), annotations.E_Http)
  679. if httpOperation == nil || httpOperation == annotations.E_Http.InterfaceOf(annotations.E_Http.Zero()) {
  680. continue
  681. }
  682. annotationsCount++
  683. _httpOperation := httpOperation.(*annotations.HttpRule)
  684. var path string
  685. var httpMethod string
  686. var bodyField string
  687. switch httpRule := _httpOperation.GetPattern().(type) {
  688. case *annotations.HttpRule_Post:
  689. path = httpRule.Post
  690. httpMethod = http2.MethodPost
  691. bodyField = _httpOperation.GetBody()
  692. case *annotations.HttpRule_Get:
  693. path = httpRule.Get
  694. httpMethod = http2.MethodGet
  695. bodyField = ""
  696. case *annotations.HttpRule_Delete:
  697. path = httpRule.Delete
  698. httpMethod = http2.MethodDelete
  699. bodyField = ""
  700. case *annotations.HttpRule_Put:
  701. path = httpRule.Put
  702. httpMethod = http2.MethodPut
  703. bodyField = _httpOperation.GetBody()
  704. case *annotations.HttpRule_Patch:
  705. path = httpRule.Patch
  706. httpMethod = http2.MethodPatch
  707. bodyField = _httpOperation.GetBody()
  708. case *annotations.HttpRule_Custom:
  709. path = httpRule.Custom.Path
  710. httpMethod = httpRule.Custom.Kind
  711. bodyField = _httpOperation.GetBody()
  712. }
  713. if path == "" {
  714. path = fmt.Sprintf("/api/%s/%s", service.Desc.FullName(), method.GoName)
  715. }
  716. if httpMethod == "" {
  717. httpMethod = http2.MethodPost
  718. }
  719. if bodyField == "" && (httpMethod == http2.MethodPost || httpMethod == http2.MethodPut || httpMethod == http2.MethodPatch) {
  720. bodyField = "*"
  721. }
  722. defaultHost := proto.GetExtension(service.Desc.Options(), annotations.E_DefaultHost).(string)
  723. op, path2 := g.buildOperationV3(
  724. d, operationID, service.GoName, comment, defaultHost, path, bodyField, inputMessage, outputMessage)
  725. // Merge any `Operation` annotations with the current
  726. proto.Merge(op, extOperation.(*v3.Operation))
  727. g.addOperationToDocumentV3(d, op, path2, httpMethod)
  728. }
  729. if annotationsCount > 0 {
  730. comment := g.filterCommentString(service.Comments.Leading)
  731. d.Tags = append(d.Tags, &v3.Tag{Name: service.GoName, Description: comment})
  732. }
  733. }
  734. }
  735. // addSchemaForMessageToDocumentV3 adds the schema to the document if required
  736. func (g *OpenAPIv3Generator) addSchemaToDocumentV3(d *v3.Document, schema *v3.NamedSchemaOrReference) {
  737. if contains(g.generatedSchemas, schema.Name) {
  738. return
  739. }
  740. g.generatedSchemas = append(g.generatedSchemas, schema.Name)
  741. d.Components.Schemas.AdditionalProperties = append(d.Components.Schemas.AdditionalProperties, schema)
  742. }
  743. // addSchemasForMessagesToDocumentV3 adds info from one file descriptor.
  744. func (g *OpenAPIv3Generator) addSchemasForMessagesToDocumentV3(d *v3.Document, messages []*protogen.Message, edition descriptorpb.Edition) {
  745. // For each message, generate a definition.
  746. for _, message := range messages {
  747. if message.Messages != nil {
  748. g.addSchemasForMessagesToDocumentV3(d, message.Messages, edition)
  749. }
  750. schemaName := g.reflect.formatMessageName(message.Desc)
  751. // Only generate this if we need it and haven't already generated it.
  752. if !contains(g.reflect.requiredSchemas, schemaName) ||
  753. contains(g.generatedSchemas, schemaName) {
  754. continue
  755. }
  756. typeName := g.reflect.fullMessageTypeName(message.Desc)
  757. messageDescription := g.filterCommentString(message.Comments.Leading)
  758. // `google.protobuf.Value` and `google.protobuf.Any` have special JSON transcoding
  759. // so we can't just reflect on the message descriptor.
  760. if typeName == ".google.protobuf.Value" {
  761. g.addSchemaToDocumentV3(d, wk.NewGoogleProtobufValueSchema(schemaName))
  762. continue
  763. } else if typeName == ".google.protobuf.Any" {
  764. g.addSchemaToDocumentV3(d, wk.NewGoogleProtobufAnySchema(schemaName))
  765. continue
  766. } else if typeName == ".google.rpc.Status" {
  767. anySchemaName := g.reflect.formatMessageName(anyProtoDesc)
  768. g.addSchemaToDocumentV3(d, wk.NewGoogleProtobufAnySchema(anySchemaName))
  769. g.addSchemaToDocumentV3(d, wk.NewGoogleRpcStatusSchema(schemaName, anySchemaName))
  770. continue
  771. }
  772. // Build an array holding the fields of the message.
  773. definitionProperties := &v3.Properties{
  774. AdditionalProperties: make([]*v3.NamedSchemaOrReference, 0),
  775. }
  776. var required []string
  777. for _, field := range message.Fields {
  778. // Get the field description from the comments.
  779. description := g.filterCommentString(field.Comments.Leading)
  780. // Check the field annotations to see if this is a readonly or writeonly field.
  781. inputOnly := false
  782. outputOnly := false
  783. isRequired := true
  784. extension := proto.GetExtension(field.Desc.Options(), annotations.E_FieldBehavior)
  785. if extension != nil {
  786. switch v := extension.(type) {
  787. case []annotations.FieldBehavior:
  788. for _, vv := range v {
  789. switch vv {
  790. case annotations.FieldBehavior_OUTPUT_ONLY:
  791. outputOnly = true
  792. case annotations.FieldBehavior_INPUT_ONLY:
  793. inputOnly = true
  794. case annotations.FieldBehavior_OPTIONAL:
  795. isRequired = false
  796. }
  797. }
  798. default:
  799. log.Printf("unsupported extension type %T", extension)
  800. }
  801. }
  802. if edition == descriptorpb.Edition_EDITION_2023 {
  803. if fieldOptions, ok := field.Desc.Options().(*descriptorpb.FieldOptions); ok {
  804. if fieldOptions.GetFeatures().GetFieldPresence() == descriptorpb.FeatureSet_EXPLICIT {
  805. isRequired = false
  806. }
  807. }
  808. }
  809. if isRequired {
  810. required = append(required, g.reflect.formatFieldName(field.Desc))
  811. }
  812. // The field is either described by a reference or a schema.
  813. fieldSchema := g.reflect.schemaOrReferenceForField(field.Desc)
  814. if fieldSchema == nil {
  815. continue
  816. }
  817. // If this field has siblings and is a $ref now, create a new schema use `allOf` to wrap it
  818. wrapperNeeded := inputOnly || outputOnly || description != ""
  819. if wrapperNeeded {
  820. if _, ok := fieldSchema.Oneof.(*v3.SchemaOrReference_Reference); ok {
  821. fieldSchema = &v3.SchemaOrReference{Oneof: &v3.SchemaOrReference_Schema{Schema: &v3.Schema{
  822. AllOf: []*v3.SchemaOrReference{fieldSchema},
  823. }}}
  824. }
  825. }
  826. if schema, ok := fieldSchema.Oneof.(*v3.SchemaOrReference_Schema); ok {
  827. schema.Schema.Description = description
  828. schema.Schema.ReadOnly = outputOnly
  829. schema.Schema.WriteOnly = inputOnly
  830. // Merge any `Property` annotations with the current
  831. extProperty := proto.GetExtension(field.Desc.Options(), v3.E_Property)
  832. if extProperty != nil {
  833. proto.Merge(schema.Schema, extProperty.(*v3.Schema))
  834. }
  835. }
  836. definitionProperties.AdditionalProperties = append(
  837. definitionProperties.AdditionalProperties,
  838. &v3.NamedSchemaOrReference{
  839. Name: g.reflect.formatFieldName(field.Desc),
  840. Value: fieldSchema,
  841. },
  842. )
  843. }
  844. schema := &v3.Schema{
  845. Type: "object",
  846. Description: messageDescription,
  847. Properties: definitionProperties,
  848. Required: required,
  849. }
  850. // Merge any `Schema` annotations with the current
  851. extSchema := proto.GetExtension(message.Desc.Options(), v3.E_Schema)
  852. if extSchema != nil {
  853. proto.Merge(schema, extSchema.(*v3.Schema))
  854. }
  855. // Add the schema to the components.schema list.
  856. g.addSchemaToDocumentV3(d, &v3.NamedSchemaOrReference{
  857. Name: schemaName,
  858. Value: &v3.SchemaOrReference{
  859. Oneof: &v3.SchemaOrReference_Schema{
  860. Schema: schema,
  861. },
  862. },
  863. })
  864. }
  865. }