goplsのdaemonモードを使う

goplsgopls -listen=<addr>で実行するとdaemonモードで起動し、指定した<addr>TCP接続できるようになる。

github.com

クライアントはgoplsを使っても良いし、独自に実装することも可能。
その場合、TCP上で以下のような形式のJSON-RPCを送受信すれば良い。
(改行は\r\n)

Content-Length: <JSON部分のbyte数>

{"jsonrpc":"2.0","method":"initialize","params":{},"id":1}
  • レスポンス
Content-Length: <JSON部分のbyte数>

{"jsonrpc":"2.0","result":{},"id":1}

接続ごとに 'initialize''initialized' を送信して初期化したら、あとは 'textDocument/references''callHierarchy/incomingCalls' など、呼びたいメソッドを呼べばOK。

毎回初期化しなくて済むのと、レスポンスをJSONとして扱えるので複雑なことがやりやすくなるはず。

クライアントのサンプル実装はこちら。

import (
    "bytes"
    "encoding/json"
    "fmt"
    "net"
    "sync/atomic"
)

type Client struct {
    id   int64
    conn net.Conn
}

func Connect(addr string, initializedParams map[string]interface{}) (*Client, error) {
    conn, err := net.Dial("tcp", addr)
    if err != nil {
        return nil, err
    }
    client := &Client{conn: conn}

    if _, err := client.Call("initialize", initializedParams); err != nil {
        return nil, err
    }
    if _, err := client.Call("initialized", map[string]interface{}{}); err != nil {
        return nil, err
    }

    return client, nil
}

type response struct {
    ID     int64           `json:"id"`
    Result json.RawMessage `json:"result"`
}

func (c *Client) Call(method string, params interface{}) (*json.RawMessage, error) {
    id := atomic.AddInt64(&c.id, 1)

    data, err := json.Marshal(map[string]interface{}{
        "jsonrpc": "2.0",
        "method":  method,
        "params":  params,
        "id":      id,
    })
    if err != nil {
        return nil, err
    }

    if _, err := fmt.Fprintf(c.conn, "Content-Length: %v\r\n\r\n%s", len(data), data); err != nil {
        return nil, err
    }

    for {
        // Content-Lengthまで読む
        buf := make([]byte, 40)
        n, err := c.conn.Read(buf)
        if err != nil {
            return nil, err
        }

        r := bytes.NewBuffer(buf[:n])
        var length int
        if _, err := fmt.Fscanf(r, "Content-Length: %d\r\n\r\n", &length); err != nil {
            continue
        }

        // bufに入りきらなかったBodyを読む
        body := make([]byte, length)
        idx := copy(body, r.Bytes())
        if _, err = c.conn.Read(body[idx:]); err != nil {
            return nil, err
        }

        var res response
        if err := json.Unmarshal(body, &res); err != nil {
            return nil, err
        }
        if res.ID != id {
            // 送信したリクエストに対するレスポンス以外は無視
            // (goplsからの通知を含む)
            continue
        }
        return &res.Result, nil
    }
}

func (c *Client) Shutdown() error {
    _, err := c.Call("shutdown", nil)
    return err
}

xorm更新用のテストを静的解析で生成した時のメモ

前に書いた↓を使ってxormのバージョンを上げようと思っていたけど、生成されるSQLがおかしくなることがあると聞いたのでさらにテストを拡充することにした。

daisuzu.hatenablog.com

といってもまだ完成していないので、ここまでやったことを備忘録*1として残しておいて続きは連休明けにやる予定。

方針としてはxormがDBのドライバーを呼んだ際のクエリを記録・比較できるようにし、それを実行するテストを自動生成するというもの。

なぜそうしたかというと、理由は主に次の2点。

  • できるだけ短時間でテストを追加したい
    • 対象となるテストが膨大なので一つずつ書いていられない
      • 不要なテストを精査することも厳しい
  • 実際のDBに接続したくない
    • 他のテストの影響を受けたくないし、与えたくない
    • CIで時間がかかるようになってしまうのも困る

ということで以下にだらだらと書いていく。

1. ダミードライバーを作る

DBはMySQLを使っていて、ドライバーは固定されていたのでまずはこれを変えられるようにする。

package database

func NewORM() ORM { // ORMはinterface
    // 略
    e, err := xorm.NewEngine("mysql", dsn)
    // 略
}

既存への影響を最小限にしたかったのでビルドタグで切り替えることにした。

driver.go

// +build !dummy
//go:build !dummy

package database

const driverName = "mysql"

driver_dummy.go

// +build dummy
//go:build dummy

package database

const driverName = "dummy"

ダミーの方は代わりとなるドライバーも実装しておく。
テストはrepositoryレイヤのメソッド単位にするつもりなのでPrepareExecorQueryの引数を記録したらその時点で処理は終わらせてしまう。
(できれば返ってくるエラーが一致するかもチェックしたいが難しそう)

