Goのモジュールを個別に更新する

特に問題なければ go get -u ./... で全て更新してしまうのが楽ですが、更新できないモジュールがある場合は個別に更新する必要があります。

go list -m -u all で全モジュールとその更新有無を確認できるため、 -f でgo getコマンドを出力するようにし、必要なものだけ実行すると楽です。

例えばgoplsはそのまま実行すると以下のようになりますが、

$ go list -m -u all
golang.org/x/tools/gopls
github.com/BurntSushi/toml v0.4.1
github.com/davecgh/go-spew v1.1.1
github.com/google/go-cmp v0.5.6
github.com/google/safehtml v0.0.2
github.com/jba/templatecheck v0.6.0
github.com/kr/pretty v0.1.0 [v0.3.0]
github.com/kr/pty v1.1.1 [v1.1.8]
github.com/kr/text v0.1.0 [v0.2.0]
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e
github.com/pmezard/go-difflib v1.0.0
github.com/rogpeppe/go-internal v1.8.0
github.com/sanity-io/litter v1.5.1
github.com/sergi/go-diff v1.1.0 [v1.2.0]
github.com/stretchr/objx v0.1.0 [v0.3.0]
github.com/stretchr/testify v1.4.0 [v1.7.0]
github.com/yuin/goldmark v1.4.1 [v1.4.4]
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 [v0.0.0-20211108221036-ceb1ce70b4fa]
golang.org/x/mod v0.5.1
golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f [v0.0.0-20211112202133-69e39bad7dc2]
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
golang.org/x/sys v0.0.0-20211019181941-9d821ace8654 [v0.0.0-20211113001501-0c823b97ae02]
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 [v0.0.0-20210927222741-03fcf44c2211]
golang.org/x/text v0.3.7
golang.org/x/tools v0.1.7 => ../
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 [v1.0.0-20201130134442-10cb98267c6c]
gopkg.in/errgo.v2 v2.1.0
gopkg.in/yaml.v2 v2.2.4 [v2.4.0]
honnef.co/go/tools v0.2.1 [v0.2.2]
mvdan.cc/gofumpt v0.1.1 [v0.2.0]
mvdan.cc/xurls/v2 v2.3.0

更新があり、直接使っている(not .Indirect)モジュールのgo getコマンドを生成すると以下のようになります。

$ go list -m -u -f '{{if (and .Update (not .Indirect))}}go get {{.Path}}@{{.Update.Version}}{{end}}' all
go get github.com/sergi/go-diff@v1.2.0
go get golang.org/x/sys@v0.0.0-20211113001501-0c823b97ae02
go get honnef.co/go/tools@v0.2.2
go get mvdan.cc/gofumpt@v0.2.0

aws-sdk-go-v2をモックせずにテストする

テストでaws-sdk-go-v2を使う場合はドキュメントにある通り、Clientのモックを用意するのが一般的な手法かと思います。
ただテストのためだけにinterfaceを書きたくないので、aws-sdk-go-v2が提供するClientをそのまま使える形にしたいです。

幸いaws-sdk-go-v2はClientをカスタマイズするためのオプションがあるため、大別して以下の2つの方法で実現可能です。

1つ目はAPIリクエストの送信先を変更する方法です。
こちらはWithEndpointResolverWithHTTPClientを用いることで、リクエストをhttptestで立ち上げたサーバーなど、任意の宛先に送信できます。

2つ目はClientの処理に任意の処理を割り込ませる方法です。
各Clientは下図のStackが実装されており、WithAPIOptionsで任意の処理を追加できるようになっています。 middleware
(詳細はイメージのリンク先へ)

通常はStackを順番に処理していくようになっていますが、途中で次を呼ばずに打ち切ってしまうこともできます。

例えばs3のGetObjectは以下のように呼ぶことでAWSにアクセスせずに"ok"を返せます。

input := &s3.GetObjectInput{
  Bucket: aws.String("bucket"),
  Key:    aws.String("key"),
}
output, err := client.GetObject(ctx, input, s3.WithAPIOptions(func(stack *middleware.Stack) error {
  return stack.Finalize.Add(
    middleware.FinalizeMiddlewareFunc("test",
      func(context.Context, middleware.FinalizeInput, middleware.FinalizeHandler) (middleware.FinalizeOutput, middleware.Metadata, error) {
        return middleware.FinalizeOutput{
          Result: &s3.GetObjectOutput{
            Body: io.NopCloser(strings.NewReader("ok")),
          },
        }, middleware.Metadata{}, nil
      },
    ),
    middleware.Before,
  )
}))

※ s3はFinalizeにリトライ処理があるため、それが呼ばれる前に処理を打ち切ることで数秒のロスを回避できる

しかし、実際のプロダクションコードだと個別のメソッドにオプションを渡すのは難しい形になっているかもしれません。
その際はClientをDIできるようにしておき、config.WithAPIOptionsを使ってClient側にオプションを設定します。

また、WithAPIOptionsはコードがそこそこ大きいのでfunc(*middleware.Stack) errorを返す関数を作成し、応答を渡せるようにしておくと使いやすいです。

以下がテストコードのサンプルです。

package main

import (
    "bytes"
    "context"
    "errors"
    "io"
    "strings"
    "testing"

    "github.com/aws/aws-sdk-go-v2/config"
    "github.com/aws/aws-sdk-go-v2/service/s3"
    "github.com/aws/smithy-go/middleware"
)

