使用goast为xorm生成建表sql语句
缘起
最近在工作学习的过程中,接触到一些通过自动生成代码的方式来减少重复的工作量,以及自动生成文档或桩代码的方式,包括:
- 使用Kubernetes CRD,可以通过code-generator工具自动生成客户端代码以及一些其他工具函数
- 使用goa,可以通过定义的DSL自动生成服务端的框架代码,以及文档等
一般有以下几种方式来自动生成代码:
- 通过 goast 获取代码的抽象语法树,然后通过抽象语法树来生成对应的代码
- 通过自定义DSL的方式获取代码信息,然后生成代码
- 通过反射获取信息来生成代码
正好最近在学习使用golang的ast解析工具,遂通过实现一个简单的工具来加深理解。该工具将自动读取xorm的类型信息,并自动生成对应的建表sql语句。
目标
首先明确该工具的适用目标及范围:
- 只支持xorm框架
- 只支持少量的xorm框架特性: 包括created, updated, pk, unique, notnull
- 只支持少量的sql类型: 包括BIGINT, INT, VARCHAR, DATETIME
- 支持设定表名: 通过 +genTable spec
- 只支持mysql
测试数据如下:
package main
import "time"
// User is orm for user
// +genTable: user
type User struct {
Id int64
Name string `xorm:"unique notnull"`
Salt string
Age int
Passwd string `xorm:"varchar(200)"`
Created time.Time `xorm:"created"`
Update time.Time `xorm:"updated"`
}
/*
+genTable:
*/
// User2 and User3
type (
// User2
// +genTable: user2
User2 struct { // user2 line
Uid int `xorm:"pk 'uid'"`
}
// comment group
// User3
// +genTable:
User3 struct { // user3 line
}
)
type I1 interface {
}
最终输出的建表语句如下:
CREATE TABLE `user` (
`Id` BIGINT,
`Name` VARCHAR(255) NOT NULL,
`Salt` VARCHAR(255),
`Age` INT,
`Passwd` varchar(200),
`Created` DATETIME DEFAULT CURRENT_TIMESTAMP,
`Update` DATETIME ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY(`Id`),
UNIQUE KEY `Name` (`Name`)
);
CREATE TABLE `user2` (
`uid` INT,
PRIMARY KEY(`uid`)
);
实现
1. 生成 ast
ast 即抽象语法树, go/parser 包提供了工具来解析生成ast:
fs := token.NewFileSet()
f, err := parser.ParseFile(fs, file, src, parser.ParseComments)
其中:
- 通过token.NewFileSet来生成一个FileSet对象,这个对象会保存更详细的源码位置信息
- file是文件名, 当src为nil时会读取该文件的内容;当src是字节数组/字符串时会直接将src的内容作为文件内容解析
- 需要通过ParseComments参数告知parser解析注释,因为我们的 +genTable 是通过注释实现的
2. 获取所有的table struct
在得到 ast 之后,我们需要筛选出对应的表的struct,首先定义一个结构体保存具体的信息:
type tableStruct struct {
node *ast.StructType
tableName string
}
其中表名可能来自以下内容:
- 在多个类型上的 +genTable spec,这时不能跟表名,表示使用类型名作为表名,当有多个类型都直接使用类型名作为表名时可使用
- 在单个类型上的 +genTable spec,这时可以指定表名
2.1 解析 +genTable spec
首先提供一个函数来获取 +genTable spec:
func isGenTableDoc(doc *ast.CommentGroup) (string, bool) {
if doc == nil {
return "", false
}
for _, comment := range doc.List {
if comment == nil {
continue
}
value := strings.TrimSpace(comment.Text)
if strings.HasPrefix(value, "//") { // 如果是单行注释
value = strings.TrimSpace(value[2:])
if strings.HasPrefix(value, "+genTable:") {
return strings.TrimSpace(value[10:]), true
}
} else {
lines := strings.Split(value, "\n") // 多行注释,拆分后处理每一行
for _, line := range lines {
value := strings.TrimSpace(line)
if strings.HasPrefix(value, "+genTable:") {
return strings.TrimSpace(value[10:]), true
}
}
}
}
return "", false
}
该函数在获取到一个有效的 +genTable spec 之后即返回,并能处理多行注释的情况。
2.2 对 ast 迭代获取所有 table struct
首先定义一个函数处理 *ast.File 对象:
func filterFileTableStruct(file *ast.File) ([]*tableStruct, error) {
genTables := []*tableStruct{}
for _, decl := range file.Decls {
var n ast.Node = decl
if n, ok := n.(*ast.GenDecl); ok {
tables, err := filterGenDeclTableStruct(n)
if err != nil {
return nil, err
}
genTables = append(genTables, tables...)
}
}
return genTables, nil
}
该函数会迭代所有的Decls,并且只有在当类型为*ast.GenDecl时才进一步解析。因为所有的类型定义都被解析为*ast.GenDecl对象。
例如:
type S struct {}
type (
S1 struct{}
S2 struct{}
)
上面的每一个type关键字都被视为一个GenDecl对象。
然后处理每一个GenDecl对象,并解析出table struct:
func filterGenDeclTableStruct(n *ast.GenDecl) ([]*tableStruct, error) {
tableName, isGenTable := isGenTableDoc(n.Doc)
tables := []*tableStruct{}
for _, spec := range n.Specs {
if n, ok := spec.(*ast.TypeSpec); ok {
typeTableName, typeGenTable := isGenTableDoc(n.Doc)
if !typeGenTable && !isGenTable {
fmt.Println(n.Name, " doesn't have genTable spec, skip it")
continue
}
if tableName != "" && typeTableName != "" {
return nil, fmt.Errorf("%s has multi spec tableName", n.Name)
}
if tableName != "" && len(tables) > 0 {
return nil, fmt.Errorf("%s spec multi struct", n.Name)
}
structType, ok := n.Type.(*ast.StructType)
if !ok {
if typeTableName != "" {
return nil, fmt.Errorf("%s genTable spec on %v", n.Name, reflect.TypeOf(n.Type).Elem().Name())
}
fmt.Println(n.Name, " isn't struct, skip it")
continue
}
if typeTableName == "" {
typeTableName = tableName
}
if typeTableName == "" {
typeTableName = n.Name.Name
}
tables = append(tables, &tableStruct{
tableName: typeTableName,
node: structType,
})
}
}
if isGenTable && len(tables) == 0 {
return nil, fmt.Errorf("has genTable spec but on table exists")
}
return tables, nil
}
首先解析GenDecl的Doc注释,只有紧接着定义上的注释被视为Doc,而且因为单个的type关键字也被视为GenDecl,所以该Doc不会出现在TypeSpec上。
然后针对每一个TypeSpec进行处理,当TypeSpec是StructType时就保存下来。此处需要对genTable的有效性进行校验,并且如果没有设置自定义的表名时使用类型名作为表名。
3 从table struct中解析出对应的列
在获得了 table struct 之后,从struct tag中解析出如下的字段对象:
type Column struct {
pk string
unique string
notnull string
created string
updated string
name string
t string
}
首先迭代所有的ast.Field对象,生成对应的Column。
func genColumn(field *ast.Field) (*Column, error) {
col := &Column{}
if len(field.Names) > 1 {
return nil, fmt.Errorf("%v should only be one", field.Names)
}
if len(field.Names) == 1 {
col.name = field.Names[0].Name
} else {
ident, ok := field.Type.(*ast.Ident)
if !ok {
return nil, fmt.Errorf("field doesn't have fieldNames, and type isn't Ident, skip it")
}
col.name = ident.Name
col.t = col.name
}
tag := ""
if field.Tag != nil {
var err error
tag, err = strconv.Unquote(field.Tag.Value)
if err != nil {
fmt.Println("unquote tag failed, %v", err)
os.Exit(1)
}
structTag, ok := reflect.StructTag(tag).Lookup("xorm")
if ok {
tag = structTag
} else {
tag = ""
}
}
tags := strings.Split(tag, " ")
for _, tagField := range tags {
if tagField == "" {
continue
}
switch tagField {
case "notnull":
col.notnull = "NOT NULL"
case "unique":
col.unique = "UNIQUE KEY"
case "pk":
col.pk = "PRIMARY KEY"
case "created":
col.created = "DEFAULT CURRENT_TIMESTAMP"
case "updated":
col.updated = "ON UPDATE CURRENT_TIMESTAMP"
case "int":
col.t = "INT"
case "bigint":
col.t = "BIGINT"
default:
if strings.HasPrefix(tagField, "'") {
if len(tagField) <= 2 {
return nil, fmt.Errorf("%v used as fieldname but doesn't have value", tagField)
}
if !strings.HasSuffix(tagField, "'") {
return nil, fmt.Errorf("%v used as fieldname but doesn't close quote", tagField)
}
tagField = tagField[1 : len(tagField)-1]
col.name = tagField
} else {
if strings.HasPrefix(tagField, "varchar") {
col.t = tagField
} else {
return nil, fmt.Errorf("unknown tag: %v", tagField)
}
}
}
}
if col.t == "" {
switch ident := field.Type.(type) {
case *ast.Ident:
switch ident.Name {
case "int64":
col.t = "BIGINT"
case "time.Time":
col.t = "DATETIME"
case "string":
col.t = "VARCHAR(255)"
case "int":
col.t = "INT"
default:
return nil, fmt.Errorf("%v type to sqltype failed, unknown: %v", col.name, ident.Name)
}
case *ast.SelectorExpr:
if ident.Sel.Name == "Time" {
if ident, ok := ident.X.(*ast.Ident); ok {
if ident.Name == "time" {
col.t = "DATETIME"
}
}
}
default:
return nil, fmt.Errorf("%v field type unknown: %v", col.name, reflect.TypeOf(field.Type).Elem().Name())
}
}
if col.t == "" {
return nil, fmt.Errorf("%v doesn't have type", col.name)
}
return col, nil
}
- 当指定了列名时使用指定的值,否则使用字段名,对于匿名字段则使用类型名
- 不支持单行定义多个字段
- 当指定了列类型时使用指定的值,否则使用字段类型对应的sql类型, 其中time.Time默认对应DATETIME类型
然后返回所有的列:
func genColumns(table *tableStruct) ([]*Column, error) {
columns := []*Column{}
for _, field := range table.node.Fields.List {
col, err := genColumn(field)
if err != nil {
return nil, err
}
columns = append(columns, col)
}
for _, col := range columns {
if col.name == "Id" && col.t == "BIGINT" {
col.pk = "PRIMARY KEY"
}
}
if 0 == len(columns) {
return nil, fmt.Errorf("no column")
}
return columns, nil
}
这里有如下要求:
- 一个表至少要有一个列
- 如果有列名为Id且类型为BIGINT则默认作为主键
4 生成sql语句
在获取了表名以及列属性列表之后,生成sql语句:
func genSql(table *tableStruct) (string, error) {
columns, err := genColumns(table)
if err != nil {
return "", err
}
// 开始生成sql
builder := strings.Builder{}
builder.WriteString("CREATE TABLE `")
builder.WriteString(table.tableName)
builder.WriteString("` (")
for i, col := range columns {
builder.WriteString("\n\t`")
builder.WriteString(col.name)
builder.WriteString("` ")
builder.WriteString(col.t)
if col.notnull != "" {
builder.WriteString(" NOT NULL")
}
if col.created != "" {
builder.WriteString(" " + col.created)
}
if col.updated != "" {
builder.WriteString(" " + col.updated)
}
if i != len(columns)-1 {
builder.WriteString(",")
}
}
for _, col := range columns {
if col.pk != "" {
builder.WriteString(",\n\t")
builder.WriteString("PRIMARY KEY(`")
builder.WriteString(col.name)
builder.WriteString("`)")
}
if col.unique != "" {
builder.WriteString(",\n\t")
builder.WriteString("UNIQUE KEY `")
builder.WriteString(col.name)
builder.WriteString("` (`")
builder.WriteString(col.name)
builder.WriteString("`)")
}
}
builder.WriteString("\n);")
return builder.String(), nil
}
总结
可以看到,通过goast我们可以提取非常详细的源码信息,但是相应的,这种操作具有相当的复杂度,并且不如反射能得到更多的运行时信息。