func init() {
    sql.Register(driverName, dummyDriver{})
    core.RegisterDriver(driverName, dummyDriver{})
}

// databse/sql用
type dummyDriver struct{}

// https://pkg.go.dev/database/sql/driver#Driver
func (dummyDriver) Open(dsn string) (driver.Conn, error) {
    return nil, errors.New("not implemented")
}

// https://pkg.go.dev/database/sql/driver#DriverContext
func (dummyDriver) OpenConnector(dsn string) (driver.Connector, error) {
    return connector{}, nil
}

// https://pkg.go.dev/database/sql/driver#Connector
type connector struct{}

func (connector) Connect(ctx context.Context) (driver.Conn, error) {
    return conn{}, nil
}

func (connector) Driver() driver.Driver {
    return dummyDriver{}
}

// https://pkg.go.dev/database/sql/driver#Conn
type conn struct{}

func (conn) Prepare(query string) (driver.Stmt, error) {
    return &stmt{
        numInput: strings.Count(query, "?"),
        q:        dbtest.NewQueryLog(query),
    }, nil
}

func (conn) Close() error {
    return nil
}

func (conn) Begin() (driver.Tx, error) {
    return nil, errors.New("not implemented")
}

// https://pkg.go.dev/database/sql/driver#Stmt
type stmt struct {
    numInput int
    q        *QueryLog
}

func (s *stmt) Close() error {
    return nil
}

func (s *stmt) NumInput() int {
    return s.numInput
}

func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
    s.q.SetArgs(args)
    return nil, errors.New("abort")
}
func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
    s.q.SetArgs(args)
    return nil, errors.New("abort")
}

// https://pkg.go.dev/github.com/go-xorm/core#Driver
func (dummyDriver) Parse(string, string) (*core.Uri, error) {
    return &core.Uri{DbType: core.DbType("mysql")}, nil
}

2. テストコードを考える

以下のようなテンプレートを考えた。

{{range .}}
func Test{{.RepoName}}(t *testing.T) {
    orm := database.NewORM()
    ctx := context.WithValue(context.Background(), contextKey, orm)
    repo := xxx.New{{.RepoName}}(ctx) // 引数がormの場合もある

    {{range .Subtests}}
    t.Run("{{.MethodName}}", func(t *testing.T) {
        dbtest.RegisterTestName(t)
        _, ... = repo.{{.MethodName}}(ctx, ...)
        dbtest.CheckSQL(t)
    })
    {{end -}}
}
{{end -}}

これに合わせてSQLの記録と比較をするためのdbtestパッケージを作る。

package dbtest

import (
    "database/sql/driver"
    "encoding/json"
    "os"
    "path/filepath"
    "reflect"
    "testing"

    "github.com/google/go-cmp/cmp"
)

var currentTest string

// サブテストの最初に呼ぶ
func RegisterTestName(t *testing.T) {
    currentTest = t.Name()
}

type QueryLog struct {
    Query string
    Args  []driver.Value
}

var queryLogs = map[string]*QueryLog{}

// サブテストごとにクエリを記録する
func NewQueryLog(query string) *QueryLog {
    q := &QueryLog{Query: query}
    queryLogs[currentTest] = q
    return q
}

func (q *QueryLog) SetArgs(args []driver.Value) {
    q.Args = args
}

// 記録したクエリをgoldenファイルと比較する
func CheckSQL(t *testing.T, opts ...cmp.Option) {
    t.Helper()

    got := queryLogs[t.Name()]

    if *updateGolden {
        writeGolden(t, got)
        return
    }

    // gotはint64、goldenはfloat64になるので常にfloat64で比較する
    // https://daisuzu.hatenablog.com/entry/2021/01/08/145459
    opts = append(opts, cmp.FilterValues(func(x, y interface{}) bool {
        return isNumber(x) && isNumber(y)
    }, cmp.Comparer(func(x, y interface{}) bool {
        return cmp.Equal(toFloat64(x), toFloat64(y))
    })))

    if diff := cmp.Diff(readGolden(t), got, opts...); diff != "" {
        t.Errorf("SQL mismatch (-want +got):\n%s", diff)
    }
}

3. 静的解析ツールを作る

あとはテンプレートに必要な情報を集めればOK。

repositoryレイヤは以下のようになっているため、

type XXXRepository struct {
    repo.RootRepository
}

func NewXXXRepository(ctx context.Context) XXXRepository {
    // 略
}

func (r XXXRepository) Method()
  1. フィールドにRootRepositoryがある型を探す
  2. 1.の型を返す関数(コンストラクタ)を探す
  3. 1.の型のメソッドを探す

の順に解析していけば必要な情報が揃えられる。