type resp struct {
    body string
    err  error
}

func middlewareForGetObject(r resp) func(*middleware.Stack) error {
    return func(stack *middleware.Stack) error {
        return stack.Finalize.Add(
            middleware.FinalizeMiddlewareFunc(
                "test",
                func(context.Context, middleware.FinalizeInput, middleware.FinalizeHandler) (middleware.FinalizeOutput, middleware.Metadata, error) {
                    return middleware.FinalizeOutput{
                        Result: &s3.GetObjectOutput{
                            Body: io.NopCloser(strings.NewReader(r.body)),
                        },
                    }, middleware.Metadata{}, r.err
                },
            ),
            middleware.Before,
        )
    }
}

func Test_GetObject(t *testing.T) {
    type args struct {
        bucket string
        key    string
    }
    tests := []struct {
        name    string
        args    args
        resp    resp
        want    []byte
        wantErr bool
    }{
        {
            name: "success",
            args: args{bucket: "Bucket", key: "Key"},
            resp: resp{body: "ok"},
            want: []byte("ok"),
        },
        {
            name:    "failure",
            args:    args{bucket: "Bucket", key: "Key"},
            resp:    resp{err: errors.New("object not found")},
            wantErr: true,
        },
    }
    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            ctx := context.TODO()

            cfg, err := config.LoadDefaultConfig(ctx,
                config.WithRegion("ap-northeast-1"),
                config.WithAPIOptions([]func(*middleware.Stack) error{middlewareForGetObject(tt.resp)}),
            )
            if err != nil {
                t.Fatal(err)
            }
            client := s3.NewFromConfig(cfg)

            out, err := client.GetObject(ctx, &s3.GetObjectInput{Bucket: &tt.args.bucket, Key: &tt.args.key})
            if (err != nil) != tt.wantErr {
                t.Errorf("GetObject() error = %v, wantErr %v", err, tt.wantErr)
                return
            }
            if tt.wantErr {
                return
            }

            got, _ := io.ReadAll(out.Body)
            if !bytes.Equal(tt.want, got) {
                t.Errorf("GetObject() = %q, want %q", got, tt.want)
            }
        })
    }
}

xorm.EngineGroupで複数のDBをまとめる

データベースをWriterとReaderで分けている時、xorm.ioのEngineGroupを使うことでどちらのDBを使うのかを任せられるようになります。

使い方は通常のNewEngineと似ていて、NewEngineGroupの第2引数に接続先をスライスで渡します。
接続先のスライスは1番目がWriter、2番目以降がReaderになります。

xorm.NewEngineGroup(driverName, []string{dsnWriter, dsnReader})

もしくは既存のEngineを直接渡すこともできます。
この時は第1引数がWriterで第2引数がReaderのスライスです。

xorm.NewEngineGroup(engineWriter, []*xorm.Engine{engineReader})

これで自動的に参照クエリはReader、更新クエリはWriterが使われるようになります。
また、トランザクションの中では参照クエリでもWriterが使われます。

  • Readerが使われる
func get(eg *xorm.EngineGroup, bean interface{}) {
    s := eg.NewSession()
    s.Get(bean)
}
  • Writerが使われる
func insert(eg *xorm.EngineGroup, bean interface{}) {
    s := eg.NewSession()
    s.Insert(bean)
}
func getTx(eg *xorm.EngineGroup, bean interface{}) {
    s := eg.NewSession()
    s.Begin()
    s.Get(bean)
    s.Commit()
}

この動作は以下のコードを使用して確認しました。
(実際のDBに接続していないのでエラーは無視しています)

play.golang.org

ただGoだとxorm以外に同じようなことができるORMを知らないんですが、やっぱりアプリケーションの中でやるよりも外でやった方が良いから他では実装されないんですかね。

goplsと静的解析を活用して変更の影響範囲を調べたい

1000パッケージ弱あるような巨大なリポジトリだと、関数1つの修正でどこまで影響があるのかを調べるのが結構大変*1だったりする。

Vimプラグインを作ったり、goplsを魔改造してみたりしてみたものの、使う人や環境を選ぶし、実行速度もイマイチだったのでもっと使い勝手の良いものが欲しかった。

そこでPull Requestに対して自動的にチェックしてくれると便利そうだったので以下のようなツールを考えてみた。

  1. git diffの結果から、
  2. 変更のあったシンボルの位置を特定し、
  3. (決められたところまで)呼び出し元を辿る

うまく実装できれば公開するかもしれない。

3は前回のblogに書いた方法でLSPのcallHierarchy/incomingCallsを繰り返していけば良いので処理的には比較的簡単。

daisuzu.hatenablog.com

2のシンボルの位置はもっと簡単な方法がありそうな気もしたけど、ast.FileDeclsを使って変更行が範囲内かどうかを調べれば良さそう。

package a

func sum(a, b int) int {
    total := a + b
    return total
}

このsum関数の場合は以下のようになっているので、

