264 lines
6.1 KiB
Go
264 lines
6.1 KiB
Go
package main
|
||
|
||
import (
|
||
"errors"
|
||
"flag"
|
||
"fmt"
|
||
"io"
|
||
"io/ioutil"
|
||
"os"
|
||
"regexp"
|
||
"strings"
|
||
"text/template"
|
||
|
||
"github.com/sirupsen/logrus"
|
||
|
||
"github.com/xwb1989/sqlparser"
|
||
|
||
"github.com/ehlxr/go-utils/utils/generator/tmpl"
|
||
"github.com/ehlxr/go-utils/utils/log"
|
||
)
|
||
|
||
// 定义 sqlData 数据结构。
|
||
type SqlData struct {
|
||
ClassName string
|
||
ClassNameLower string
|
||
PackageName string
|
||
TableName string
|
||
ColumnList []SqlColumn
|
||
}
|
||
|
||
// 表属性
|
||
type SqlColumn struct {
|
||
Name string // 英文小写名称
|
||
CnName string // 输入的中文名称
|
||
Type string // 类型
|
||
GoName string // Golang 名称
|
||
GoType string // Golang 类型
|
||
IsNumber bool // 判断是否数字
|
||
ColName string
|
||
ColType string
|
||
NameUpper string
|
||
IsLast bool
|
||
IsId bool // 判断是否是 Id
|
||
IsDateTime bool // 判断是否是时间类型
|
||
}
|
||
|
||
var (
|
||
// Options
|
||
flgSql = flag.String("sql", "", "path to sql file or '-' to read from stdin")
|
||
flagClz = flag.String("clz", "", "ClassName")
|
||
flagPkg = flag.String("pkg", "", "PackageName")
|
||
)
|
||
|
||
func main() {
|
||
log.SetLogLevel(logrus.InfoLevel)
|
||
flag.Parse()
|
||
|
||
if *flgSql == "" {
|
||
log.Fatal("sql is nil")
|
||
}
|
||
if *flagClz == "" {
|
||
log.Fatal("ClassName is nil")
|
||
}
|
||
if *flagPkg == "" {
|
||
log.Fatal("PackageName is nil")
|
||
}
|
||
|
||
loadSql, err := loadData(*flgSql)
|
||
if err != nil {
|
||
log.Fatal(err)
|
||
}
|
||
|
||
sql := string(loadSql)
|
||
log.Debugf("input sql %s", sql)
|
||
|
||
// 规则引擎的要求,) 后面的不能解析。
|
||
last := strings.LastIndex(sql, ")")
|
||
if last > 0 {
|
||
sql = sql[:last+1]
|
||
}
|
||
//AUTO_INCREMENT 不解析,只解析标准的 sql
|
||
sql = strings.Replace(sql, "AUTO_INCREMENT", "", -1)
|
||
sql = strings.Replace(sql, "auto_increment", "", -1)
|
||
log.Debugf("new sql %s", sql)
|
||
|
||
sqlNode, err := sqlparser.ParseStrictDDL(sql)
|
||
if err != nil {
|
||
log.Fatal(err)
|
||
}
|
||
|
||
sqlDataTmp, err := getSqlData(sqlNode)
|
||
if err != nil {
|
||
log.Fatal(err)
|
||
}
|
||
|
||
log.Debugf("sql data: %v", sqlDataTmp)
|
||
for i, v := range sqlDataTmp.ColumnList {
|
||
log.Debugf("index: %d, value: %v", i, v)
|
||
}
|
||
|
||
// path, err := filepath.Abs("")
|
||
// if err != nil {
|
||
// log.Fatal(err)
|
||
// }
|
||
// pojo := filepath.Join(path, "tmpl/java/pojo.tmpl")
|
||
// log.Info(pojo)
|
||
|
||
data := make(map[string]interface{})
|
||
data["SqlData"] = sqlDataTmp
|
||
|
||
t := template.New("javapojo")
|
||
t = template.Must(t.Parse(tmpl.Pojo))
|
||
err = t.Execute(os.Stdout, data)
|
||
if err != nil {
|
||
log.Fatal(err)
|
||
}
|
||
|
||
t = template.New("javacontroller")
|
||
t = template.Must(t.Parse(tmpl.Controller))
|
||
err = t.Execute(os.Stdout, data)
|
||
if err != nil {
|
||
log.Fatal(err)
|
||
}
|
||
}
|
||
|
||
// Helper func: Read input from specified file or stdin
|
||
func loadData(p string) ([]byte, error) {
|
||
if p == "" {
|
||
return nil, fmt.Errorf("No path specified")
|
||
}
|
||
|
||
var rdr io.Reader
|
||
if p == "-" {
|
||
rdr = os.Stdin
|
||
} else if p == "+" {
|
||
return []byte("{}"), nil
|
||
} else {
|
||
if f, err := os.Open(p); err == nil {
|
||
rdr = f
|
||
defer f.Close()
|
||
} else {
|
||
return nil, err
|
||
}
|
||
}
|
||
return ioutil.ReadAll(rdr)
|
||
}
|
||
|
||
// 解析 sql 并返回 sqlData 数据
|
||
func getSqlData(sqlNode sqlparser.SQLNode) (*SqlData, error) {
|
||
node, ok := sqlNode.(*sqlparser.DDL)
|
||
if !ok {
|
||
return nil, errors.New("不是标准的创建 sql 语句")
|
||
}
|
||
|
||
//返回首字母小写
|
||
first := string((*flagClz)[0])
|
||
clzNameLower := strings.ToLower(first) + (*flagClz)[1:]
|
||
|
||
sqlData := &SqlData{
|
||
PackageName: *flagPkg,
|
||
ClassNameLower: clzNameLower,
|
||
ClassName: *flagClz,
|
||
TableName: node.Table.Name.String(),
|
||
ColumnList: getAllColumn(node)}
|
||
|
||
return sqlData, nil
|
||
}
|
||
|
||
// 返回属性数组
|
||
func getAllColumn(node *sqlparser.DDL) []SqlColumn {
|
||
columnList := []SqlColumn{}
|
||
for index, col := range node.TableSpec.Columns {
|
||
colName := col.Name.String()
|
||
colType := col.Type.Type
|
||
|
||
tmpColumn := SqlColumn{ColName: colName, ColType: colType}
|
||
// 设置首字母大写,驼峰命名
|
||
tmpColumn.Name, tmpColumn.GoName = getNameByTitle(colName)
|
||
// 设置类型
|
||
tmpColumn.Type, tmpColumn.GoType = getType(colType)
|
||
// 返回首字母大写
|
||
first := string(tmpColumn.Name[0])
|
||
nameUpper := strings.ToUpper(first) + tmpColumn.Name[1:]
|
||
tmpColumn.NameUpper = nameUpper
|
||
// 判断是否是最后一个数据
|
||
if index == (len(node.TableSpec.Columns) - 1) {
|
||
tmpColumn.IsLast = true
|
||
} else {
|
||
tmpColumn.IsLast = false
|
||
}
|
||
//判断数据是否 == Id
|
||
if strings.ToLower(colName) == "id" {
|
||
tmpColumn.IsId = true
|
||
} else {
|
||
tmpColumn.IsId = false
|
||
}
|
||
// 判断数据是否 DateTime 类型
|
||
if strings.ToLower(colType) == "datetime" {
|
||
tmpColumn.IsDateTime = true
|
||
} else {
|
||
tmpColumn.IsDateTime = false
|
||
}
|
||
// 添加数据
|
||
columnList = append(columnList, tmpColumn)
|
||
}
|
||
return columnList
|
||
}
|
||
|
||
// 获取 title 拆分的名称。java的 和 golang 的两个名称
|
||
func getNameByTitle(colName string) (string, string) {
|
||
if colName == "" {
|
||
return "", ""
|
||
}
|
||
tmp := strings.Replace(colName, "_", " ", -1)
|
||
tmp = strings.Title(tmp) //title 支持空格分开的都title。
|
||
tmp = strings.Replace(tmp, " ", "", -1)
|
||
//返回首字母小写
|
||
first := string(tmp[0])
|
||
tmp2 := strings.ToLower(first) + tmp[1:]
|
||
return tmp2, tmp
|
||
}
|
||
|
||
// 获取类型,java 和 golang 的两个类型
|
||
func getType(typeName string) (string, string) {
|
||
if typeName == "" {
|
||
return "", ""
|
||
}
|
||
var stringReg = regexp.MustCompile(`varchar|char`) // Has digit(s)
|
||
var dateTimeReg = regexp.MustCompile(`datetime|date`) // Has digit(s)
|
||
var longReg = regexp.MustCompile(`bigint|long`) // Has digit(s)
|
||
var integerReg = regexp.MustCompile(`integer|int`) // Has digit(s)
|
||
var floatReg = regexp.MustCompile(`float`) // Has digit(s)
|
||
var doubleReg = regexp.MustCompile(`float`) // Has digit(s)
|
||
|
||
var tmp, goTmp string
|
||
switch {
|
||
case stringReg.MatchString(typeName):
|
||
tmp = "String"
|
||
goTmp = "string"
|
||
break
|
||
case dateTimeReg.MatchString(typeName):
|
||
tmp = "Date"
|
||
goTmp = "time.Time"
|
||
break
|
||
case longReg.MatchString(typeName):
|
||
tmp = "Long"
|
||
goTmp = "int64"
|
||
break
|
||
case integerReg.MatchString(typeName):
|
||
tmp = "Integer"
|
||
goTmp = "int32"
|
||
break
|
||
case floatReg.MatchString(typeName):
|
||
tmp = "Float"
|
||
goTmp = "float64"
|
||
break
|
||
case doubleReg.MatchString(typeName):
|
||
tmp = "Double"
|
||
goTmp = "float64"
|
||
break
|
||
}
|
||
return tmp, goTmp
|
||
}
|