xorm更新用のテストを静的解析で生成した時のメモ

前に書いた↓を使ってxormのバージョンを上げようと思っていたけど、生成されるSQLがおかしくなることがあると聞いたのでさらにテストを拡充することにした。

daisuzu.hatenablog.com

といってもまだ完成していないので、ここまでやったことを備忘録*1として残しておいて続きは連休明けにやる予定。

方針としてはxormがDBのドライバーを呼んだ際のクエリを記録・比較できるようにし、それを実行するテストを自動生成するというもの。

なぜそうしたかというと、理由は主に次の2点。

  • できるだけ短時間でテストを追加したい
    • 対象となるテストが膨大なので一つずつ書いていられない
      • 不要なテストを精査することも厳しい
  • 実際のDBに接続したくない
    • 他のテストの影響を受けたくないし、与えたくない
    • CIで時間がかかるようになってしまうのも困る

ということで以下にだらだらと書いていく。

1. ダミードライバーを作る

DBはMySQLを使っていて、ドライバーは固定されていたのでまずはこれを変えられるようにする。

package database

func NewORM() ORM { // ORMはinterface
    // 略
    e, err := xorm.NewEngine("mysql", dsn)
    // 略
}

既存への影響を最小限にしたかったのでビルドタグで切り替えることにした。

driver.go

// +build !dummy
//go:build !dummy

package database

const driverName = "mysql"

driver_dummy.go

// +build dummy
//go:build dummy

package database

const driverName = "dummy"

ダミーの方は代わりとなるドライバーも実装しておく。
テストはrepositoryレイヤのメソッド単位にするつもりなのでPrepareExecorQueryの引数を記録したらその時点で処理は終わらせてしまう。
(できれば返ってくるエラーが一致するかもチェックしたいが難しそう)

func init() {
    sql.Register(driverName, dummyDriver{})
    core.RegisterDriver(driverName, dummyDriver{})
}

// databse/sql用
type dummyDriver struct{}

// https://pkg.go.dev/database/sql/driver#Driver
func (dummyDriver) Open(dsn string) (driver.Conn, error) {
    return nil, errors.New("not implemented")
}

// https://pkg.go.dev/database/sql/driver#DriverContext
func (dummyDriver) OpenConnector(dsn string) (driver.Connector, error) {
    return connector{}, nil
}

// https://pkg.go.dev/database/sql/driver#Connector
type connector struct{}

func (connector) Connect(ctx context.Context) (driver.Conn, error) {
    return conn{}, nil
}

func (connector) Driver() driver.Driver {
    return dummyDriver{}
}

// https://pkg.go.dev/database/sql/driver#Conn
type conn struct{}

func (conn) Prepare(query string) (driver.Stmt, error) {
    return &stmt{
        numInput: strings.Count(query, "?"),
        q:        dbtest.NewQueryLog(query),
    }, nil
}

func (conn) Close() error {
    return nil
}

func (conn) Begin() (driver.Tx, error) {
    return nil, errors.New("not implemented")
}

// https://pkg.go.dev/database/sql/driver#Stmt
type stmt struct {
    numInput int
    q        *QueryLog
}

func (s *stmt) Close() error {
    return nil
}

func (s *stmt) NumInput() int {
    return s.numInput
}

func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
    s.q.SetArgs(args)
    return nil, errors.New("abort")
}
func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
    s.q.SetArgs(args)
    return nil, errors.New("abort")
}

// https://pkg.go.dev/github.com/go-xorm/core#Driver
func (dummyDriver) Parse(string, string) (*core.Uri, error) {
    return &core.Uri{DbType: core.DbType("mysql")}, nil
}

2. テストコードを考える

以下のようなテンプレートを考えた。

{{range .}}
func Test{{.RepoName}}(t *testing.T) {
    orm := database.NewORM()
    ctx := context.WithValue(context.Background(), contextKey, orm)
    repo := xxx.New{{.RepoName}}(ctx) // 引数がormの場合もある

    {{range .Subtests}}
    t.Run("{{.MethodName}}", func(t *testing.T) {
        dbtest.RegisterTestName(t)
        _, ... = repo.{{.MethodName}}(ctx, ...)
        dbtest.CheckSQL(t)
    })
    {{end -}}
}
{{end -}}

これに合わせてSQLの記録と比較をするためのdbtestパッケージを作る。

package dbtest

import (
    "database/sql/driver"
    "encoding/json"
    "os"
    "path/filepath"
    "reflect"
    "testing"

    "github.com/google/go-cmp/cmp"
)

var currentTest string

// サブテストの最初に呼ぶ
func RegisterTestName(t *testing.T) {
    currentTest = t.Name()
}

type QueryLog struct {
    Query string
    Args  []driver.Value
}

var queryLogs = map[string]*QueryLog{}

// サブテストごとにクエリを記録する
func NewQueryLog(query string) *QueryLog {
    q := &QueryLog{Query: query}
    queryLogs[currentTest] = q
    return q
}

func (q *QueryLog) SetArgs(args []driver.Value) {
    q.Args = args
}