3-1. 対象のrepositoryを探す

型や関数が定義されているファイルや行の位置によってうまく解析できないと困るので1〜3は個別に収集していくことにした。

まずは1のrepository本体。
型の名前は後続のAnalyzerで使い、ファイル名をテストファイルのprefixにする。

var repoCollector = &analysis.Analyzer{
    Name:       "repocollector",
    Doc:        "collect repository definition",
    Run:        collectRepository,
    ResultType: reflect.TypeOf(new(Repository)),
    Requires:   []*analysis.Analyzer{inspect.Analyzer},
}

type Repository struct {
    types map[string]string
}

func collectRepository(pass *analysis.Pass) (interface{}, error) {
    inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)

    nodeFilter := []ast.Node{
        (*ast.TypeSpec)(nil),
    }

    result := &Repository{types: make(map[string]string)}

    inspect.Preorder(nodeFilter, func(n ast.Node) {
        ts := n.(*ast.TypeSpec)
        st, ok := ts.Type.(*ast.StructType)
        if !ok {
            return
        }
        if !hasRootRepository(st.Fields.List) {
            return
        }
        result.types[ts.Name.Name] = filepath.Base(pass.Fset.File(n.Pos()).Name())
    })

    return result, nil
}

3-2. repositoryのコンストラクタを探す

次はコンストラクタ(Newから始まる関数)。
関数名とParamsと、1つのものしか無かったが念のためResultsも全て保持しておく。

var newCollector = &analysis.Analyzer{
    Name:       "newcollector",
    Doc:        "collect constructor",
    Run:        collectConstructor,
    ResultType: reflect.TypeOf(new(Constructor)),
    Requires:   []*analysis.Analyzer{inspect.Analyzer, repoCollector},
}

type Constructor struct {
    funcs map[string]*Func
}

type Func struct {
    name    string
    params  []*ast.Field
    results []*ast.Field
}

func collectConstructor(pass *analysis.Pass) (interface{}, error) {
    inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
    repo := pass.ResultOf[repoCollector].(*Repository)

    nodeFilter := []ast.Node{
        (*ast.FuncDecl)(nil),
    }

    result := &Constructor{funcs: make(map[string]*Func)}

    inspect.Preorder(nodeFilter, func(n ast.Node) {
        fd := n.(*ast.FuncDecl)
        if fd.Type.Results == nil {
            // 何も返さない関数はコンストラクタではない
            return
        }

        typeName, ok := repo.isConstructor(fd.Type.Results.List)
        if !ok {
            // collectRepository()で収集した型のみが対象
            return
        }

        result.funcs[typeName] = &Func{
            name:    fd.Name.Name,
            params:  fd.Type.Params.List,
            results: fd.Type.Results.List,
        }
    })

    return result, nil
}

3-3. メソッドを探す

最後はメソッド。
このAnalyzerでテストコードの生成も行う。

var analyzer = &analysis.Analyzer{
    Name:     "gensqltest",
    Doc:      "generate tests",
    Run:      run,
    Requires: []*analysis.Analyzer{inspect.Analyzer, repoCollector, newCollector},
}

type Cases []Case

type Case struct {
    PkgPath     string
    RepoName    string
    Constructor string
    Subtests    []Method
}

type Method struct {
    Name    string
    pkg     string
    params  []*ast.Field
    results []*ast.Field
}

func run(pass *analysis.Pass) (interface{}, error) {
    inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
    repo := pass.ResultOf[repoCollector].(*Repository)
    constructor := pass.ResultOf[newCollector].(*Constructor)

    nodeFilter := []ast.Node{
        (*ast.FuncDecl)(nil),
    }

    var order []string // 生成するRepositoryの順番を固定するために使う
    tests := make(map[string][]Method)
    inspect.Preorder(nodeFilter, func(n ast.Node) {
        fd := n.(*ast.FuncDecl)
        if fd.Recv == nil {
            // メソッドじゃない関数は対象外
            return
        }

        if !token.IsExported(fd.Name.Name) {
            // 非公開メソッドは対象外
            return
        }

        typeName := strings.TrimPrefix(types.ExprString(fd.Recv.List[0].Type), "*")
        if ok := repo.isMethod(typeName); !ok {
            // 一致するレシーバがないものは対象外
            return
        }

        if _, ok := tests[typeName]; !ok {
            // メソッドが見つかった順にテスト関数を出力する
            order = append(order, typeName)
        }

        m := Method{Name: fd.Name.Name, pkg: pass.Pkg.Name()}
        if fd.Type.Params != nil {
            m.params = fd.Type.Params.List
        }
        if fd.Type.Results != nil {
            m.results = fd.Type.Results.List
        }
        tests[typeName] = append(tests[typeName], m)
    })

    if len(order) == 0 {
        return nil, nil
    }

    testFiles := make(map[string]Cases)
    for _, v := range order {
        code := constructor.genConstructorCode(pass.Pkg.Name(), v)
        if code == "" {
            continue
        }
        testFiles[repo.testFileName(v)] = append(testFiles[repo.testFileName(v)], Case{
            PkgPath:     pass.Pkg.Path(), // importに追加する
            RepoName:    v,
            Constructor: code,
            Subtests:    tests[v],
        })
    }

    t := template.Must(template.New("test").Parse(tpl))
    for k, v := range testFiles {
        var out bytes.Buffer
        if err := t.Execute(&out, v); err != nil {
            log.Println(err)
            continue
        }

        // goimportsをかける
        b, err := imports.Process(k, out.Bytes(), nil)
        if err != nil {
            log.Println(err)
            continue
        }

        if err := os.WriteFile(k, b, 0600); err != nil {
            log.Println(err)
        }
    }

    return nil, nil
}