Decls: []ast.Decl (len = 1) {
.  0: *ast.FuncDecl {
.  .  Name: *ast.Ident {
.  .  .  NamePos: a.go:3:6
.  .  .  Name: "sum"
.  .  .  Obj: *ast.Object {
.  .  .  .  Kind: func
.  .  .  .  Name: "sum"
.  .  .  .  Decl: *(obj @ 7)
.  .  .  }
.  .  }
.  .  Type: *ast.FuncType {
.  .  .  Func: a.go:3:1
.  .  .  Params: *ast.FieldList {
.  .  .  .  Opening: a.go:3:9
.  .  .  .  List: []*ast.Field (len = 1) {
.  .  .  .  .  0: *ast.Field {
.  .  .  .  .  .  Names: []*ast.Ident (len = 2) {
.  .  .  .  .  .  .  0: *ast.Ident {
.  .  .  .  .  .  .  .  NamePos: a.go:3:10
.  .  .  .  .  .  .  .  Name: "a"
.  .  .  .  .  .  .  .  Obj: *ast.Object {
.  .  .  .  .  .  .  .  .  Kind: var
.  .  .  .  .  .  .  .  .  Name: "a"
.  .  .  .  .  .  .  .  .  Decl: *(obj @ 22)
.  .  .  .  .  .  .  .  }
.  .  .  .  .  .  .  }
.  .  .  .  .  .  .  1: *ast.Ident {
.  .  .  .  .  .  .  .  NamePos: a.go:3:13
.  .  .  .  .  .  .  .  Name: "b"
.  .  .  .  .  .  .  .  Obj: *ast.Object {
.  .  .  .  .  .  .  .  .  Kind: var
.  .  .  .  .  .  .  .  .  Name: "b"
.  .  .  .  .  .  .  .  .  Decl: *(obj @ 22)
.  .  .  .  .  .  .  .  }
.  .  .  .  .  .  .  }
.  .  .  .  .  .  }
.  .  .  .  .  .  Type: *ast.Ident {
.  .  .  .  .  .  .  NamePos: a.go:3:15
.  .  .  .  .  .  .  Name: "int"
.  .  .  .  .  .  }
.  .  .  .  .  }
.  .  .  .  }
.  .  .  .  Closing: a.go:3:18
.  .  .  }
.  .  .  Results: *ast.FieldList {
.  .  .  .  Opening: -
.  .  .  .  List: []*ast.Field (len = 1) {
.  .  .  .  .  0: *ast.Field {
.  .  .  .  .  .  Type: *ast.Ident {
.  .  .  .  .  .  .  NamePos: a.go:3:20
.  .  .  .  .  .  .  Name: "int"
.  .  .  .  .  .  }
.  .  .  .  .  }
.  .  .  .  }
.  .  .  .  Closing: -
.  .  .  }
.  .  }
.  .  Body: *ast.BlockStmt {
.  .  .  Lbrace: a.go:3:24
.  .  .  List: []ast.Stmt (len = 2) {
.  .  .  .  0: *ast.AssignStmt {
.  .  .  .  .  Lhs: []ast.Expr (len = 1) {
.  .  .  .  .  .  0: *ast.Ident {
.  .  .  .  .  .  .  NamePos: a.go:4:2
.  .  .  .  .  .  .  Name: "total"
.  .  .  .  .  .  .  Obj: *ast.Object {
.  .  .  .  .  .  .  .  Kind: var
.  .  .  .  .  .  .  .  Name: "total"
.  .  .  .  .  .  .  .  Decl: *(obj @ 67)
.  .  .  .  .  .  .  }
.  .  .  .  .  .  }
.  .  .  .  .  }
.  .  .  .  .  TokPos: a.go:4:8
.  .  .  .  .  Tok: :=
.  .  .  .  .  Rhs: []ast.Expr (len = 1) {
.  .  .  .  .  .  0: *ast.BinaryExpr {
.  .  .  .  .  .  .  X: *ast.Ident {
.  .  .  .  .  .  .  .  NamePos: a.go:4:11
.  .  .  .  .  .  .  .  Name: "a"
.  .  .  .  .  .  .  .  Obj: *(obj @ 27)
.  .  .  .  .  .  .  }
.  .  .  .  .  .  .  OpPos: a.go:4:13
.  .  .  .  .  .  .  Op: +
.  .  .  .  .  .  .  Y: *ast.Ident {
.  .  .  .  .  .  .  .  NamePos: a.go:4:15
.  .  .  .  .  .  .  .  Name: "b"
.  .  .  .  .  .  .  .  Obj: *(obj @ 36)
.  .  .  .  .  .  .  }
.  .  .  .  .  .  }
.  .  .  .  .  }
.  .  .  .  }
.  .  .  .  1: *ast.ReturnStmt {
.  .  .  .  .  Return: a.go:5:2
.  .  .  .  .  Results: []ast.Expr (len = 1) {
.  .  .  .  .  .  0: *ast.Ident {
.  .  .  .  .  .  .  NamePos: a.go:5:9
.  .  .  .  .  .  .  Name: "total"
.  .  .  .  .  .  .  Obj: *(obj @ 72)
.  .  .  .  .  .  }
.  .  .  .  .  }
.  .  .  .  }
.  .  .  }
.  .  .  Rbrace: a.go:6:1
.  .  }
.  }
}

token.FileSetPosition()Pos()End()を渡せば行番号を取得できる。

1のgit diffは良い方法が思いつかなかったので標準出力をパースしてみる。
余計な情報は減らしておきたいので--diff-filter=Mで変更のあったファイルのみを対象にし、-U0で変わった行だけを出力する。