// 記録したクエリをgoldenファイルと比較する
func CheckSQL(t *testing.T, opts ...cmp.Option) {
    t.Helper()

    got := queryLogs[t.Name()]

    if *updateGolden {
        writeGolden(t, got)
        return
    }

    // gotはint64、goldenはfloat64になるので常にfloat64で比較する
    // https://daisuzu.hatenablog.com/entry/2021/01/08/145459
    opts = append(opts, cmp.FilterValues(func(x, y interface{}) bool {
        return isNumber(x) && isNumber(y)
    }, cmp.Comparer(func(x, y interface{}) bool {
        return cmp.Equal(toFloat64(x), toFloat64(y))
    })))

    if diff := cmp.Diff(readGolden(t), got, opts...); diff != "" {
        t.Errorf("SQL mismatch (-want +got):\n%s", diff)
    }
}

3. 静的解析ツールを作る

あとはテンプレートに必要な情報を集めればOK。

repositoryレイヤは以下のようになっているため、

type XXXRepository struct {
    repo.RootRepository
}

func NewXXXRepository(ctx context.Context) XXXRepository {
    // 略
}

func (r XXXRepository) Method()
  1. フィールドにRootRepositoryがある型を探す
  2. 1.の型を返す関数(コンストラクタ)を探す
  3. 1.の型のメソッドを探す

の順に解析していけば必要な情報が揃えられる。

3-1. 対象のrepositoryを探す

型や関数が定義されているファイルや行の位置によってうまく解析できないと困るので1〜3は個別に収集していくことにした。

まずは1のrepository本体。
型の名前は後続のAnalyzerで使い、ファイル名をテストファイルのprefixにする。

var repoCollector = &analysis.Analyzer{
    Name:       "repocollector",
    Doc:        "collect repository definition",
    Run:        collectRepository,
    ResultType: reflect.TypeOf(new(Repository)),
    Requires:   []*analysis.Analyzer{inspect.Analyzer},
}

type Repository struct {
    types map[string]string
}

func collectRepository(pass *analysis.Pass) (interface{}, error) {
    inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)

    nodeFilter := []ast.Node{
        (*ast.TypeSpec)(nil),
    }

    result := &Repository{types: make(map[string]string)}

    inspect.Preorder(nodeFilter, func(n ast.Node) {
        ts := n.(*ast.TypeSpec)
        st, ok := ts.Type.(*ast.StructType)
        if !ok {
            return
        }
        if !hasRootRepository(st.Fields.List) {
            return
        }
        result.types[ts.Name.Name] = filepath.Base(pass.Fset.File(n.Pos()).Name())
    })

    return result, nil
}

3-2. repositoryのコンストラクタを探す

次はコンストラクタ(Newから始まる関数)。
関数名とParamsと、1つのものしか無かったが念のためResultsも全て保持しておく。

var newCollector = &analysis.Analyzer{
    Name:       "newcollector",
    Doc:        "collect constructor",
    Run:        collectConstructor,
    ResultType: reflect.TypeOf(new(Constructor)),
    Requires:   []*analysis.Analyzer{inspect.Analyzer, repoCollector},
}

type Constructor struct {
    funcs map[string]*Func
}

type Func struct {
    name    string
    params  []*ast.Field
    results []*ast.Field
}

func collectConstructor(pass *analysis.Pass) (interface{}, error) {
    inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
    repo := pass.ResultOf[repoCollector].(*Repository)

    nodeFilter := []ast.Node{
        (*ast.FuncDecl)(nil),
    }

    result := &Constructor{funcs: make(map[string]*Func)}

    inspect.Preorder(nodeFilter, func(n ast.Node) {
        fd := n.(*ast.FuncDecl)
        if fd.Type.Results == nil {
            // 何も返さない関数はコンストラクタではない
            return
        }

        typeName, ok := repo.isConstructor(fd.Type.Results.List)
        if !ok {
            // collectRepository()で収集した型のみが対象
            return
        }

        result.funcs[typeName] = &Func{
            name:    fd.Name.Name,
            params:  fd.Type.Params.List,
            results: fd.Type.Results.List,
        }
    })

    return result, nil
}

3-3. メソッドを探す

最後はメソッド。
このAnalyzerでテストコードの生成も行う。

var analyzer = &analysis.Analyzer{
    Name:     "gensqltest",
    Doc:      "generate tests",
    Run:      run,
    Requires: []*analysis.Analyzer{inspect.Analyzer, repoCollector, newCollector},
}

type Cases []Case

type Case struct {
    PkgPath     string
    RepoName    string
    Constructor string
    Subtests    []Method
}

type Method struct {
    Name    string
    pkg     string
    params  []*ast.Field
    results []*ast.Field
}