repo := xxx.New{{.RepoName}}(ctx) // 引数がormの場合もある

コンストラクタはイレギュラーがあった時に対応しやすいのでGo側で生成することにした。
基本的にはrepo := xxx.NewXXXRepository(ctx)repo := xxx.NewXXXRepository(orm)のどちらかになる。

_, ... = repo.{{.MethodName}}(ctx, ...)

また、こちらも同様にGo側で行全体を生成することにした。
(なんだかんだでかなり泥臭いコードになってしまったけど...)

func (m Method) Call() string {
    var b strings.Builder

    // 左辺を作る
    if len(m.results) > 0 {
        ret := strings.Repeat("_,", len(m.results))
        b.WriteString(ret[:len(ret)-1] + " = ")
    }

    // 右辺を作る(引数は適当な値を詰める)
    b.WriteString("repo." + m.Name + "(")
    args := make([]string, 0, len(m.params))
    for i, v := range m.params {
        switch t := v.Type.(type) {
        case *ast.Ident:
            switch t.Name {
            case "int", "int64", "float", "float64":
                for j := range v.Names {
                    // `(a, b int)` のようなケースへの対応
                    args = append(args, strconv.Itoa(i+j))
                }
            case "string":
                for _, vv := range v.Names {
                    args = append(args, strconv.Quote(vv.Name))
                }
            default:
                typ := types.ExprString(v.Type)
                if token.IsExported(typ) {
                    typ = m.pkg + "." + typ + "{}"
                }
                for range v.Names {
                    args = append(args, typ)
                }
            }
        case *ast.SelectorExpr:
            if t.Sel.Name == "Context" {
                args = append(args, "ctx")
            } else {
                for range v.Names {
                    args = append(args, types.ExprString(t)+"{}")
                }
            }
        case *ast.StarExpr:
            // 略
        case *ast.InterfaceType:
            // 略
        case *ast.MapType:
            // 略
        case *ast.ArrayType:
            // 略
        case *ast.Ellipsis:
            // 略
        }
    }
    b.WriteString(strings.Join(args, ",") + ")")

    return b.String()
}

なのでテンプレートは以下のようになった。

{{range .}}
func Test{{.RepoName}}(t *testing.T) {
    orm := database.NewORM()
    ctx := context.WithValue(context.Background(), contextKey, orm)
    {{.Constructor}}

    {{range .Subtests}}
    t.Run("{{.MethodName}}", func(t *testing.T) {
        dbtest.RegisterTestName(t)
        {{.Call}}
        dbtest.CheckSQL(t)
    })
    {{end -}}
}
{{end -}}

4. TODO

これで1000件弱のサブテストを生成してみたところ、ビルドできなかったり実行時にpanicするのが10件ほどあった。
一旦はコメントアウト状態でコードを生成したりt.Skipを差し込むようにしているが、手動で直すなり分岐を追加してちゃんと通るコードにしないといけない。
(というのもあってコードはだいぶ省略している)

それから比較でFAILするのも10数件あったのでcmp.Optionを使うか、その他の方法でPASSするようにしないといけない。

そして新たに追加されるrepositoryやメソッドをどうするかもまだ考えていない。

*1:+チームメンバーに自分の思考を共有できればと思ってたり

みんなで書くGoのエンドポイントテスト

Webアプリケーションサーバーに何か大きな変更をしたいけど、既存のテストだと心許なかったので各エンドポイントにHandlerからのテストを追加することにした。

ただ全部のテストを自分1人で作っていくのはボリューム的に現実的ではなかったので、どうしたらチーム全員が書きやすいテストになるか考えて色々と整備してみた。

テストの書き方がある程度決まっている

エンドポイントごとにスタイルがバラバラだと都度どう書くか考えなければいけなくなってしまうため、基本的にはリクエストとレスポンスだけテーブルに指定するスタイルが良さそうだと考えた。