diff --git a/a.go b/a.go
index 2d1b2ea..4cceff6 100644
--- a/a.go
+++ b/a.go
@@ -4,2 +4 @@ func sum(a, b int) int {
-       total := a + b
-       return total
+       return a + b

最低限必要なのは、

  • bのファイル名(a.go)
  • @@の行の+の後ろにある数字(4)
  • @@以降で+から始まる行がいくつあるか(1つ)

の3つ。
ファイルをparser.ParseFileで開いたら行番号を起点として変更行の数だけシンボルの位置(このケースだとa.go:3:6のみ)を探していく。

ただ、これだけだと関数の位置を移動しただけでも影響があることになってしまうのでこういった部分は除外したい。
そして完全に新規で追加されたシンボルはきっとどこかで使われているはずなのでこれも除外したい。

また、構造体やgoroutineの呼び出しはcallHierarchyが使えないのでtextDocument/referencesや2を使って関数の位置を調べる必要がある。

あとは結果をどう見せるのかも悩ましいところ。
GitHubでリンクになっているのが良いかもしれないし、さらに他のツールと連携することを考えるとgo vetのような形式やJSONの方が扱いやすいかもしれない。

*1:アーキテクチャがおかしいのかもしれないけど、実際そうなってしまっているので。。。

goplsのdaemonモードを使う

goplsgopls -listen=<addr>で実行するとdaemonモードで起動し、指定した<addr>TCP接続できるようになる。

github.com

クライアントはgoplsを使っても良いし、独自に実装することも可能。
その場合、TCP上で以下のような形式のJSON-RPCを送受信すれば良い。
(改行は\r\n)

Content-Length: <JSON部分のbyte数>

{"jsonrpc":"2.0","method":"initialize","params":{},"id":1}
  • レスポンス
Content-Length: <JSON部分のbyte数>

{"jsonrpc":"2.0","result":{},"id":1}

接続ごとに 'initialize''initialized' を送信して初期化したら、あとは 'textDocument/references''callHierarchy/incomingCalls' など、呼びたいメソッドを呼べばOK。

毎回初期化しなくて済むのと、レスポンスをJSONとして扱えるので複雑なことがやりやすくなるはず。

クライアントのサンプル実装はこちら。

import (
    "bytes"
    "encoding/json"
    "fmt"
    "net"
    "sync/atomic"
)

type Client struct {
    id   int64
    conn net.Conn
}

func Connect(addr string, initializedParams map[string]interface{}) (*Client, error) {
    conn, err := net.Dial("tcp", addr)
    if err != nil {
        return nil, err
    }
    client := &Client{conn: conn}

    if _, err := client.Call("initialize", initializedParams); err != nil {
        return nil, err
    }
    if _, err := client.Call("initialized", map[string]interface{}{}); err != nil {
        return nil, err
    }

    return client, nil
}

type response struct {
    ID     int64           `json:"id"`
    Result json.RawMessage `json:"result"`
}

func (c *Client) Call(method string, params interface{}) (*json.RawMessage, error) {
    id := atomic.AddInt64(&c.id, 1)

    data, err := json.Marshal(map[string]interface{}{
        "jsonrpc": "2.0",
        "method":  method,
        "params":  params,
        "id":      id,
    })
    if err != nil {
        return nil, err
    }

    if _, err := fmt.Fprintf(c.conn, "Content-Length: %v\r\n\r\n%s", len(data), data); err != nil {
        return nil, err
    }

    for {
        // Content-Lengthまで読む
        buf := make([]byte, 40)
        n, err := c.conn.Read(buf)
        if err != nil {
            return nil, err
        }

        r := bytes.NewBuffer(buf[:n])
        var length int
        if _, err := fmt.Fscanf(r, "Content-Length: %d\r\n\r\n", &length); err != nil {
            continue
        }

        // bufに入りきらなかったBodyを読む
        body := make([]byte, length)
        idx := copy(body, r.Bytes())
        if _, err = c.conn.Read(body[idx:]); err != nil {
            return nil, err
        }

        var res response
        if err := json.Unmarshal(body, &res); err != nil {
            return nil, err
        }
        if res.ID != id {
            // 送信したリクエストに対するレスポンス以外は無視
            // (goplsからの通知を含む)
            continue
        }
        return &res.Result, nil
    }
}

func (c *Client) Shutdown() error {
    _, err := c.Call("shutdown", nil)
    return err
}

xorm更新用のテストを静的解析で生成した時のメモ

前に書いた↓を使ってxormのバージョンを上げようと思っていたけど、生成されるSQLがおかしくなることがあると聞いたのでさらにテストを拡充することにした。

daisuzu.hatenablog.com

といってもまだ完成していないので、ここまでやったことを備忘録*1として残しておいて続きは連休明けにやる予定。

方針としてはxormがDBのドライバーを呼んだ際のクエリを記録・比較できるようにし、それを実行するテストを自動生成するというもの。

なぜそうしたかというと、理由は主に次の2点。

  • できるだけ短時間でテストを追加したい
    • 対象となるテストが膨大なので一つずつ書いていられない
      • 不要なテストを精査することも厳しい
  • 実際のDBに接続したくない
    • 他のテストの影響を受けたくないし、与えたくない
    • CIで時間がかかるようになってしまうのも困る

ということで以下にだらだらと書いていく。

1. ダミードライバーを作る

DBはMySQLを使っていて、ドライバーは固定されていたのでまずはこれを変えられるようにする。

package database

func NewORM() ORM { // ORMはinterface
    // 略
    e, err := xorm.NewEngine("mysql", dsn)
    // 略
}

既存への影響を最小限にしたかったのでビルドタグで切り替えることにした。

driver.go

// +build !dummy
//go:build !dummy

package database

const driverName = "mysql"

driver_dummy.go

// +build dummy
//go:build dummy

package database

const driverName = "dummy"

ダミーの方は代わりとなるドライバーも実装しておく。
テストはrepositoryレイヤのメソッド単位にするつもりなのでPrepareExecorQueryの引数を記録したらその時点で処理は終わらせてしまう。
(できれば返ってくるエラーが一致するかもチェックしたいが難しそう)

func init() {
    sql.Register(driverName, dummyDriver{})
    core.RegisterDriver(driverName, dummyDriver{})
}

// databse/sql用
type dummyDriver struct{}

// https://pkg.go.dev/database/sql/driver#Driver
func (dummyDriver) Open(dsn string) (driver.Conn, error) {
    return nil, errors.New("not implemented")
}

// https://pkg.go.dev/database/sql/driver#DriverContext
func (dummyDriver) OpenConnector(dsn string) (driver.Connector, error) {
    return connector{}, nil
}

// https://pkg.go.dev/database/sql/driver#Connector
type connector struct{}

func (connector) Connect(ctx context.Context) (driver.Conn, error) {
    return conn{}, nil
}

func (connector) Driver() driver.Driver {
    return dummyDriver{}
}

// https://pkg.go.dev/database/sql/driver#Conn
type conn struct{}

func (conn) Prepare(query string) (driver.Stmt, error) {
    return &stmt{
        numInput: strings.Count(query, "?"),
        q:        dbtest.NewQueryLog(query),
    }, nil
}

func (conn) Close() error {
    return nil
}

func (conn) Begin() (driver.Tx, error) {
    return nil, errors.New("not implemented")
}

// https://pkg.go.dev/database/sql/driver#Stmt
type stmt struct {
    numInput int
    q        *QueryLog
}

func (s *stmt) Close() error {
    return nil
}

func (s *stmt) NumInput() int {
    return s.numInput
}

func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
    s.q.SetArgs(args)
    return nil, errors.New("abort")
}
func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
    s.q.SetArgs(args)
    return nil, errors.New("abort")
}

