前に書いた↓を使ってxormのバージョンを上げようと思っていたけど、生成されるSQLがおかしくなることがあると聞いたのでさらにテストを拡充することにした。
daisuzu.hatenablog.com
といってもまだ完成していないので、ここまでやったことを備忘録*1として残しておいて続きは連休明けにやる予定。
方針としてはxormがDBのドライバーを呼んだ際のクエリを記録・比較できるようにし、それを実行するテストを自動生成するというもの。
なぜそうしたかというと、理由は主に次の2点。
- できるだけ短時間でテストを追加したい
- 対象となるテストが膨大なので一つずつ書いていられない
- 実際のDBに接続したくない
- 他のテストの影響を受けたくないし、与えたくない
- CIで時間がかかるようになってしまうのも困る
ということで以下にだらだらと書いていく。
1. ダミードライバーを作る
DBはMySQLを使っていて、ドライバーは固定されていたのでまずはこれを変えられるようにする。
package database
func NewORM() ORM {
e, err := xorm.NewEngine("mysql", dsn)
}
既存への影響を最小限にしたかったのでビルドタグで切り替えることにした。
driver.go
package database
const driverName = "mysql"
driver_dummy.go
package database
const driverName = "dummy"
ダミーの方は代わりとなるドライバーも実装しておく。
テストはrepositoryレイヤのメソッド単位にするつもりなのでPrepare
→Exec
orQuery
の引数を記録したらその時点で処理は終わらせてしまう。
(できれば返ってくるエラーが一致するかもチェックしたいが難しそう)
func init() {
sql.Register(driverName, dummyDriver{})
core.RegisterDriver(driverName, dummyDriver{})
}
type dummyDriver struct{}
func (dummyDriver) Open(dsn string) (driver.Conn, error) {
return nil, errors.New("not implemented")
}
func (dummyDriver) OpenConnector(dsn string) (driver.Connector, error) {
return connector{}, nil
}
type connector struct{}
func (connector) Connect(ctx context.Context) (driver.Conn, error) {
return conn{}, nil
}
func (connector) Driver() driver.Driver {
return dummyDriver{}
}
type conn struct{}
func (conn) Prepare(query string) (driver.Stmt, error) {
return &stmt{
numInput: strings.Count(query, "?"),
q: dbtest.NewQueryLog(query),
}, nil
}
func (conn) Close() error {
return nil
}
func (conn) Begin() (driver.Tx, error) {
return nil, errors.New("not implemented")
}
type stmt struct {
numInput int
q *QueryLog
}
func (s *stmt) Close() error {
return nil
}
func (s *stmt) NumInput() int {
return s.numInput
}
func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
s.q.SetArgs(args)
return nil, errors.New("abort")
}
func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
s.q.SetArgs(args)
return nil, errors.New("abort")
}
func (dummyDriver) Parse(string, string) (*core.Uri, error) {
return &core.Uri{DbType: core.DbType("mysql")}, nil
}
2. テストコードを考える
以下のようなテンプレートを考えた。
{{range .}}
func Test{{.RepoName}}(t *testing.T) {
orm := database.NewORM()
ctx := context.WithValue(context.Background(), contextKey, orm)
repo := xxx.New{{.RepoName}}(ctx)
{{range .Subtests}}
t.Run("{{.MethodName}}", func(t *testing.T) {
dbtest.RegisterTestName(t)
_, ... = repo.{{.MethodName}}(ctx, ...)
dbtest.CheckSQL(t)
})
{{end -}}
}
{{end -}}
これに合わせてSQLの記録と比較をするためのdbtestパッケージを作る。
package dbtest
import (
"database/sql/driver"
"encoding/json"
"os"
"path/filepath"
"reflect"
"testing"
"github.com/google/go-cmp/cmp"
)
var currentTest string
func RegisterTestName(t *testing.T) {
currentTest = t.Name()
}
type QueryLog struct {
Query string
Args []driver.Value
}
var queryLogs = map[string]*QueryLog{}
func NewQueryLog(query string) *QueryLog {
q := &QueryLog{Query: query}
queryLogs[currentTest] = q
return q
}
func (q *QueryLog) SetArgs(args []driver.Value) {
q.Args = args
}
func CheckSQL(t *testing.T, opts ...cmp.Option) {
t.Helper()
got := queryLogs[t.Name()]
if *updateGolden {
writeGolden(t, got)
return
}
opts = append(opts, cmp.FilterValues(func(x, y interface{}) bool {
return isNumber(x) && isNumber(y)
}, cmp.Comparer(func(x, y interface{}) bool {
return cmp.Equal(toFloat64(x), toFloat64(y))
})))
if diff := cmp.Diff(readGolden(t), got, opts...); diff != "" {
t.Errorf("SQL mismatch (-want +got):\n%s", diff)
}
}
3. 静的解析ツールを作る
あとはテンプレートに必要な情報を集めればOK。
repositoryレイヤは以下のようになっているため、
type XXXRepository struct {
repo.RootRepository
}
func NewXXXRepository(ctx context.Context) XXXRepository {
}
func (r XXXRepository) Method()
- フィールドに
RootRepository
がある型を探す
- 1.の型を返す関数(コンストラクタ)を探す
- 1.の型のメソッドを探す
の順に解析していけば必要な情報が揃えられる。
3-1. 対象のrepositoryを探す
型や関数が定義されているファイルや行の位置によってうまく解析できないと困るので1〜3は個別に収集していくことにした。
まずは1のrepository本体。
型の名前は後続のAnalyzerで使い、ファイル名をテストファイルのprefixにする。
var repoCollector = &analysis.Analyzer{
Name: "repocollector",
Doc: "collect repository definition",
Run: collectRepository,
ResultType: reflect.TypeOf(new(Repository)),
Requires: []*analysis.Analyzer{inspect.Analyzer},
}
type Repository struct {
types map[string]string
}
func collectRepository(pass *analysis.Pass) (interface{}, error) {
inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
nodeFilter := []ast.Node{
(*ast.TypeSpec)(nil),
}
result := &Repository{types: make(map[string]string)}
inspect.Preorder(nodeFilter, func(n ast.Node) {
ts := n.(*ast.TypeSpec)
st, ok := ts.Type.(*ast.StructType)
if !ok {
return
}
if !hasRootRepository(st.Fields.List) {
return
}
result.types[ts.Name.Name] = filepath.Base(pass.Fset.File(n.Pos()).Name())
})
return result, nil
}
3-2. repositoryのコンストラクタを探す
次はコンストラクタ(Newから始まる関数)。
関数名とParams
と、1つのものしか無かったが念のためResults
も全て保持しておく。
var newCollector = &analysis.Analyzer{
Name: "newcollector",
Doc: "collect constructor",
Run: collectConstructor,
ResultType: reflect.TypeOf(new(Constructor)),
Requires: []*analysis.Analyzer{inspect.Analyzer, repoCollector},
}
type Constructor struct {
funcs map[string]*Func
}
type Func struct {
name string
params []*ast.Field
results []*ast.Field
}
func collectConstructor(pass *analysis.Pass) (interface{}, error) {
inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
repo := pass.ResultOf[repoCollector].(*Repository)
nodeFilter := []ast.Node{
(*ast.FuncDecl)(nil),
}
result := &Constructor{funcs: make(map[string]*Func)}
inspect.Preorder(nodeFilter, func(n ast.Node) {
fd := n.(*ast.FuncDecl)
if fd.Type.Results == nil {
return
}
typeName, ok := repo.isConstructor(fd.Type.Results.List)
if !ok {
return
}
result.funcs[typeName] = &Func{
name: fd.Name.Name,
params: fd.Type.Params.List,
results: fd.Type.Results.List,
}
})
return result, nil
}
3-3. メソッドを探す
最後はメソッド。
このAnalyzerでテストコードの生成も行う。
var analyzer = &analysis.Analyzer{
Name: "gensqltest",
Doc: "generate tests",
Run: run,
Requires: []*analysis.Analyzer{inspect.Analyzer, repoCollector, newCollector},
}
type Cases []Case
type Case struct {
PkgPath string
RepoName string
Constructor string
Subtests []Method
}
type Method struct {
Name string
pkg string
params []*ast.Field
results []*ast.Field
}
func run(pass *analysis.Pass) (interface{}, error) {
inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
repo := pass.ResultOf[repoCollector].(*Repository)
constructor := pass.ResultOf[newCollector].(*Constructor)
nodeFilter := []ast.Node{
(*ast.FuncDecl)(nil),
}
var order []string
tests := make(map[string][]Method)
inspect.Preorder(nodeFilter, func(n ast.Node) {
fd := n.(*ast.FuncDecl)
if fd.Recv == nil {
return
}
if !token.IsExported(fd.Name.Name) {
return
}
typeName := strings.TrimPrefix(types.ExprString(fd.Recv.List[0].Type), "*")
if ok := repo.isMethod(typeName); !ok {
return
}
if _, ok := tests[typeName]; !ok {
order = append(order, typeName)
}
m := Method{Name: fd.Name.Name, pkg: pass.Pkg.Name()}
if fd.Type.Params != nil {
m.params = fd.Type.Params.List
}
if fd.Type.Results != nil {
m.results = fd.Type.Results.List
}
tests[typeName] = append(tests[typeName], m)
})
if len(order) == 0 {
return nil, nil
}
testFiles := make(map[string]Cases)
for _, v := range order {
code := constructor.genConstructorCode(pass.Pkg.Name(), v)
if code == "" {
continue
}
testFiles[repo.testFileName(v)] = append(testFiles[repo.testFileName(v)], Case{
PkgPath: pass.Pkg.Path(),
RepoName: v,
Constructor: code,
Subtests: tests[v],
})
}
t := template.Must(template.New("test").Parse(tpl))
for k, v := range testFiles {
var out bytes.Buffer
if err := t.Execute(&out, v); err != nil {
log.Println(err)
continue
}
b, err := imports.Process(k, out.Bytes(), nil)
if err != nil {
log.Println(err)
continue
}
if err := os.WriteFile(k, b, 0600); err != nil {
log.Println(err)
}
}
return nil, nil
}
repo := xxx.New{{.RepoName}}(ctx) // 引数がormの場合もある
コンストラクタはイレギュラーがあった時に対応しやすいのでGo側で生成することにした。
基本的にはrepo := xxx.NewXXXRepository(ctx)
かrepo := xxx.NewXXXRepository(orm)
のどちらかになる。
_, ... = repo.{{.MethodName}}(ctx, ...)
また、こちらも同様にGo側で行全体を生成することにした。
(なんだかんだでかなり泥臭いコードになってしまったけど...)
func (m Method) Call() string {
var b strings.Builder
if len(m.results) > 0 {
ret := strings.Repeat("_,", len(m.results))
b.WriteString(ret[:len(ret)-1] + " = ")
}
b.WriteString("repo." + m.Name + "(")
args := make([]string, 0, len(m.params))
for i, v := range m.params {
switch t := v.Type.(type) {
case *ast.Ident:
switch t.Name {
case "int", "int64", "float", "float64":
for j := range v.Names {
args = append(args, strconv.Itoa(i+j))
}
case "string":
for _, vv := range v.Names {
args = append(args, strconv.Quote(vv.Name))
}
default:
typ := types.ExprString(v.Type)
if token.IsExported(typ) {
typ = m.pkg + "." + typ + "{}"
}
for range v.Names {
args = append(args, typ)
}
}
case *ast.SelectorExpr:
if t.Sel.Name == "Context" {
args = append(args, "ctx")
} else {
for range v.Names {
args = append(args, types.ExprString(t)+"{}")
}
}
case *ast.StarExpr:
case *ast.InterfaceType:
case *ast.MapType:
case *ast.ArrayType:
case *ast.Ellipsis:
}
}
b.WriteString(strings.Join(args, ",") + ")")
return b.String()
}
なのでテンプレートは以下のようになった。
{{range .}}
func Test{{.RepoName}}(t *testing.T) {
orm := database.NewORM()
ctx := context.WithValue(context.Background(), contextKey, orm)
{{.Constructor}}
{{range .Subtests}}
t.Run("{{.MethodName}}", func(t *testing.T) {
dbtest.RegisterTestName(t)
{{.Call}}
dbtest.CheckSQL(t)
})
{{end -}}
}
{{end -}}
4. TODO
これで1000件弱のサブテストを生成してみたところ、ビルドできなかったり実行時にpanicするのが10件ほどあった。
一旦はコメントアウト状態でコードを生成したりt.Skip
を差し込むようにしているが、手動で直すなり分岐を追加してちゃんと通るコードにしないといけない。
(というのもあってコードはだいぶ省略している)
それから比較でFAILするのも10数件あったのでcmp.Option
を使うか、その他の方法でPASSするようにしないといけない。
そして新たに追加されるrepositoryやメソッドをどうするかもまだ考えていない。