簡略化すると以下のような形式。

func TestFoo_Get(t *testing.T) {
    tests := []struct {
        name string
        // ヘッダやクエリパラメータなど
        // 期待するレスポンス
    }{
        // 実際のテストケース
    }
    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            r := httptest.NewRequest("GET", "/api/foo", nil)
            RunTest(t, r, tt.want)
        })
    }
}

しかし、ヘッダやクエリパラメータの有無によってはそれに応じた処理を書かないといけないので、それを吸収するための関数を用意することにした。

type RequestOption func(*http.Request)

func WithQuery(key, value string) RequestOption {
    return func(r *http.Request) {
        q := r.URL.Query()
        q.Set(key, value)
        r.URL.RawQuery = q.Encode()
    }
}

func WithHeader(key, value string) RequestOption {
    return func(r *http.Request) {
        r.Header.Set(key, value)
    }
}

func NewRequest(method, endpoint string, body io.Reader, options ...RequestOption) *http.Request {
    r := httptest.NewRequest(method, endpoint, body)
    for _, opt := range options {
        opt(r)
    }
    return r
}

また、POSTやPUTでJSONを送る場合は以下の関数でボディを作れるようにした。

func JSONBody(t *testing.T, m map[string]interface{}) io.Reader {
    t.Helper()

    body := new(bytes.Buffer)
    if err := json.NewEncoder(body).Encode(&m); err != nil {
        t.Fatal(err)
    }
    return body
}

期待する結果(want)を全て書かなくても良い

レスポンスはエンドポイントによってはかなり大きくなることもあり、毎回全体を書くのは大変そうだったので避けたかった。
そしてレスポンスが変わるたびに毎回手動で全て直さないといけないのも面倒なのでgoldenファイル化することにした。

値が固定されないところもあるので、そこはレスポンスを柔軟に書き換えられるようにしている。*1
例えばJSONが返ってくるエンドポイントであれば以下のような関数。*2

type ResponseFilter func(t *testing.T, r *http.Response)

func ModJSONFields(overwrite map[string]interface{}) ResponseFilter {
    return func(t *testing.T, r *http.Response) {
        t.Helper()

        var tmp map[string]interface{}
        if err := json.NewDecoder(r.Body).Decode(&tmp); err != nil {
            t.Fatal(err)
        }

        rewriteMap(t, tmp, overwrite)

        body := new(bytes.Buffer)
        if err := json.NewEncoder(body).Encode(&tmp); err != nil {
            t.Fatal(err)
        }
        r.Body = io.NopCloser(body)
    }
}

これでRunTestは以下のようになる。

var (
    handler http.Handler

    updateGolden = flag.Bool("golden", false, "Update golden files")
)

func RunTest(t *testing.T, r *http.Request, want int, filters ...ResponseFilter) {
    t.Helper()

    w := httptest.NewRecorder()
    handler.ServeHTTP(w, r)

    got := w.Result()
    if got.StatusCode != want {
        t.Errorf("HTTP StatusCode = %d, want %d", got.StatusCode, want)
    }

    for _, f := range filters {
        f(t, got)
    }

    dump, err := httputil.DumpResponse(got, true)
    if err != nil {
        t.Fatal(err)
    }

    if *updateGolden {
        writeGolden(t, dump)
    } else {
        golden := readGolden(t)
        if diff := cmp.Diff(golden, dump); diff != "" {
            t.Errorf("HTTP Response mismatch (-want +got):\n%s", diff)
        }
    }
}

httptest.Serverを使わなかったのはモックが無いとどうにもならなくなった時に最悪contextに何か詰めてどうにかしようと思ったからなんだけど、今のところその必要はなさそう。

テストの前後で必要な処理がわかる

これだけで良ければとても楽なんだけど、一番大変なのは必要なリソースの準備なはず。
今回対象としたWebアプリはxorm経由でMySQLを使っているため、テスト実行時に出力されるxormのログを分析するツールを用意した。

go test -v -run TestFoo_Get | go run $PATH_TO_TOOL のように使うことで、サブテストごとにアクセスのあったテーブルを表示したり、setup.sqlcleanup.sqlを生成できる。
まだそのまま使えるSQLにはならないので手動で直さないといけないけど、何も無いよりはだいぶマシかな。

func SetupDB(t *testing.T) {
    t.Helper()

    execSQL(t, "setup.sql")
}

func CleanupDB(t *testing.T) {
    t.Helper()

    execSQL(t, "cleanup.sql")
}

var db *sql.DB

func execSQL(t *testing.T, sqlfile string) {
    t.Helper()

    filename := filepath.Join("testdata", t.Name(), sqlfile)
    file, err := os.ReadFile(filename)
    if os.IsNotExist(err) {
        return
    }
    if err != nil {
        t.Fatal(err)
    }

    if _, err := db.Exec(string(file)); err != nil {
        log.Fatal(err)
    }
}

