main.go 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. package main
  2. import (
  3. "flag"
  4. "fmt"
  5. "os"
  6. "strings"
  7. "git.ikuban.com/server/swagger-api/protoc-gen-openapiv2/internal/codegenerator"
  8. "git.ikuban.com/server/swagger-api/protoc-gen-openapiv2/internal/descriptor"
  9. "git.ikuban.com/server/swagger-api/protoc-gen-openapiv2/internal/genopenapi"
  10. "github.com/golang/glog"
  11. "google.golang.org/protobuf/proto"
  12. "google.golang.org/protobuf/types/pluginpb"
  13. )
  14. var (
  15. importPrefix = flag.String("import_prefix", "", "prefix to be added to go package paths for imported proto files")
  16. file = flag.String("file", "-", "where to load data from")
  17. allowDeleteBody = flag.Bool("allow_delete_body", false, "unless set, HTTP DELETE methods may not have a body")
  18. grpcAPIConfiguration = flag.String("grpc_api_configuration", "", "path to file which describes the gRPC API Configuration in YAML format")
  19. allowMerge = flag.Bool("allow_merge", false, "if set, generation one OpenAPI file out of multiple protos")
  20. mergeFileName = flag.String("merge_file_name", "apidocs", "target OpenAPI file name prefix after merge")
  21. useJSONNamesForFields = flag.Bool("json_names_for_fields", true, "if disabled, the original proto name will be used for generating OpenAPI definitions")
  22. repeatedPathParamSeparator = flag.String("repeated_path_param_separator", "csv", "configures how repeated fields should be split. Allowed values are `csv`, `pipes`, `ssv` and `tsv`")
  23. versionFlag = flag.Bool("version", false, "print the current version")
  24. allowRepeatedFieldsInBody = flag.Bool("allow_repeated_fields_in_body", false, "allows to use repeated field in `body` and `response_body` field of `google.api.http` annotation option")
  25. includePackageInTags = flag.Bool("include_package_in_tags", false, "if unset, the gRPC service name is added to the `Tags` field of each operation. If set and the `package` directive is shown in the proto file, the package name will be prepended to the service name")
  26. useFQNForOpenAPIName = flag.Bool("fqn_for_openapi_name", false, "if set, the object's OpenAPI names will use the fully qualified names from the proto definition (ie my.package.MyMessage.MyInnerMessage")
  27. useGoTemplate = flag.Bool("use_go_templates", false, "if set, you can use Go templates in protofile comments")
  28. disableDefaultErrors = flag.Bool("disable_default_errors", false, "if set, disables generation of default errors. This is useful if you have defined custom error handling")
  29. enumsAsInts = flag.Bool("enums_as_ints", false, "whether to render enum values as integers, as opposed to string values")
  30. simpleOperationIDs = flag.Bool("simple_operation_ids", false, "whether to remove the service prefix in the operationID generation. Can introduce duplicate operationIDs, use with caution.")
  31. proto3OptionalNullable = flag.Bool("proto3_optional_nullable", false, "whether Proto3 Optional fields should be marked as x-nullable")
  32. openAPIConfiguration = flag.String("openapi_configuration", "", "path to file which describes the OpenAPI Configuration in YAML format")
  33. generateUnboundMethods = flag.Bool("generate_unbound_methods", false, "generate swagger metadata even for RPC methods that have no HttpRule annotation")
  34. generateRPCMethods = flag.Bool("generate_rpc_methods", false, "generate swagger metadata even for RPC methods without HttpRule annotation")
  35. recursiveDepth = flag.Int("recursive-depth", 1000, "maximum recursion count allowed for a field type")
  36. )
  37. // Variables set by goreleaser at build time
  38. var (
  39. version = "dev123"
  40. commit = "unknown"
  41. date = "unknown"
  42. )
  43. func main() {
  44. flag.Parse()
  45. defer glog.Flush()
  46. if *versionFlag {
  47. fmt.Printf("Version %v, commit %v, built at %v\n", version, commit, date)
  48. os.Exit(0)
  49. }
  50. reg := descriptor.NewRegistry()
  51. glog.V(1).Info("Processing code generator request")
  52. f := os.Stdin
  53. if *file != "-" {
  54. var err error
  55. f, err = os.Open(*file)
  56. if err != nil {
  57. glog.Fatal(err)
  58. }
  59. }
  60. glog.V(1).Info("Parsing code generator request")
  61. req, err := codegenerator.ParseRequest(f)
  62. if err != nil {
  63. glog.Fatal(err)
  64. }
  65. glog.V(1).Info("Parsed code generator request")
  66. pkgMap := make(map[string]string)
  67. if req.Parameter != nil {
  68. err := parseReqParam(req.GetParameter(), flag.CommandLine, pkgMap)
  69. if err != nil {
  70. glog.Fatalf("Error parsing flags: %v", err)
  71. }
  72. }
  73. reg.SetPrefix(*importPrefix)
  74. reg.SetAllowDeleteBody(*allowDeleteBody)
  75. reg.SetAllowMerge(*allowMerge)
  76. reg.SetMergeFileName(*mergeFileName)
  77. reg.SetUseJSONNamesForFields(*useJSONNamesForFields)
  78. reg.SetAllowRepeatedFieldsInBody(*allowRepeatedFieldsInBody)
  79. reg.SetIncludePackageInTags(*includePackageInTags)
  80. reg.SetUseFQNForOpenAPIName(*useFQNForOpenAPIName)
  81. reg.SetUseGoTemplate(*useGoTemplate)
  82. reg.SetEnumsAsInts(*enumsAsInts)
  83. reg.SetDisableDefaultErrors(*disableDefaultErrors)
  84. reg.SetSimpleOperationIDs(*simpleOperationIDs)
  85. reg.SetProto3OptionalNullable(*proto3OptionalNullable)
  86. reg.SetGenerateUnboundMethods(*generateUnboundMethods)
  87. reg.SetGenerateRPCMethods(*generateRPCMethods)
  88. reg.SetRecursiveDepth(*recursiveDepth)
  89. if err := reg.SetRepeatedPathParamSeparator(*repeatedPathParamSeparator); err != nil {
  90. emitError(err)
  91. return
  92. }
  93. for k, v := range pkgMap {
  94. reg.AddPkgMap(k, v)
  95. }
  96. if *grpcAPIConfiguration != "" {
  97. if err := reg.LoadGrpcAPIServiceFromYAML(*grpcAPIConfiguration); err != nil {
  98. emitError(err)
  99. return
  100. }
  101. }
  102. g := genopenapi.New(reg)
  103. if err := genopenapi.AddErrorDefs(reg); err != nil {
  104. emitError(err)
  105. return
  106. }
  107. if err := reg.Load(req); err != nil {
  108. emitError(err)
  109. return
  110. }
  111. if *openAPIConfiguration != "" {
  112. if err := reg.LoadOpenAPIConfigFromYAML(*openAPIConfiguration); err != nil {
  113. emitError(err)
  114. return
  115. }
  116. }
  117. var targets []*descriptor.File
  118. for _, target := range req.FileToGenerate {
  119. f, err := reg.LookupFile(target)
  120. if err != nil {
  121. glog.Fatal(err)
  122. }
  123. targets = append(targets, f)
  124. }
  125. out, err := g.Generate(targets)
  126. glog.V(1).Info("Processed code generator request")
  127. if err != nil {
  128. emitError(err)
  129. return
  130. }
  131. emitFiles(out)
  132. }
  133. func emitFiles(out []*descriptor.ResponseFile) {
  134. files := make([]*pluginpb.CodeGeneratorResponse_File, len(out))
  135. for idx, item := range out {
  136. files[idx] = item.CodeGeneratorResponse_File
  137. }
  138. resp := &pluginpb.CodeGeneratorResponse{File: files}
  139. codegenerator.SetSupportedFeaturesOnCodeGeneratorResponse(resp)
  140. emitResp(resp)
  141. }
  142. func emitError(err error) {
  143. emitResp(&pluginpb.CodeGeneratorResponse{Error: proto.String(err.Error())})
  144. }
  145. func emitResp(resp *pluginpb.CodeGeneratorResponse) {
  146. buf, err := proto.Marshal(resp)
  147. if err != nil {
  148. glog.Fatal(err)
  149. }
  150. if _, err := os.Stdout.Write(buf); err != nil {
  151. glog.Fatal(err)
  152. }
  153. }
  154. // parseReqParam parses a CodeGeneratorRequest parameter and adds the
  155. // extracted values to the given FlagSet and pkgMap. Returns a non-nil
  156. // error if setting a flag failed.
  157. func parseReqParam(param string, f *flag.FlagSet, pkgMap map[string]string) error {
  158. if param == "" {
  159. return nil
  160. }
  161. for _, p := range strings.Split(param, ",") {
  162. spec := strings.SplitN(p, "=", 2)
  163. if len(spec) == 1 {
  164. if spec[0] == "allow_delete_body" {
  165. err := f.Set(spec[0], "true")
  166. if err != nil {
  167. return fmt.Errorf("cannot set flag %s: %v", p, err)
  168. }
  169. continue
  170. }
  171. if spec[0] == "allow_merge" {
  172. err := f.Set(spec[0], "true")
  173. if err != nil {
  174. return fmt.Errorf("cannot set flag %s: %v", p, err)
  175. }
  176. continue
  177. }
  178. if spec[0] == "allow_repeated_fields_in_body" {
  179. err := f.Set(spec[0], "true")
  180. if err != nil {
  181. return fmt.Errorf("cannot set flag %s: %v", p, err)
  182. }
  183. continue
  184. }
  185. if spec[0] == "include_package_in_tags" {
  186. err := f.Set(spec[0], "true")
  187. if err != nil {
  188. return fmt.Errorf("cannot set flag %s: %v", p, err)
  189. }
  190. continue
  191. }
  192. err := f.Set(spec[0], "")
  193. if err != nil {
  194. return fmt.Errorf("cannot set flag %s: %v", p, err)
  195. }
  196. continue
  197. }
  198. name, value := spec[0], spec[1]
  199. if strings.HasPrefix(name, "M") {
  200. pkgMap[name[1:]] = value
  201. continue
  202. }
  203. if err := f.Set(name, value); err != nil {
  204. return fmt.Errorf("cannot set flag %s: %v", p, err)
  205. }
  206. }
  207. return nil
  208. }