go-utils/utils/generator/main.go
2018-02-12 09:12:58 +08:00

262 lines
6.0 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"errors"
"flag"
"fmt"
"io"
"io/ioutil"
"os"
"regexp"
"strings"
"text/template"
"github.com/sirupsen/logrus"
"github.com/xwb1989/sqlparser"
"github.com/ehlxr/go-utils/utils/generator/tmpl"
"github.com/ehlxr/go-utils/utils/log"
)
// 定义 sqlData 数据结构。
type SqlData struct {
ClassName string
ClassNameLower string
PackageName string
TableName string
ColumnList []SqlColumn
}
// 表属性
type SqlColumn struct {
Name string // 英文小写名称
CnName string // 输入的中文名称
Type string // 类型
GoName string // Golang 名称
GoType string // Golang 类型
IsNumber bool // 判断是否数字
ColName string
ColType string
NameUpper string
IsLast bool
IsId bool // 判断是否是 Id
IsDateTime bool // 判断是否是时间类型
}
var (
// Options
flgSql = flag.String("sql", "", "path to sql file or '-' to read from stdin")
flagClz = flag.String("clz", "", "ClassName")
flagPkg = flag.String("pkg", "", "PackageName")
)
func main() {
log.SetLogLevel(logrus.InfoLevel)
flag.Parse()
if *flgSql == "" {
log.Fatal("sql is nil")
}
if *flagClz == "" {
log.Fatal("ClassName is nil")
}
if *flagPkg == "" {
log.Fatal("PackageName is nil")
}
loadSql, err := loadData(*flgSql)
if err != nil {
log.Fatal(err)
}
sql := string(loadSql)
log.Debugf("input sql %s", sql)
// 规则引擎的要求,) 后面的不能解析。
last := strings.LastIndex(sql, ")")
if last > 0 {
sql = sql[:last+1]
}
//AUTO_INCREMENT 不解析,只解析标准的 sql
sql = strings.Replace(sql, "AUTO_INCREMENT", "", -1)
sql = strings.Replace(sql, "auto_increment", "", -1)
log.Debugf("new sql %s", sql)
sqlNode, err := sqlparser.ParseStrictDDL(sql)
if err != nil {
log.Fatal(err)
}
sqlDataTmp, err := getSqlData(sqlNode)
if err != nil {
log.Fatal(err)
}
log.Debugf("sql data: %v", sqlDataTmp)
for i, v := range sqlDataTmp.ColumnList {
log.Debugf("index: %d, value: %v", i, v)
}
// path, err := filepath.Abs("")
// if err != nil {
// log.Fatal(err)
// }
// pojo := filepath.Join(path, "tmpl/java/pojo.tmpl")
// log.Info(pojo)
data := make(map[string]interface{})
data["SqlData"] = sqlDataTmp
t := template.New("javapojo")
t = template.Must(t.Parse(tmpl.Pojo))
err = t.Execute(os.Stdout, data)
if err != nil {
log.Fatal(err)
}
t = template.New("javacontroller")
t = template.Must(t.Parse(tmpl.Controller))
err = t.Execute(os.Stdout, data)
if err != nil {
log.Fatal(err)
}
}
// Helper func: Read input from specified file or stdin
func loadData(p string) ([]byte, error) {
if p == "" {
return nil, fmt.Errorf("No path specified")
}
var rdr io.Reader
if p == "-" {
rdr = os.Stdin
} else {
if f, err := os.Open(p); err == nil {
rdr = f
defer f.Close()
} else {
return nil, err
}
}
return ioutil.ReadAll(rdr)
}
// 解析 sql 并返回 sqlData 数据
func getSqlData(sqlNode sqlparser.SQLNode) (*SqlData, error) {
node, ok := sqlNode.(*sqlparser.DDL)
if !ok {
return nil, errors.New("不是标准的创建 sql 语句")
}
//返回首字母小写
first := string((*flagClz)[0])
clzNameLower := strings.ToLower(first) + (*flagClz)[1:]
sqlData := &SqlData{
PackageName: *flagPkg,
ClassNameLower: clzNameLower,
ClassName: *flagClz,
TableName: node.Table.Name.String(),
ColumnList: getAllColumn(node)}
return sqlData, nil
}
// 返回属性数组
func getAllColumn(node *sqlparser.DDL) []SqlColumn {
columnList := []SqlColumn{}
for index, col := range node.TableSpec.Columns {
colName := col.Name.String()
colType := col.Type.Type
tmpColumn := SqlColumn{ColName: colName, ColType: colType}
// 设置首字母大写,驼峰命名
tmpColumn.Name, tmpColumn.GoName = getNameByTitle(colName)
// 设置类型
tmpColumn.Type, tmpColumn.GoType = getType(colType)
// 返回首字母大写
first := string(tmpColumn.Name[0])
nameUpper := strings.ToUpper(first) + tmpColumn.Name[1:]
tmpColumn.NameUpper = nameUpper
// 判断是否是最后一个数据
if index == (len(node.TableSpec.Columns) - 1) {
tmpColumn.IsLast = true
} else {
tmpColumn.IsLast = false
}
//判断数据是否 == Id
if strings.ToLower(colName) == "id" {
tmpColumn.IsId = true
} else {
tmpColumn.IsId = false
}
// 判断数据是否 DateTime 类型
if strings.ToLower(colType) == "datetime" {
tmpColumn.IsDateTime = true
} else {
tmpColumn.IsDateTime = false
}
// 添加数据
columnList = append(columnList, tmpColumn)
}
return columnList
}
// 获取 title 拆分的名称。java的 和 golang 的两个名称
func getNameByTitle(colName string) (string, string) {
if colName == "" {
return "", ""
}
tmp := strings.Replace(colName, "_", " ", -1)
tmp = strings.Title(tmp) //title 支持空格分开的都title。
tmp = strings.Replace(tmp, " ", "", -1)
//返回首字母小写
first := string(tmp[0])
tmp2 := strings.ToLower(first) + tmp[1:]
return tmp2, tmp
}
// 获取类型java 和 golang 的两个类型
func getType(typeName string) (string, string) {
if typeName == "" {
return "", ""
}
var stringReg = regexp.MustCompile(`varchar|char`) // Has digit(s)
var dateTimeReg = regexp.MustCompile(`datetime|date`) // Has digit(s)
var longReg = regexp.MustCompile(`bigint|long`) // Has digit(s)
var integerReg = regexp.MustCompile(`integer|int`) // Has digit(s)
var floatReg = regexp.MustCompile(`float`) // Has digit(s)
var doubleReg = regexp.MustCompile(`float`) // Has digit(s)
var tmp, goTmp string
switch {
case stringReg.MatchString(typeName):
tmp = "String"
goTmp = "string"
break
case dateTimeReg.MatchString(typeName):
tmp = "Date"
goTmp = "time.Time"
break
case longReg.MatchString(typeName):
tmp = "Long"
goTmp = "int64"
break
case integerReg.MatchString(typeName):
tmp = "Integer"
goTmp = "int32"
break
case floatReg.MatchString(typeName):
tmp = "Float"
goTmp = "float64"
break
case doubleReg.MatchString(typeName):
tmp = "Double"
goTmp = "float64"
break
}
return tmp, goTmp
}