具体例

ここまできたらあとは

  1. テスト関数を作る
  2. go test -v -run TestFoo_Get | go run $PATH_TO_TOOL する
  3. setup.sqlcleanup.sql を修正する
  4. go test -v -run TestFoo_Get -golden する
  5. goldenファイルの中身を確認する
  6. go test -v -run TestFoo_Get でPASSすることを確認する

の流れで以下を量産していくだけ。

func TestFoo_Get(t *testing.T) {
    SetupDB(t) // TestFoo_Get/setup.sqlがあれば実行する
    t.Cleanup(func() {
        CleanupDB(t) // TestFoo_Get/cleanup.sqlがあれば実行する
    })

    tests := []struct {
        name string
        opts []RequestOption
        want int
    }{
        {
            name: "found",
            opts: []RequestOption{WithQuery("limit", "10")},
            want: http.StatusOK,
        },
        {
            name: "invalid limit",
            opts: []RequestOption{WithQuery("limit", "abc")},
            want: http.StatusBadRequest,
        },
    }
    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            SetupDB(t) // TestFoo_Get/サブテスト名/setup.sqlがあれば実行する
            t.Cleanup(func() {
                CleanupDB(t) // TestFoo_Get/サブテスト名/cleanup.sqlがあれば実行する
            })

            r := NewRequest("GET", "/api/foo", nil, tt.opts...)
            RunTest(t, r, tt.want,
                ModJSONFields(map[string]interface{}{
                    "created_at": "2006-01-02 15:04:05",
                }),
            )
        })
    }
}

とりあえず次の改修には十分なテストが追加できたので安心して変更できそう。

それでも今後のことを考えるともう少しテストを増やしておきたいので既存のエンドポイントにテスト追加を促すlinterでも作りたいところ。
なお、新規エンドポイント追加時にテストが無かったら警告するlinterは導入済み。

*1:go-cmpのオプションは難しそうだったのでやらなかったのと、この形式ならJSONを整形する関数なんかも簡単に作れる

*2:他にはCookieやHTMLなんかを加工したり

GoのWebアプリで見かけたツラいコード

構造体のフィールドにContextを持たせる

ほとんどの場合、各メソッドの引数にいちいちctxを渡すのが面倒だという理由だけで以下のようにしている印象がある。

type S struct {
    ctx context.Context
}

func (s *S) A() {
    // s.ctxを使う
}

func (s *S) B() {
    // s.ctxを使う
}

func (s *S) C() {
    // s.ctxを使う
}

ちょっとくらいタイプ数を減らすよりも素直に引数で渡すようにした方がシンプルだし、将来的に変にContextを共有するようなコードになってしまうのも防げる。

func (s *S) A() {
    value := getValue(s.ctx)
    s.ctx = context.WithValue(s.ctx, "key", value)
}

func (s *S) B() {
    // 事前にA()を呼んでおく必要がある?
    value := s.ctx.Value("key")
}

https://blog.golang.org/context-and-structs にも書いてあるように、どうしてもそうしなければいけない理由がない限りはやらない方が良い。

Contextの中に参照を入れておいて任意の場所で更新する

GoのContextはcontext.WithValueで新しいものを作って呼び出し先に渡す形になっているため、基本的に呼び出し元は呼び出し先の影響を受けることがない。
ただ、予めContextの中に参照を入れておくと呼び出し元に影響を与えることができてしまう。

func f(ctx context.Context) {
    // この段階では ctx.Value("key").(*V).value が空
    do(ctx)
    // ctx.Value("key").(*V).value の値が変化
}

構造体でContextを共有するのと組み合わせると非常に危険。

関連が把握しきれなくなると直すに直せなくなってしまうので、こういうコードはなるべく書かない。
もしくはせめて影響範囲を限定できるようにしておきたい。

似ている処理を匿名の構造体でまとめる

https://golang.org/doc/effective_go#embedding のようにすることでGoでも継承のようなことができる。
しかし、何でもかんでもこれを適用してしまうと扱いにくいコードになってしまう可能性がある。

type Common struct {
    req *http.Request
    rw  http.ResponseWriter

    id   int64
    data map[string]interface{}
}

func (c *Common) Prepare() {
    c.id = idFromPath(c.req.URL)
    c.data = decodeBody(c.req.Body) // GETリクエストの場合は不要
}

func (c *Common) ID() int64                    { return c.id }
func (c *Common) Data() map[string]interface{} { return c.data }

type Handler struct {
    Common
}

func (h *Handler) Get() {
    // Trace系の処理をしたりとか

    h.Prepare()

    res := getResource(h.ID())
    // 続く...
}

