264 lines
6.1 KiB
Go
264 lines
6.1 KiB
Go
|
package main
|
|||
|
|
|||
|
import (
|
|||
|
"errors"
|
|||
|
"flag"
|
|||
|
"fmt"
|
|||
|
"io"
|
|||
|
"io/ioutil"
|
|||
|
"os"
|
|||
|
"regexp"
|
|||
|
"strings"
|
|||
|
"text/template"
|
|||
|
|
|||
|
"github.com/sirupsen/logrus"
|
|||
|
|
|||
|
"github.com/xwb1989/sqlparser"
|
|||
|
|
|||
|
"github.com/ehlxr/go-utils/utils/generator/tmpl"
|
|||
|
"github.com/ehlxr/go-utils/utils/log"
|
|||
|
)
|
|||
|
|
|||
|
// 定义 sqlData 数据结构。
|
|||
|
type SqlData struct {
|
|||
|
ClassName string
|
|||
|
ClassNameLower string
|
|||
|
PackageName string
|
|||
|
TableName string
|
|||
|
ColumnList []SqlColumn
|
|||
|
}
|
|||
|
|
|||
|
// 表属性
|
|||
|
type SqlColumn struct {
|
|||
|
Name string // 英文小写名称
|
|||
|
CnName string // 输入的中文名称
|
|||
|
Type string // 类型
|
|||
|
GoName string // Golang 名称
|
|||
|
GoType string // Golang 类型
|
|||
|
IsNumber bool // 判断是否数字
|
|||
|
ColName string
|
|||
|
ColType string
|
|||
|
NameUpper string
|
|||
|
IsLast bool
|
|||
|
IsId bool // 判断是否是 Id
|
|||
|
IsDateTime bool // 判断是否是时间类型
|
|||
|
}
|
|||
|
|
|||
|
var (
|
|||
|
// Options
|
|||
|
flgSql = flag.String("sql", "", "path to sql file or '-' to read from stdin")
|
|||
|
flagClz = flag.String("clz", "", "ClassName")
|
|||
|
flagPkg = flag.String("pkg", "", "PackageName")
|
|||
|
)
|
|||
|
|
|||
|
func main() {
|
|||
|
log.SetLogLevel(logrus.InfoLevel)
|
|||
|
flag.Parse()
|
|||
|
|
|||
|
if *flgSql == "" {
|
|||
|
log.Fatal("sql is nil")
|
|||
|
}
|
|||
|
if *flagClz == "" {
|
|||
|
log.Fatal("ClassName is nil")
|
|||
|
}
|
|||
|
if *flagPkg == "" {
|
|||
|
log.Fatal("PackageName is nil")
|
|||
|
}
|
|||
|
|
|||
|
loadSql, err := loadData(*flgSql)
|
|||
|
if err != nil {
|
|||
|
log.Fatal(err)
|
|||
|
}
|
|||
|
|
|||
|
sql := string(loadSql)
|
|||
|
log.Debugf("input sql %s", sql)
|
|||
|
|
|||
|
// 规则引擎的要求,) 后面的不能解析。
|
|||
|
last := strings.LastIndex(sql, ")")
|
|||
|
if last > 0 {
|
|||
|
sql = sql[:last+1]
|
|||
|
}
|
|||
|
//AUTO_INCREMENT 不解析,只解析标准的 sql
|
|||
|
sql = strings.Replace(sql, "AUTO_INCREMENT", "", -1)
|
|||
|
sql = strings.Replace(sql, "auto_increment", "", -1)
|
|||
|
log.Debugf("new sql %s", sql)
|
|||
|
|
|||
|
sqlNode, err := sqlparser.ParseStrictDDL(sql)
|
|||
|
if err != nil {
|
|||
|
log.Fatal(err)
|
|||
|
}
|
|||
|
|
|||
|
sqlDataTmp, err := getSqlData(sqlNode)
|
|||
|
if err != nil {
|
|||
|
log.Fatal(err)
|
|||
|
}
|
|||
|
|
|||
|
log.Debugf("sql data: %v", sqlDataTmp)
|
|||
|
for i, v := range sqlDataTmp.ColumnList {
|
|||
|
log.Debugf("index: %d, value: %v", i, v)
|
|||
|
}
|
|||
|
|
|||
|
// path, err := filepath.Abs("")
|
|||
|
// if err != nil {
|
|||
|
// log.Fatal(err)
|
|||
|
// }
|
|||
|
// pojo := filepath.Join(path, "tmpl/java/pojo.tmpl")
|
|||
|
// log.Info(pojo)
|
|||
|
|
|||
|
data := make(map[string]interface{})
|
|||
|
data["SqlData"] = sqlDataTmp
|
|||
|
|
|||
|
t := template.New("javapojo")
|
|||
|
t = template.Must(t.Parse(tmpl.Pojo))
|
|||
|
err = t.Execute(os.Stdout, data)
|
|||
|
if err != nil {
|
|||
|
log.Fatal(err)
|
|||
|
}
|
|||
|
|
|||
|
t = template.New("javacontroller")
|
|||
|
t = template.Must(t.Parse(tmpl.Controller))
|
|||
|
err = t.Execute(os.Stdout, data)
|
|||
|
if err != nil {
|
|||
|
log.Fatal(err)
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
// Helper func: Read input from specified file or stdin
|
|||
|
func loadData(p string) ([]byte, error) {
|
|||
|
if p == "" {
|
|||
|
return nil, fmt.Errorf("No path specified")
|
|||
|
}
|
|||
|
|
|||
|
var rdr io.Reader
|
|||
|
if p == "-" {
|
|||
|
rdr = os.Stdin
|
|||
|
} else if p == "+" {
|
|||
|
return []byte("{}"), nil
|
|||
|
} else {
|
|||
|
if f, err := os.Open(p); err == nil {
|
|||
|
rdr = f
|
|||
|
defer f.Close()
|
|||
|
} else {
|
|||
|
return nil, err
|
|||
|
}
|
|||
|
}
|
|||
|
return ioutil.ReadAll(rdr)
|
|||
|
}
|
|||
|
|
|||
|
// 解析 sql 并返回 sqlData 数据
|
|||
|
func getSqlData(sqlNode sqlparser.SQLNode) (*SqlData, error) {
|
|||
|
node, ok := sqlNode.(*sqlparser.DDL)
|
|||
|
if !ok {
|
|||
|
return nil, errors.New("不是标准的创建 sql 语句")
|
|||
|
}
|
|||
|
|
|||
|
//返回首字母小写
|
|||
|
first := string((*flagClz)[0])
|
|||
|
clzNameLower := strings.ToLower(first) + (*flagClz)[1:]
|
|||
|
|
|||
|
sqlData := &SqlData{
|
|||
|
PackageName: *flagPkg,
|
|||
|
ClassNameLower: clzNameLower,
|
|||
|
ClassName: *flagClz,
|
|||
|
TableName: node.Table.Name.String(),
|
|||
|
ColumnList: getAllColumn(node)}
|
|||
|
|
|||
|
return sqlData, nil
|
|||
|
}
|
|||
|
|
|||
|
// 返回属性数组
|
|||
|
func getAllColumn(node *sqlparser.DDL) []SqlColumn {
|
|||
|
columnList := []SqlColumn{}
|
|||
|
for index, col := range node.TableSpec.Columns {
|
|||
|
colName := col.Name.String()
|
|||
|
colType := col.Type.Type
|
|||
|
|
|||
|
tmpColumn := SqlColumn{ColName: colName, ColType: colType}
|
|||
|
// 设置首字母大写,驼峰命名
|
|||
|
tmpColumn.Name, tmpColumn.GoName = getNameByTitle(colName)
|
|||
|
// 设置类型
|
|||
|
tmpColumn.Type, tmpColumn.GoType = getType(colType)
|
|||
|
// 返回首字母大写
|
|||
|
first := string(tmpColumn.Name[0])
|
|||
|
nameUpper := strings.ToUpper(first) + tmpColumn.Name[1:]
|
|||
|
tmpColumn.NameUpper = nameUpper
|
|||
|
// 判断是否是最后一个数据
|
|||
|
if index == (len(node.TableSpec.Columns) - 1) {
|
|||
|
tmpColumn.IsLast = true
|
|||
|
} else {
|
|||
|
tmpColumn.IsLast = false
|
|||
|
}
|
|||
|
//判断数据是否 == Id
|
|||
|
if strings.ToLower(colName) == "id" {
|
|||
|
tmpColumn.IsId = true
|
|||
|
} else {
|
|||
|
tmpColumn.IsId = false
|
|||
|
}
|
|||
|
// 判断数据是否 DateTime 类型
|
|||
|
if strings.ToLower(colType) == "datetime" {
|
|||
|
tmpColumn.IsDateTime = true
|
|||
|
} else {
|
|||
|
tmpColumn.IsDateTime = false
|
|||
|
}
|
|||
|
// 添加数据
|
|||
|
columnList = append(columnList, tmpColumn)
|
|||
|
}
|
|||
|
return columnList
|
|||
|
}
|
|||
|
|
|||
|
// 获取 title 拆分的名称。java的 和 golang 的两个名称
|
|||
|
func getNameByTitle(colName string) (string, string) {
|
|||
|
if colName == "" {
|
|||
|
return "", ""
|
|||
|
}
|
|||
|
tmp := strings.Replace(colName, "_", " ", -1)
|
|||
|
tmp = strings.Title(tmp) //title 支持空格分开的都title。
|
|||
|
tmp = strings.Replace(tmp, " ", "", -1)
|
|||
|
//返回首字母小写
|
|||
|
first := string(tmp[0])
|
|||
|
tmp2 := strings.ToLower(first) + tmp[1:]
|
|||
|
return tmp2, tmp
|
|||
|
}
|
|||
|
|
|||
|
// 获取类型,java 和 golang 的两个类型
|
|||
|
func getType(typeName string) (string, string) {
|
|||
|
if typeName == "" {
|
|||
|
return "", ""
|
|||
|
}
|
|||
|
var stringReg = regexp.MustCompile(`varchar|char`) // Has digit(s)
|
|||
|
var dateTimeReg = regexp.MustCompile(`datetime|date`) // Has digit(s)
|
|||
|
var longReg = regexp.MustCompile(`bigint|long`) // Has digit(s)
|
|||
|
var integerReg = regexp.MustCompile(`integer|int`) // Has digit(s)
|
|||
|
var floatReg = regexp.MustCompile(`float`) // Has digit(s)
|
|||
|
var doubleReg = regexp.MustCompile(`float`) // Has digit(s)
|
|||
|
|
|||
|
var tmp, goTmp string
|
|||
|
switch {
|
|||
|
case stringReg.MatchString(typeName):
|
|||
|
tmp = "String"
|
|||
|
goTmp = "string"
|
|||
|
break
|
|||
|
case dateTimeReg.MatchString(typeName):
|
|||
|
tmp = "Date"
|
|||
|
goTmp = "time.Time"
|
|||
|
break
|
|||
|
case longReg.MatchString(typeName):
|
|||
|
tmp = "Long"
|
|||
|
goTmp = "int64"
|
|||
|
break
|
|||
|
case integerReg.MatchString(typeName):
|
|||
|
tmp = "Integer"
|
|||
|
goTmp = "int32"
|
|||
|
break
|
|||
|
case floatReg.MatchString(typeName):
|
|||
|
tmp = "Float"
|
|||
|
goTmp = "float64"
|
|||
|
break
|
|||
|
case doubleReg.MatchString(typeName):
|
|||
|
tmp = "Double"
|
|||
|
goTmp = "float64"
|
|||
|
break
|
|||
|
}
|
|||
|
return tmp, goTmp
|
|||
|
}
|