// https://pkg.go.dev/github.com/go-xorm/core#Driver
func (dummyDriver) Parse(string, string) (*core.Uri, error) {
    return &core.Uri{DbType: core.DbType("mysql")}, nil
}

2. テストコードを考える

以下のようなテンプレートを考えた。

{{range .}}
func Test{{.RepoName}}(t *testing.T) {
    orm := database.NewORM()
    ctx := context.WithValue(context.Background(), contextKey, orm)
    repo := xxx.New{{.RepoName}}(ctx) // 引数がormの場合もある

    {{range .Subtests}}
    t.Run("{{.MethodName}}", func(t *testing.T) {
        dbtest.RegisterTestName(t)
        _, ... = repo.{{.MethodName}}(ctx, ...)
        dbtest.CheckSQL(t)
    })
    {{end -}}
}
{{end -}}

これに合わせてSQLの記録と比較をするためのdbtestパッケージを作る。

package dbtest

import (
    "database/sql/driver"
    "encoding/json"
    "os"
    "path/filepath"
    "reflect"
    "testing"

    "github.com/google/go-cmp/cmp"
)

var currentTest string

// サブテストの最初に呼ぶ
func RegisterTestName(t *testing.T) {
    currentTest = t.Name()
}

type QueryLog struct {
    Query string
    Args  []driver.Value
}

var queryLogs = map[string]*QueryLog{}

// サブテストごとにクエリを記録する
func NewQueryLog(query string) *QueryLog {
    q := &QueryLog{Query: query}
    queryLogs[currentTest] = q
    return q
}

func (q *QueryLog) SetArgs(args []driver.Value) {
    q.Args = args
}

// 記録したクエリをgoldenファイルと比較する
func CheckSQL(t *testing.T, opts ...cmp.Option) {
    t.Helper()

    got := queryLogs[t.Name()]

    if *updateGolden {
        writeGolden(t, got)
        return
    }

    // gotはint64、goldenはfloat64になるので常にfloat64で比較する
    // https://daisuzu.hatenablog.com/entry/2021/01/08/145459
    opts = append(opts, cmp.FilterValues(func(x, y interface{}) bool {
        return isNumber(x) && isNumber(y)
    }, cmp.Comparer(func(x, y interface{}) bool {
        return cmp.Equal(toFloat64(x), toFloat64(y))
    })))

    if diff := cmp.Diff(readGolden(t), got, opts...); diff != "" {
        t.Errorf("SQL mismatch (-want +got):\n%s", diff)
    }
}

3. 静的解析ツールを作る

あとはテンプレートに必要な情報を集めればOK。

repositoryレイヤは以下のようになっているため、

type XXXRepository struct {
    repo.RootRepository
}

func NewXXXRepository(ctx context.Context) XXXRepository {
    // 略
}

func (r XXXRepository) Method()
  1. フィールドにRootRepositoryがある型を探す
  2. 1.の型を返す関数(コンストラクタ)を探す
  3. 1.の型のメソッドを探す

の順に解析していけば必要な情報が揃えられる。

3-1. 対象のrepositoryを探す

型や関数が定義されているファイルや行の位置によってうまく解析できないと困るので1〜3は個別に収集していくことにした。

まずは1のrepository本体。
型の名前は後続のAnalyzerで使い、ファイル名をテストファイルのprefixにする。

