generator.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. package genopenapi
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "path/filepath"
  8. "reflect"
  9. "strings"
  10. "git.ikuban.com/server/swagger-api/protoc-gen-openapiv2/internal/descriptor"
  11. gen "git.ikuban.com/server/swagger-api/protoc-gen-openapiv2/internal/generator"
  12. "github.com/golang/glog"
  13. anypb "github.com/golang/protobuf/ptypes/any"
  14. openapi_options "github.com/grpc-ecosystem/grpc-gateway/v2/protoc-gen-openapiv2/options"
  15. statuspb "google.golang.org/genproto/googleapis/rpc/status"
  16. "google.golang.org/protobuf/proto"
  17. "google.golang.org/protobuf/types/descriptorpb"
  18. "google.golang.org/protobuf/types/pluginpb"
  19. //nolint:staticcheck // Known issue, will be replaced when possible
  20. legacydescriptor "github.com/golang/protobuf/descriptor"
  21. )
  22. var (
  23. errNoTargetService = errors.New("no target service defined in the file")
  24. )
  25. type generator struct {
  26. reg *descriptor.Registry
  27. }
  28. type wrapper struct {
  29. fileName string
  30. swagger *openapiSwaggerObject
  31. }
  32. type GeneratorOptions struct {
  33. Registry *descriptor.Registry
  34. RecursiveDepth int
  35. }
  36. // New returns a new generator which generates grpc gateway files.
  37. func New(reg *descriptor.Registry) gen.Generator {
  38. return &generator{reg: reg}
  39. }
  40. // Merge a lot of OpenAPI file (wrapper) to single one OpenAPI file
  41. func mergeTargetFile(targets []*wrapper, mergeFileName string) *wrapper {
  42. var mergedTarget *wrapper
  43. for _, f := range targets {
  44. if mergedTarget == nil {
  45. mergedTarget = &wrapper{
  46. fileName: mergeFileName,
  47. swagger: f.swagger,
  48. }
  49. } else {
  50. for k, v := range f.swagger.Definitions {
  51. mergedTarget.swagger.Definitions[k] = v
  52. }
  53. for k, v := range f.swagger.Paths {
  54. mergedTarget.swagger.Paths[k] = v
  55. }
  56. for k, v := range f.swagger.SecurityDefinitions {
  57. mergedTarget.swagger.SecurityDefinitions[k] = v
  58. }
  59. mergedTarget.swagger.Security = append(mergedTarget.swagger.Security, f.swagger.Security...)
  60. }
  61. }
  62. return mergedTarget
  63. }
  64. // Q: What's up with the alias types here?
  65. // A: We don't want to completely override how these structs are marshaled into
  66. //
  67. // JSON, we only want to add fields (see below, extensionMarshalJSON).
  68. // An infinite recursion would happen if we'd call json.Marshal on the struct
  69. // that has swaggerObject as an embedded field. To avoid that, we'll create
  70. // type aliases, and those don't have the custom MarshalJSON methods defined
  71. // on them. See http://choly.ca/post/go-json-marshalling/ (or, if it ever
  72. // goes away, use
  73. // https://web.archive.org/web/20190806073003/http://choly.ca/post/go-json-marshalling/.
  74. func (so openapiSwaggerObject) MarshalJSON() ([]byte, error) {
  75. type alias openapiSwaggerObject
  76. return extensionMarshalJSON(alias(so), so.extensions)
  77. }
  78. func (so openapiInfoObject) MarshalJSON() ([]byte, error) {
  79. type alias openapiInfoObject
  80. return extensionMarshalJSON(alias(so), so.extensions)
  81. }
  82. func (so openapiSecuritySchemeObject) MarshalJSON() ([]byte, error) {
  83. type alias openapiSecuritySchemeObject
  84. return extensionMarshalJSON(alias(so), so.extensions)
  85. }
  86. func (so openapiOperationObject) MarshalJSON() ([]byte, error) {
  87. type alias openapiOperationObject
  88. return extensionMarshalJSON(alias(so), so.extensions)
  89. }
  90. func (so openapiResponseObject) MarshalJSON() ([]byte, error) {
  91. type alias openapiResponseObject
  92. return extensionMarshalJSON(alias(so), so.extensions)
  93. }
  94. func extensionMarshalJSON(so interface{}, extensions []extension) ([]byte, error) {
  95. // To append arbitrary keys to the struct we'll render into json,
  96. // we're creating another struct that embeds the original one, and
  97. // its extra fields:
  98. //
  99. // The struct will look like
  100. // struct {
  101. // *openapiCore
  102. // XGrpcGatewayFoo json.RawMessage `json:"x-grpc-gateway-foo"`
  103. // XGrpcGatewayBar json.RawMessage `json:"x-grpc-gateway-bar"`
  104. // }
  105. // and thus render into what we want -- the JSON of openapiCore with the
  106. // extensions appended.
  107. fields := []reflect.StructField{
  108. { // embedded
  109. Name: "Embedded",
  110. Type: reflect.TypeOf(so),
  111. Anonymous: true,
  112. },
  113. }
  114. for _, ext := range extensions {
  115. fields = append(fields, reflect.StructField{
  116. Name: fieldName(ext.key),
  117. Type: reflect.TypeOf(ext.value),
  118. Tag: reflect.StructTag(fmt.Sprintf("json:\"%s\"", ext.key)),
  119. })
  120. }
  121. t := reflect.StructOf(fields)
  122. s := reflect.New(t).Elem()
  123. s.Field(0).Set(reflect.ValueOf(so))
  124. for _, ext := range extensions {
  125. s.FieldByName(fieldName(ext.key)).Set(reflect.ValueOf(ext.value))
  126. }
  127. return json.Marshal(s.Interface())
  128. }
  129. // encodeOpenAPI converts OpenAPI file obj to pluginpb.CodeGeneratorResponse_File
  130. func encodeOpenAPI(file *wrapper) (*descriptor.ResponseFile, error) {
  131. var formatted bytes.Buffer
  132. enc := json.NewEncoder(&formatted)
  133. enc.SetIndent("", " ")
  134. if err := enc.Encode(*file.swagger); err != nil {
  135. return nil, err
  136. }
  137. name := file.fileName
  138. ext := filepath.Ext(name)
  139. base := strings.TrimSuffix(name, ext)
  140. output := fmt.Sprintf("%s.swagger.json", base)
  141. return &descriptor.ResponseFile{
  142. CodeGeneratorResponse_File: &pluginpb.CodeGeneratorResponse_File{
  143. Name: proto.String(output),
  144. Content: proto.String(formatted.String()),
  145. },
  146. }, nil
  147. }
  148. func (g *generator) Generate(targets []*descriptor.File) ([]*descriptor.ResponseFile, error) {
  149. var files []*descriptor.ResponseFile
  150. if g.reg.IsAllowMerge() {
  151. var mergedTarget *descriptor.File
  152. // try to find proto leader
  153. for _, f := range targets {
  154. if proto.HasExtension(f.Options, openapi_options.E_Openapiv2Swagger) {
  155. mergedTarget = f
  156. break
  157. }
  158. }
  159. // merge protos to leader
  160. for _, f := range targets {
  161. if mergedTarget == nil {
  162. mergedTarget = f
  163. } else if mergedTarget != f {
  164. mergedTarget.Enums = append(mergedTarget.Enums, f.Enums...)
  165. mergedTarget.Messages = append(mergedTarget.Messages, f.Messages...)
  166. mergedTarget.Services = append(mergedTarget.Services, f.Services...)
  167. }
  168. }
  169. targets = nil
  170. targets = append(targets, mergedTarget)
  171. }
  172. var openapis []*wrapper
  173. for _, file := range targets {
  174. glog.V(1).Infof("Processing %s", file.GetName())
  175. swagger, err := applyTemplate(param{File: file, reg: g.reg})
  176. if err == errNoTargetService {
  177. glog.V(1).Infof("%s: %v", file.GetName(), err)
  178. continue
  179. }
  180. if err != nil {
  181. return nil, err
  182. }
  183. openapis = append(openapis, &wrapper{
  184. fileName: file.GetName(),
  185. swagger: swagger,
  186. })
  187. }
  188. if g.reg.IsAllowMerge() {
  189. targetOpenAPI := mergeTargetFile(openapis, g.reg.GetMergeFileName())
  190. f, err := encodeOpenAPI(targetOpenAPI)
  191. if err != nil {
  192. return nil, fmt.Errorf("failed to encode OpenAPI for %s: %s", g.reg.GetMergeFileName(), err)
  193. }
  194. files = append(files, f)
  195. glog.V(1).Infof("New OpenAPI file will emit")
  196. } else {
  197. for _, file := range openapis {
  198. f, err := encodeOpenAPI(file)
  199. if err != nil {
  200. return nil, fmt.Errorf("failed to encode OpenAPI for %s: %s", file.fileName, err)
  201. }
  202. files = append(files, f)
  203. glog.V(1).Infof("New OpenAPI file will emit")
  204. }
  205. }
  206. return files, nil
  207. }
  208. // AddErrorDefs Adds google.rpc.Status and google.protobuf.Any
  209. // to registry (used for error-related API responses)
  210. func AddErrorDefs(reg *descriptor.Registry) error {
  211. // load internal protos
  212. any, _ := legacydescriptor.MessageDescriptorProto(&anypb.Any{})
  213. any.SourceCodeInfo = new(descriptorpb.SourceCodeInfo)
  214. status, _ := legacydescriptor.MessageDescriptorProto(&statuspb.Status{})
  215. status.SourceCodeInfo = new(descriptorpb.SourceCodeInfo)
  216. // TODO(johanbrandhorst): Use new conversion later when possible
  217. // any := protodesc.ToFileDescriptorProto((&anypb.Any{}).ProtoReflect().Descriptor().ParentFile())
  218. // status := protodesc.ToFileDescriptorProto((&statuspb.Status{}).ProtoReflect().Descriptor().ParentFile())
  219. return reg.Load(&pluginpb.CodeGeneratorRequest{
  220. ProtoFile: []*descriptorpb.FileDescriptorProto{
  221. any,
  222. status,
  223. },
  224. })
  225. }