func (h *Handler) Put() {
    // Trace系の処理をしたりとか

    // Prepare忘れ!

    res := putResource(h.ID(), h.Data())
    // 続く...
}

この程度ならまだわかりやすいが、継承が多段になったり、レイヤやパッケージや処理がさらに細分化されていくことで、だんだんとわかりにくいコードになっていってしまう。

そのため、共通処理は親クラスのメソッドではなく関数を使うようにするなど、なるべく暗黙的な要素を排除しておいた方が後になって困ることが少ない。

特に状態が変化する構造体を埋め込む際は要注意。

vim-lspのCallHierarchyをツリーっぽく表示する

リファクタリングしたりコードを調べたりする時、呼び出し元を探すのにLspReferencesLspCallHierarchyIncomingを使っていた。
ただ、どちらも1階層分しか表示してくれず、呼び出し元が遠いと影響範囲が把握しにくかったのでquickfixに結果をマージして表示するコマンドを作ってみた。

command! AppendCallTree call s:append_tree(':LspCallHierarchyIncoming')
command! AppendRefTree call s:append_tree(':LspReferences')

augroup AppendTree
    autocmd!
augroup END

function! s:append_tree(cmd) abort
    autocmd AppendTree BufWinEnter quickfix let s:lsp_done = 1

    copen                            " quickfixに移動し、
    let l:pos = line('.')            " 現在の行番号と、
    let l:parent_tree = getqflist()  " 内容を取得し、
    call setqflist([])               " いったん空する
    let l:level = count(l:parent_tree[l:pos-1].text, '⬅️  ')
    wincmd p

    " 元のバッファで指定したコマンドを実行し、
    let s:lsp_done = 0
    execute a:cmd

    " 完了するかある程度時間が経過するまで待つ
    let l:cnt = 0
    while !s:lsp_done && l:cnt < 100
        sleep 10m
        let l:cnt += 1
    endwhile

    let l:child = getqflist()
    if len(l:child) != 0
        " 新たに取得した分は先頭に⬅️を付けて元の位置の下に挿入する
        call extend(l:parent_tree, map(l:child, 'extend(v:val, {"text": repeat("⬅️  ", l:level+1) . v:val.text})'), l:pos)
    endif

    " 結果(取得できなかった場合は元の内容)をquickfixに表示し、
    " 次の場所にジャンプする
    call setqflist(l:parent_tree)
    execute 'cc ' . string(l:pos + 1)

    autocmd! AppendTree
endfunction

AppendCallTree実行後は@:などで繰り返せるので調査が楽になった。

f:id:daisuzu:20210312165925g:plain
例: goplsのCallHierarchy

  • LSPを直接呼ぶのは面倒なのでコマンドを実行する形式にした
    • 特にCallHierarchy...
  • 専用バッファよりquickfixの方が何かと扱いやすいのでやらなかった
    • 何も考えずにジャンプできるし
    • フィルタも簡単だし

go/analysisのSuggestedFixでコードを修正する

Goの既存コードを修正するツールを作る時、

  • 既存コードをどう書き換えて
  • 出力して
  • テストするか

を考えなければいけないのが少し面倒だと思っていました。
が、golang.org/x/tools/go/analysisSuggestedFixを使えばすごく簡単にできてしまいます。

golang.org/x/tools/go/analysisstaticcheckgolangci-lintなどの静的解析ツールでよく使われているパッケージです。

例えば以下のような、関数の引数にcontext.Contextがあるかどうかチェックするツールがあったとして、

func run(pass *analysis.Pass) (interface{}, error) {
    inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)

    nodeFilter := []ast.Node{
        (*ast.FuncDecl)(nil),
    }

    inspect.Preorder(nodeFilter, func(n ast.Node) {
        decl := n.(*ast.FuncDecl)
        if decl.Type.Params.NumFields() > 0 {
            // NOTE: 第1引数のみを文字列でチェックしているので厳密ではない
            if types.ExprString(decl.Type.Params.List[0].Type) == "context.Context" {
                return
            }
        }

        pass.Reportf(decl.Pos(), "missing ctx in parameter")
    })

    return nil, nil
}

これを、もしチェックに引っ掛かったら引数にcontext.Contextを追加できるように変更してみます。

まずはpass.Reportfpass.Reportに変更し、直接Diagnosticを渡せる形にします。

pass.Report(analysis.Diagnostic{
    Pos:     decl.Pos(),
    Message: "missing context in parameter",
})

そしてSuggestedFixesとしてコードを変更する場所(PosからEnd)と書き換え後のコード(NewText)を渡します。

pass.Report(analysis.Diagnostic{
    Pos:     decl.Pos(),
    Message: "missing context in parameter",
    SuggestedFixes: []analysis.SuggestedFix{{
        Message: "add ctx to parameter",
        TextEdits: []analysis.TextEdit{{
            Pos:     decl.Pos(),
            End:     decl.Type.Params.Closing + 1,
            NewText: b,
        }},
    }},
})