var repoCollector = &analysis.Analyzer{
    Name:       "repocollector",
    Doc:        "collect repository definition",
    Run:        collectRepository,
    ResultType: reflect.TypeOf(new(Repository)),
    Requires:   []*analysis.Analyzer{inspect.Analyzer},
}

type Repository struct {
    types map[string]string
}

func collectRepository(pass *analysis.Pass) (interface{}, error) {
    inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)

    nodeFilter := []ast.Node{
        (*ast.TypeSpec)(nil),
    }

    result := &Repository{types: make(map[string]string)}

    inspect.Preorder(nodeFilter, func(n ast.Node) {
        ts := n.(*ast.TypeSpec)
        st, ok := ts.Type.(*ast.StructType)
        if !ok {
            return
        }
        if !hasRootRepository(st.Fields.List) {
            return
        }
        result.types[ts.Name.Name] = filepath.Base(pass.Fset.File(n.Pos()).Name())
    })

    return result, nil
}

3-2. repositoryのコンストラクタを探す

次はコンストラクタ(Newから始まる関数)。
関数名とParamsと、1つのものしか無かったが念のためResultsも全て保持しておく。

var newCollector = &analysis.Analyzer{
    Name:       "newcollector",
    Doc:        "collect constructor",
    Run:        collectConstructor,
    ResultType: reflect.TypeOf(new(Constructor)),
    Requires:   []*analysis.Analyzer{inspect.Analyzer, repoCollector},
}

type Constructor struct {
    funcs map[string]*Func
}

type Func struct {
    name    string
    params  []*ast.Field
    results []*ast.Field
}

func collectConstructor(pass *analysis.Pass) (interface{}, error) {
    inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
    repo := pass.ResultOf[repoCollector].(*Repository)

    nodeFilter := []ast.Node{
        (*ast.FuncDecl)(nil),
    }

    result := &Constructor{funcs: make(map[string]*Func)}

    inspect.Preorder(nodeFilter, func(n ast.Node) {
        fd := n.(*ast.FuncDecl)
        if fd.Type.Results == nil {
            // 何も返さない関数はコンストラクタではない
            return
        }

        typeName, ok := repo.isConstructor(fd.Type.Results.List)
        if !ok {
            // collectRepository()で収集した型のみが対象
            return
        }

        result.funcs[typeName] = &Func{
            name:    fd.Name.Name,
            params:  fd.Type.Params.List,
            results: fd.Type.Results.List,
        }
    })

    return result, nil
}

3-3. メソッドを探す

最後はメソッド。
このAnalyzerでテストコードの生成も行う。

var analyzer = &analysis.Analyzer{
    Name:     "gensqltest",
    Doc:      "generate tests",
    Run:      run,
    Requires: []*analysis.Analyzer{inspect.Analyzer, repoCollector, newCollector},
}

type Cases []Case

type Case struct {
    PkgPath     string
    RepoName    string
    Constructor string
    Subtests    []Method
}

type Method struct {
    Name    string
    pkg     string
    params  []*ast.Field
    results []*ast.Field
}

func run(pass *analysis.Pass) (interface{}, error) {
    inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
    repo := pass.ResultOf[repoCollector].(*Repository)
    constructor := pass.ResultOf[newCollector].(*Constructor)

    nodeFilter := []ast.Node{
        (*ast.FuncDecl)(nil),
    }

    var order []string // 生成するRepositoryの順番を固定するために使う
    tests := make(map[string][]Method)
    inspect.Preorder(nodeFilter, func(n ast.Node) {
        fd := n.(*ast.FuncDecl)
        if fd.Recv == nil {
            // メソッドじゃない関数は対象外
            return
        }

        if !token.IsExported(fd.Name.Name) {
            // 非公開メソッドは対象外
            return
        }

        typeName := strings.TrimPrefix(types.ExprString(fd.Recv.List[0].Type), "*")
        if ok := repo.isMethod(typeName); !ok {
            // 一致するレシーバがないものは対象外
            return
        }

        if _, ok := tests[typeName]; !ok {
            // メソッドが見つかった順にテスト関数を出力する
            order = append(order, typeName)
        }

        m := Method{Name: fd.Name.Name, pkg: pass.Pkg.Name()}
        if fd.Type.Params != nil {
            m.params = fd.Type.Params.List
        }
        if fd.Type.Results != nil {
            m.results = fd.Type.Results.List
        }
        tests[typeName] = append(tests[typeName], m)
    })

    if len(order) == 0 {
        return nil, nil
    }

    testFiles := make(map[string]Cases)
    for _, v := range order {
        code := constructor.genConstructorCode(pass.Pkg.Name(), v)
        if code == "" {
            continue
        }
        testFiles[repo.testFileName(v)] = append(testFiles[repo.testFileName(v)], Case{
            PkgPath:     pass.Pkg.Path(), // importに追加する
            RepoName:    v,
            Constructor: code,
            Subtests:    tests[v],
        })
    }

    t := template.Must(template.New("test").Parse(tpl))
    for k, v := range testFiles {
        var out bytes.Buffer
        if err := t.Execute(&out, v); err != nil {
            log.Println(err)
            continue
        }

        // goimportsをかける
        b, err := imports.Process(k, out.Bytes(), nil)
        if err != nil {
            log.Println(err)
            continue
        }

        if err := os.WriteFile(k, b, 0600); err != nil {
            log.Println(err)
        }
    }

    return nil, nil
}

