Skip to content

Instantly share code, notes, and snippets.

@bruth
Created June 19, 2017 16:59
Show Gist options
  • Select an option

  • Save bruth/b63b5c48df3007dd7aeee42de09f58a2 to your computer and use it in GitHub Desktop.

Select an option

Save bruth/b63b5c48df3007dd7aeee42de09f58a2 to your computer and use it in GitHub Desktop.

Revisions

  1. bruth created this gist Jun 19, 2017.
    31 changes: 31 additions & 0 deletions README.md
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,31 @@
    # NATS-RPC Generator

    `go generate` command for creating a client interface, CLI, and serve function for a service interface.

    ## Usage

    ```go
    //go:generate nats-rpc -type=Service -client=client.go -cli=./cmd/cli/main.go
    package main

    type Req struct {
    Left int
    Right int
    }

    type Resp struct {
    Sum int
    }

    type Service interface {
    Add(context.Context, *Req) (*Rep, error)
    }
    ```

    ## Options

    - `type` - Name of the service interface type. All methods are expected to have the same signature. `(context.Context, <RequestType>) (<ResponseType>, error)` where the request and response types can be user-defined.
    - `client` - Name of the output file to write the client type and serve function.
    - `cli` - Name of the output file to write the CLI.
    - `group` - Name of the NATS queue group for the serve subscription handlers. Defaults to `svc.<pkg-name>`.
    - `prefix` - Prefix to all NATS subjects used. Defaults to no prefix.
    8 changes: 8 additions & 0 deletions example_service.go
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,8 @@
    //go:generate nats-rpc -type=Service -client=client.go -cli=./cmd/cli/main.go
    package example

    import "context"

    type Service interface {
    Add(context.Context, *Req) (*Rep, error)
    }
    13 changes: 13 additions & 0 deletions example_service.proto
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,13 @@
    syntax = "proto3";

    package example;


    message Req {
    int32 left = 1;
    int32 right = 2;
    }

    message Rep {
    int32 sum = 1;
    }
    124 changes: 124 additions & 0 deletions gen.go
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,124 @@
    package main

    import (
    "fmt"
    "go/ast"
    "go/build"
    "go/importer"
    "go/parser"
    "go/token"
    "go/types"
    "path/filepath"
    "strings"
    )

    func defaultImporter() types.Importer {
    return importer.Default()
    }

    // prefixDirectory places the directory name on the beginning of each name in the list.
    func prefixDirectory(directory string, names []string) []string {
    if directory == "." {
    return names
    }

    ret := make([]string, len(names))
    for i, name := range names {
    ret[i] = filepath.Join(directory, name)
    }

    return ret
    }

    // File holds a single parsed file and associated data.
    type File struct {
    pkg *Package
    // Parsed AST.
    file *ast.File
    }

    type Package struct {
    dir string
    name string
    files []*File
    // objects defined in the AST.
    defs map[*ast.Ident]types.Object
    typesPkg *types.Package
    }

    // check type-checks the package. The package must be OK to proceed.
    func (p *Package) check(fs *token.FileSet, astFiles []*ast.File) error {
    p.defs = make(map[*ast.Ident]types.Object)

    config := types.Config{Importer: defaultImporter(), FakeImportC: true}
    info := &types.Info{
    Defs: p.defs,
    }

    typesPkg, err := config.Check(p.dir, fs, astFiles, info)
    if err != nil {
    return err
    }

    p.typesPkg = typesPkg
    return nil
    }

    // ParsePackageDir parses the package residing in the directory.
    func ParsePackageDir(d string) (*Package, error) {
    pkg, err := build.Default.ImportDir(d, 0)
    if err != nil {
    return nil, fmt.Errorf("cannot process directory %s: %s", d, err)
    }

    var names []string

    names = append(names, pkg.GoFiles...)
    names = prefixDirectory(d, names)

    return parsePackage(d, names, nil)
    }

    // parsePackage analyzes the single package constructed from the named files.
    // If text is non-nil, it is a string to be used instead of the content of the file,
    // to be used for testing. parsePackage exits if there is an error.
    func parsePackage(directory string, names []string, text interface{}) (*Package, error) {
    var (
    pkg Package
    astFiles []*ast.File
    )

    fs := token.NewFileSet()

    for _, name := range names {
    if !strings.HasSuffix(name, ".go") {
    continue
    }

    parsedFile, err := parser.ParseFile(fs, name, text, 0)
    if err != nil {
    return nil, err
    }

    astFiles = append(astFiles, parsedFile)
    pkg.files = append(pkg.files, &File{
    file: parsedFile,
    pkg: &pkg,
    })
    }

    if len(astFiles) == 0 {
    return nil, fmt.Errorf("%s: no buildable Go files", directory)
    }

    pkg.name = astFiles[0].Name.Name
    pkg.dir = directory

    // Type check the package.
    err := pkg.check(fs, astFiles)
    if err != nil {
    return nil, err
    }

    return &pkg, nil
    }
    134 changes: 134 additions & 0 deletions main.go
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,134 @@
    package main

    import (
    "bytes"
    "flag"
    "fmt"
    "go/format"
    "go/types"
    "io/ioutil"
    "log"
    "text/template"
    )

    func init() {
    log.SetFlags(0)
    log.SetPrefix("nats-rpc: ")
    }

    func main() {
    var (
    typeName string
    fileName string
    cliFileName string
    serviceGroup string
    subjectPrefix string
    )

    flag.StringVar(&typeName, "type", "", "Type name.")
    flag.StringVar(&fileName, "client", "", "Output file name client interface.")
    flag.StringVar(&cliFileName, "cli", "", "Output file name for CLI.")
    flag.StringVar(&serviceGroup, "group", "", "Name of the NATS queue group.")
    flag.StringVar(&subjectPrefix, "prefix", "", "Prefix to all subjects.")

    flag.Parse()

    if typeName == "" {
    log.Fatal("type required")
    }

    if fileName == "" {
    log.Fatal("file name required")
    }

    if cliFileName == "" {
    log.Fatal("cli file name required")
    }

    args := flag.Args()

    // Default to current directory.
    if len(args) == 0 {
    args = []string{"."}
    }

    pkg, err := ParsePackageDir(args[0])
    if err != nil {
    log.Fatal(err)
    }

    var (
    ok bool
    obj types.Object
    inf *types.Interface
    )

    for _, obj = range pkg.defs {
    if obj == nil {
    continue
    }

    // Ignore objects that don't have the target name.
    if obj.Name() != typeName {
    continue
    }

    // Looking for an interface type..
    inf, ok = obj.Type().Underlying().(*types.Interface)
    if !ok {
    continue
    }

    break
    }

    meta := reflectInterface(inf)
    meta.Name = typeName
    meta.Pkg = obj.Pkg().Name()

    if serviceGroup == "" {
    serviceGroup = fmt.Sprintf("%#v", fmt.Sprintf("%s.svc", meta.Pkg))
    }

    for _, m := range meta.Methods {
    m.Pkg = meta.Pkg
    m.Topic = fmt.Sprintf("%#v", fmt.Sprintf("%s%s.%s", subjectPrefix, meta.Pkg, m.Name))
    m.ServiceGroup = serviceGroup
    }

    // Compile and generate files.
    var buf bytes.Buffer

    t := template.Must(template.New("client").Parse(fileTmpl))
    if err := t.Execute(&buf, meta); err != nil {
    log.Fatal(err)
    }

    // Format the output.
    src, err := format.Source(buf.Bytes())
    if err != nil {
    log.Fatal(err)
    }

    if err = ioutil.WriteFile(fileName, src, 0644); err != nil {
    log.Fatalf("writing output: %s", err)
    }

    // Reuse buffer.
    buf.Reset()

    t = template.Must(template.New("cli").Parse(cliTmpl))
    if err := t.Execute(&buf, meta); err != nil {
    log.Fatal(err)
    }

    // Format the output.
    src, err = format.Source(buf.Bytes())
    if err != nil {
    log.Fatal(err)
    }

    if err = ioutil.WriteFile(cliFileName, src, 0644); err != nil {
    log.Fatalf("writing output: %s", err)
    }
    }
    95 changes: 95 additions & 0 deletions meta.go
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,95 @@
    package main

    import (
    "go/types"
    "log"
    )

    type Interface struct {
    Pkg string
    Name string
    Methods []*Method
    }

    type Method struct {
    Pkg string
    Name string
    Topic string
    Request *Var
    Response *Var
    ServiceGroup string

    ins []*Var
    outs []*Var
    }

    type Var struct {
    Pkg string
    Type string
    Ptr bool
    }

    func reflectInterface(iface *types.Interface) *Interface {
    var x Interface

    // Method count.
    nm := iface.NumMethods()

    x.Methods = make([]*Method, nm)

    for i := 0; i < nm; i++ {
    m := iface.Method(i)
    x.Methods[i] = reflectMethod(m)
    }

    return &x
    }

    func reflectMethod(m *types.Func) *Method {
    sig := m.Type().(*types.Signature)
    params := sig.Params()
    results := sig.Results()

    if params.Len() != 2 {
    log.Fatalf("expected 2 params, got %d", params.Len())
    }

    if results.Len() != 2 {
    log.Fatalf("expected 2 results, got %d", results.Len())
    }

    x := Method{
    Name: m.Name(),
    }

    x.Request = reflectVar(params.At(1))
    x.Response = reflectVar(results.At(0))

    return &x
    }

    func reflectVar(v *types.Var) *Var {
    var x Var

    t := v.Type()

    switch u := t.(type) {
    case *types.Named:
    o := u.Obj()
    x.Type = o.Name()
    p := o.Pkg()
    if p != nil {
    x.Pkg = p.Name()
    }
    case *types.Pointer:
    x.Ptr = true
    o := u.Elem().(*types.Named).Obj()
    x.Type = o.Name()
    p := o.Pkg()
    if p != nil {
    x.Pkg = p.Name()
    }
    }

    return &x
    }
    194 changes: 194 additions & 0 deletions tmpl.go
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,194 @@
    package main

    var fileTmpl = `// Generated by nats-rpc. DO NOT EDIT.
    package {{ .Pkg }}
    import (
    "context"
    "os"
    "os/signal"
    "syscall"
    "github.com/golang/protobuf/proto"
    "github.research.chop.edu/libi/transport"
    )
    var (
    traceIdKey = struct{}{}
    )
    type Client interface {
    {{ .Name }}
    }
    type client struct {
    tp transport.Transport
    }
    {{ range .Methods }}
    func (c *client) {{ .Name }}(ctx context.Context, req *{{ .Request.Type }}) (*{{ .Response.Type }}, error) {
    var rep {{ .Response.Type }}
    _, err := c.tp.Request({{ .Topic }}, req, &rep)
    if err != nil {
    return nil, err
    }
    return &rep, nil
    }
    {{ end }}
    func NewClient(tp transport.Transport) Client {
    return &client{tp}
    }
    func Serve(ctx context.Context, tp transport.Transport, svc Service) error {
    ctx, cancel := context.WithCancel(ctx)
    defer func() {
    cancel()
    }()
    var err error
    {{ range .Methods }}
    _, err = tp.Subscribe({{ .Topic }}, func(msg *transport.Message) (proto.Message, error) {
    ctx := context.WithValue(ctx, traceIdKey, msg.Id)
    var req {{ .Request.Type }}
    if err := msg.Decode(&req); err != nil {
    return nil, err
    }
    return svc.{{ .Name }}(ctx, &req)
    }, transport.SubscribeQueue({{ .ServiceGroup }}))
    if err != nil {
    return err
    }
    {{ end }}
    sigchan := make(chan os.Signal)
    signal.Notify(sigchan, syscall.SIGINT, syscall.SIGTERM)
    <-sigchan
    return nil
    }
    `

    var cliTmpl = `// Generated by nats-rpc. DO NOT EDIT.
    package main
    import (
    "bytes"
    "context"
    "flag"
    "fmt"
    "os"
    "github.research.chop.edu/libi/log"
    "github.research.chop.edu/libi/{{ .Pkg }}"
    "github.research.chop.edu/libi/transport"
    "go.uber.org/zap"
    "github.com/golang/protobuf/proto"
    "github.com/golang/protobuf/jsonpb"
    "github.com/nats-io/go-nats"
    )
    const (
    clientType = "{{ .Pkg }}-cli"
    )
    var (
    buildVersion string
    traceIdKey = struct{}{}
    jsonMarshaler = &jsonpb.Marshaler{
    EmitDefaults: true,
    }
    jsonUnmarshaler = &jsonpb.Unmarshaler{}
    )
    func main() {
    var (
    natsAddr string
    printVersion bool
    )
    flag.StringVar(&natsAddr, "nats.addr", "nats://127.0.0.1:4222", "NATS address.")
    flag.BoolVar(&printVersion, "version", false, "Print version.")
    flag.Parse()
    if printVersion {
    fmt.Fprintln(os.Stdout, buildVersion)
    return
    }
    // Get method.
    args := flag.Args()
    if len(args) == 0 {
    log.Fatalf("method name required")
    }
    meth := args[0]
    // Initialize base logger.
    logger, err := log.New()
    if err != nil {
    log.Fatal(err)
    }
    logger = logger.With(
    zap.String("client.type", clientType),
    zap.String("client.version", buildVersion),
    )
    // Initialize the transport layer.
    tp, err := transport.Connect(&nats.Options{
    Url: natsAddr,
    })
    if err != nil {
    log.Fatal(err)
    }
    defer tp.Close()
    tp.SetLogger(logger)
    inp := "{}"
    if len(args) > 1 {
    inp = args[1]
    }
    inpr := bytes.NewBufferString(inp)
    client := {{ .Pkg }}.NewClient(tp)
    var rep proto.Message
    ctx := context.Background()
    switch meth { {{ range .Methods }}
    case "{{ .Name }}":
    var req {{ .Pkg }}.{{ .Request.Type }}
    if err := jsonUnmarshaler.Unmarshal(inpr, &req); err != nil {
    log.Fatalf("json: %s", err)
    }
    rep, err = client.{{ .Name }}(ctx, &req)
    {{ end }}
    default:
    log.Fatalf("unknown method %s", meth)
    }
    if err != nil {
    log.Fatalf("rpc error: %s", err)
    }
    if err := jsonMarshaler.Marshal(os.Stdout, rep); err != nil {
    log.Fatalf("error encoding response: %s", err)
    }
    fmt.Fprint(os.Stdout, "\n")
    }
    `