書き換え後のコードは標準パッケージのformat.Nodeを使って作ります。

func newText(pass *analysis.Pass, decl *ast.FuncDecl) ([]byte, error) {
    // Godoc、戻り値、関数の中身は使わずにコードを整形する
    f := &ast.FuncDecl{
        Recv: decl.Recv,
        Name: decl.Name,
        Type: &ast.FuncType{
            Params: &ast.FieldList{
                List: append([]*ast.Field{{
                    Names: []*ast.Ident{{Name: "ctx"}},
                    Type: &ast.SelectorExpr{
                        X:   &ast.Ident{Name: "context"},
                        Sel: &ast.Ident{Name: "Context"},
                    },
                }}, decl.Type.Params.List...),
            },
        },
    }

    var buf bytes.Buffer
    if err := format.Node(&buf, pass.Fset, f); err != nil {
        return nil, err
    }
    return buf.Bytes(), nil
}

この書き換えを実際に適用するにはコマンドラインツールとして実行する時に-fixフラグを付けるようにすればOKです。
なお、-fixフラグはunitcheckerだと渡せないため、main.gosinglecheckermulticheckerを使う必要があります。

もしくは、goplsAnalyzerとして組み込むことでエディタと連携して使うことも可能です。
多少作り込みが甘くても、リファクタリングする時だけ以下に追加し、go installして使ってみても良いかもしれません。 https://github.com/golang/tools/blob/gopls/v0.6.4/internal/lsp/source/options.go#L1108-L1150

vim + vim-lspは該当箇所で:LspCodeActionを実行すると呼び出せます。

f:id:daisuzu:20210128120803g:plain
vim-lspのLspCodeAction

テストについてはanalysistest.Runanalysistest.RunWithSuggestedFixesに変更すればgoldenファイルと比較してくれるようになります。

go-cmpでmap[string]interface{}のJSONを比較する

GoでJSONを扱う際、型を定義せずに map[string]interface{} を使いたくなることがあります。

var (
    a = map[string]interface{}{
        "data": map[string]interface{}{
            "value": int64(1),
        },
    }
    b = map[string]interface{}{
        "data": map[string]interface{}{
            "value": float64(1),
        },
    }
)

ちょっとした用途であれば特に問題ないかもしれませんが、テストで使おうとするとたまに数値のフィールドがfloat64とint64で比較できずに困ってしまいます。
(goldenファイルを読み込んだ場合など)

func TestReflect(t *testing.T) {
    if !reflect.DeepEqual(a, b) {
        t.Errorf("%v != %v", a, b)
    }
}

こちらはint64が含まれている方をjson.Marshalし、再度json.Unmarshalすることでfloat64にすることで回避できます。

func TestReflect2(t *testing.T) {
    tmp, err := json.Marshal(a)
    if err != nil {
        t.Fatal(err)
    }
    var got map[string]interface{}
    if err := json.Unmarshal(tmp, &got); err != nil {
        t.Fatal(err)
    }
    if !reflect.DeepEqual(got, b) {
        t.Errorf("%v != %v", got, b)
    }
}

ただ、なんだか無駄な変換をしているようでモヤモヤします。

モヤモヤするのであればきちんと型を定義するべきだとは思いますが、どうしてもstructを作りたくないことがあるかもしれません。
そんな時はgithub.com/google/go-cmp/cmpFilterValuesを使用すると数値をfloat64として比較できます。

func TestCmpWithOpt(t *testing.T) {
    opt := cmp.FilterValues(func(x, y interface{}) bool {
        return isNumber(x) && isNumber(y)
    }, cmp.Comparer(func(x, y interface{}) bool {
        return cmp.Equal(toFloat64(x), toFloat64(y))
    }))
    if !cmp.Equal(a, b, opt) {
        t.Errorf("%v != %v", a, b)
    }
}

func isNumber(v interface{}) bool {
    k := reflect.ValueOf(v).Kind()
    return k == reflect.Int64 || k == reflect.Float64
}

func toFloat64(v interface{}) float64 {
    rv := reflect.ValueOf(v)
    if rv.Kind() == reflect.Int64 {
        return float64(rv.Int())
    }
    return rv.Float()
}

FilterValuesの第1引数には第2引数(opt)を評価する条件となる関数を指定します。
mapのフィールドは全てinterface{}なのでxとyの型はinterface{}にする必要があります。

第2引数では実際に比較する関数を指定します。
このタイミングで数値をfloat64に変換して比較します。
なお、cmp.Comparerのみだとcannot use an unfiltered optionでpanicしてしまいます。

全体のコードはこちらです。