repo := xxx.New{{.RepoName}}(ctx) // 引数がormの場合もある

コンストラクタはイレギュラーがあった時に対応しやすいのでGo側で生成することにした。
基本的にはrepo := xxx.NewXXXRepository(ctx)repo := xxx.NewXXXRepository(orm)のどちらかになる。

_, ... = repo.{{.MethodName}}(ctx, ...)

また、こちらも同様にGo側で行全体を生成することにした。
(なんだかんだでかなり泥臭いコードになってしまったけど...)

func (m Method) Call() string {
    var b strings.Builder

    // 左辺を作る
    if len(m.results) > 0 {
        ret := strings.Repeat("_,", len(m.results))
        b.WriteString(ret[:len(ret)-1] + " = ")
    }

    // 右辺を作る(引数は適当な値を詰める)
    b.WriteString("repo." + m.Name + "(")
    args := make([]string, 0, len(m.params))
    for i, v := range m.params {
        switch t := v.Type.(type) {
        case *ast.Ident:
            switch t.Name {
            case "int", "int64", "float", "float64":
                for j := range v.Names {
                    // `(a, b int)` のようなケースへの対応
                    args = append(args, strconv.Itoa(i+j))
                }
            case "string":
                for _, vv := range v.Names {
                    args = append(args, strconv.Quote(vv.Name))
                }
            default:
                typ := types.ExprString(v.Type)
                if token.IsExported(typ) {
                    typ = m.pkg + "." + typ + "{}"
                }
                for range v.Names {
                    args = append(args, typ)
                }
            }
        case *ast.SelectorExpr:
            if t.Sel.Name == "Context" {
                args = append(args, "ctx")
            } else {
                for range v.Names {
                    args = append(args, types.ExprString(t)+"{}")
                }
            }
        case *ast.StarExpr:
            // 略
        case *ast.InterfaceType:
            // 略
        case *ast.MapType:
            // 略
        case *ast.ArrayType:
            // 略
        case *ast.Ellipsis:
            // 略
        }
    }
    b.WriteString(strings.Join(args, ",") + ")")

    return b.String()
}

なのでテンプレートは以下のようになった。

{{range .}}
func Test{{.RepoName}}(t *testing.T) {
    orm := database.NewORM()
    ctx := context.WithValue(context.Background(), contextKey, orm)
    {{.Constructor}}

    {{range .Subtests}}
    t.Run("{{.MethodName}}", func(t *testing.T) {
        dbtest.RegisterTestName(t)
        {{.Call}}
        dbtest.CheckSQL(t)
    })
    {{end -}}
}
{{end -}}

4. TODO

これで1000件弱のサブテストを生成してみたところ、ビルドできなかったり実行時にpanicするのが10件ほどあった。
一旦はコメントアウト状態でコードを生成したりt.Skipを差し込むようにしているが、手動で直すなり分岐を追加してちゃんと通るコードにしないといけない。
(というのもあってコードはだいぶ省略している)

それから比較でFAILするのも10数件あったのでcmp.Optionを使うか、その他の方法でPASSするようにしないといけない。

そして新たに追加されるrepositoryやメソッドをどうするかもまだ考えていない。

*1:+チームメンバーに自分の思考を共有できればと思ってたり

みんなで書くGoのエンドポイントテスト

Webアプリケーションサーバーに何か大きな変更をしたいけど、既存のテストだと心許なかったので各エンドポイントにHandlerからのテストを追加することにした。

ただ全部のテストを自分1人で作っていくのはボリューム的に現実的ではなかったので、どうしたらチーム全員が書きやすいテストになるか考えて色々と整備してみた。

テストの書き方がある程度決まっている

エンドポイントごとにスタイルがバラバラだと都度どう書くか考えなければいけなくなってしまうため、基本的にはリクエストとレスポンスだけテーブルに指定するスタイルが良さそうだと考えた。

簡略化すると以下のような形式。

func TestFoo_Get(t *testing.T) {
    tests := []struct {
        name string
        // ヘッダやクエリパラメータなど
        // 期待するレスポンス
    }{
        // 実際のテストケース
    }
    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            r := httptest.NewRequest("GET", "/api/foo", nil)
            RunTest(t, r, tt.want)
        })
    }
}

しかし、ヘッダやクエリパラメータの有無によってはそれに応じた処理を書かないといけないので、それを吸収するための関数を用意することにした。

type RequestOption func(*http.Request)

func WithQuery(key, value string) RequestOption {
    return func(r *http.Request) {
        q := r.URL.Query()
        q.Set(key, value)
        r.URL.RawQuery = q.Encode()
    }
}

func WithHeader(key, value string) RequestOption {
    return func(r *http.Request) {
        r.Header.Set(key, value)
    }
}

func NewRequest(method, endpoint string, body io.Reader, options ...RequestOption) *http.Request {
    r := httptest.NewRequest(method, endpoint, body)
    for _, opt := range options {
        opt(r)
    }
    return r
}

また、POSTやPUTでJSONを送る場合は以下の関数でボディを作れるようにした。

func JSONBody(t *testing.T, m map[string]interface{}) io.Reader {
    t.Helper()

    body := new(bytes.Buffer)
    if err := json.NewEncoder(body).Encode(&m); err != nil {
        t.Fatal(err)
    }
    return body
}

期待する結果(want)を全て書かなくても良い