func run(pass *analysis.Pass) (interface{}, error) {
    inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
    repo := pass.ResultOf[repoCollector].(*Repository)
    constructor := pass.ResultOf[newCollector].(*Constructor)

    nodeFilter := []ast.Node{
        (*ast.FuncDecl)(nil),
    }

    var order []string // 生成するRepositoryの順番を固定するために使う
    tests := make(map[string][]Method)
    inspect.Preorder(nodeFilter, func(n ast.Node) {
        fd := n.(*ast.FuncDecl)
        if fd.Recv == nil {
            // メソッドじゃない関数は対象外
            return
        }

        if !token.IsExported(fd.Name.Name) {
            // 非公開メソッドは対象外
            return
        }

        typeName := strings.TrimPrefix(types.ExprString(fd.Recv.List[0].Type), "*")
        if ok := repo.isMethod(typeName); !ok {
            // 一致するレシーバがないものは対象外
            return
        }

        if _, ok := tests[typeName]; !ok {
            // メソッドが見つかった順にテスト関数を出力する
            order = append(order, typeName)
        }

        m := Method{Name: fd.Name.Name, pkg: pass.Pkg.Name()}
        if fd.Type.Params != nil {
            m.params = fd.Type.Params.List
        }
        if fd.Type.Results != nil {
            m.results = fd.Type.Results.List
        }
        tests[typeName] = append(tests[typeName], m)
    })

    if len(order) == 0 {
        return nil, nil
    }

    testFiles := make(map[string]Cases)
    for _, v := range order {
        code := constructor.genConstructorCode(pass.Pkg.Name(), v)
        if code == "" {
            continue
        }
        testFiles[repo.testFileName(v)] = append(testFiles[repo.testFileName(v)], Case{
            PkgPath:     pass.Pkg.Path(), // importに追加する
            RepoName:    v,
            Constructor: code,
            Subtests:    tests[v],
        })
    }

    t := template.Must(template.New("test").Parse(tpl))
    for k, v := range testFiles {
        var out bytes.Buffer
        if err := t.Execute(&out, v); err != nil {
            log.Println(err)
            continue
        }

        // goimportsをかける
        b, err := imports.Process(k, out.Bytes(), nil)
        if err != nil {
            log.Println(err)
            continue
        }

        if err := os.WriteFile(k, b, 0600); err != nil {
            log.Println(err)
        }
    }

    return nil, nil
}

repo := xxx.New{{.RepoName}}(ctx) // 引数がormの場合もある

コンストラクタはイレギュラーがあった時に対応しやすいのでGo側で生成することにした。
基本的にはrepo := xxx.NewXXXRepository(ctx)repo := xxx.NewXXXRepository(orm)のどちらかになる。

_, ... = repo.{{.MethodName}}(ctx, ...)

また、こちらも同様にGo側で行全体を生成することにした。
(なんだかんだでかなり泥臭いコードになってしまったけど...)

func (m Method) Call() string {
    var b strings.Builder

    // 左辺を作る
    if len(m.results) > 0 {
        ret := strings.Repeat("_,", len(m.results))
        b.WriteString(ret[:len(ret)-1] + " = ")
    }

    // 右辺を作る(引数は適当な値を詰める)
    b.WriteString("repo." + m.Name + "(")
    args := make([]string, 0, len(m.params))
    for i, v := range m.params {
        switch t := v.Type.(type) {
        case *ast.Ident:
            switch t.Name {
            case "int", "int64", "float", "float64":
                for j := range v.Names {
                    // `(a, b int)` のようなケースへの対応
                    args = append(args, strconv.Itoa(i+j))
                }
            case "string":
                for _, vv := range v.Names {
                    args = append(args, strconv.Quote(vv.Name))
                }
            default:
                typ := types.ExprString(v.Type)
                if token.IsExported(typ) {
                    typ = m.pkg + "." + typ + "{}"
                }
                for range v.Names {
                    args = append(args, typ)
                }
            }
        case *ast.SelectorExpr:
            if t.Sel.Name == "Context" {
                args = append(args, "ctx")
            } else {
                for range v.Names {
                    args = append(args, types.ExprString(t)+"{}")
                }
            }
        case *ast.StarExpr:
            // 略
        case *ast.InterfaceType:
            // 略
        case *ast.MapType:
            // 略
        case *ast.ArrayType:
            // 略
        case *ast.Ellipsis:
            // 略
        }
    }
    b.WriteString(strings.Join(args, ",") + ")")

    return b.String()
}

なのでテンプレートは以下のようになった。

{{range .}}
func Test{{.RepoName}}(t *testing.T) {
    orm := database.NewORM()
    ctx := context.WithValue(context.Background(), contextKey, orm)
    {{.Constructor}}

    {{range .Subtests}}
    t.Run("{{.MethodName}}", func(t *testing.T) {
        dbtest.RegisterTestName(t)
        {{.Call}}
        dbtest.CheckSQL(t)
    })
    {{end -}}
}
{{end -}}

4. TODO

これで1000件弱のサブテストを生成してみたところ、ビルドできなかったり実行時にpanicするのが10件ほどあった。
一旦はコメントアウト状態でコードを生成したりt.Skipを差し込むようにしているが、手動で直すなり分岐を追加してちゃんと通るコードにしないといけない。
(というのもあってコードはだいぶ省略している)

それから比較でFAILするのも10数件あったのでcmp.Optionを使うか、その他の方法でPASSするようにしないといけない。

そして新たに追加されるrepositoryやメソッドをどうするかもまだ考えていない。

*1:+チームメンバーに自分の思考を共有できればと思ってたり