レスポンスはエンドポイントによってはかなり大きくなることもあり、毎回全体を書くのは大変そうだったので避けたかった。
そしてレスポンスが変わるたびに毎回手動で全て直さないといけないのも面倒なのでgoldenファイル化することにした。

値が固定されないところもあるので、そこはレスポンスを柔軟に書き換えられるようにしている。*1
例えばJSONが返ってくるエンドポイントであれば以下のような関数。*2

type ResponseFilter func(t *testing.T, r *http.Response)

func ModJSONFields(overwrite map[string]interface{}) ResponseFilter {
    return func(t *testing.T, r *http.Response) {
        t.Helper()

        var tmp map[string]interface{}
        if err := json.NewDecoder(r.Body).Decode(&tmp); err != nil {
            t.Fatal(err)
        }

        rewriteMap(t, tmp, overwrite)

        body := new(bytes.Buffer)
        if err := json.NewEncoder(body).Encode(&tmp); err != nil {
            t.Fatal(err)
        }
        r.Body = io.NopCloser(body)
    }
}

これでRunTestは以下のようになる。

var (
    handler http.Handler

    updateGolden = flag.Bool("golden", false, "Update golden files")
)

func RunTest(t *testing.T, r *http.Request, want int, filters ...ResponseFilter) {
    t.Helper()

    w := httptest.NewRecorder()
    handler.ServeHTTP(w, r)

    got := w.Result()
    if got.StatusCode != want {
        t.Errorf("HTTP StatusCode = %d, want %d", got.StatusCode, want)
    }

    for _, f := range filters {
        f(t, got)
    }

    dump, err := httputil.DumpResponse(got, true)
    if err != nil {
        t.Fatal(err)
    }

    if *updateGolden {
        writeGolden(t, dump)
    } else {
        golden := readGolden(t)
        if diff := cmp.Diff(golden, dump); diff != "" {
            t.Errorf("HTTP Response mismatch (-want +got):\n%s", diff)
        }
    }
}

httptest.Serverを使わなかったのはモックが無いとどうにもならなくなった時に最悪contextに何か詰めてどうにかしようと思ったからなんだけど、今のところその必要はなさそう。

テストの前後で必要な処理がわかる

これだけで良ければとても楽なんだけど、一番大変なのは必要なリソースの準備なはず。
今回対象としたWebアプリはxorm経由でMySQLを使っているため、テスト実行時に出力されるxormのログを分析するツールを用意した。

go test -v -run TestFoo_Get | go run $PATH_TO_TOOL のように使うことで、サブテストごとにアクセスのあったテーブルを表示したり、setup.sqlcleanup.sqlを生成できる。
まだそのまま使えるSQLにはならないので手動で直さないといけないけど、何も無いよりはだいぶマシかな。

func SetupDB(t *testing.T) {
    t.Helper()

    execSQL(t, "setup.sql")
}

func CleanupDB(t *testing.T) {
    t.Helper()

    execSQL(t, "cleanup.sql")
}

var db *sql.DB

func execSQL(t *testing.T, sqlfile string) {
    t.Helper()

    filename := filepath.Join("testdata", t.Name(), sqlfile)
    file, err := os.ReadFile(filename)
    if os.IsNotExist(err) {
        return
    }
    if err != nil {
        t.Fatal(err)
    }

    if _, err := db.Exec(string(file)); err != nil {
        log.Fatal(err)
    }
}

具体例

ここまできたらあとは

  1. テスト関数を作る
  2. go test -v -run TestFoo_Get | go run $PATH_TO_TOOL する
  3. setup.sqlcleanup.sql を修正する
  4. go test -v -run TestFoo_Get -golden する
  5. goldenファイルの中身を確認する
  6. go test -v -run TestFoo_Get でPASSすることを確認する

の流れで以下を量産していくだけ。

func TestFoo_Get(t *testing.T) {
    SetupDB(t) // TestFoo_Get/setup.sqlがあれば実行する
    t.Cleanup(func() {
        CleanupDB(t) // TestFoo_Get/cleanup.sqlがあれば実行する
    })

    tests := []struct {
        name string
        opts []RequestOption
        want int
    }{
        {
            name: "found",
            opts: []RequestOption{WithQuery("limit", "10")},
            want: http.StatusOK,
        },
        {
            name: "invalid limit",
            opts: []RequestOption{WithQuery("limit", "abc")},
            want: http.StatusBadRequest,
        },
    }
    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            SetupDB(t) // TestFoo_Get/サブテスト名/setup.sqlがあれば実行する
            t.Cleanup(func() {
                CleanupDB(t) // TestFoo_Get/サブテスト名/cleanup.sqlがあれば実行する
            })

            r := NewRequest("GET", "/api/foo", nil, tt.opts...)
            RunTest(t, r, tt.want,
                ModJSONFields(map[string]interface{}{
                    "created_at": "2006-01-02 15:04:05",
                }),
            )
        })
    }
}

とりあえず次の改修には十分なテストが追加できたので安心して変更できそう。

それでも今後のことを考えるともう少しテストを増やしておきたいので既存のエンドポイントにテスト追加を促すlinterでも作りたいところ。
なお、新規エンドポイント追加時にテストが無かったら警告するlinterは導入済み。

*1:go-cmpのオプションは難しそうだったのでやらなかったのと、この形式ならJSONを整形する関数なんかも簡単に作れる

*2:他にはCookieやHTMLなんかを加工したり