code generator (unfinished)

This commit is contained in:
ehlxr 2018-01-31 14:59:55 +08:00
parent eeb4d31572
commit bbd5c5b060
7 changed files with 419 additions and 4 deletions

10
glide.lock generated
View File

@ -1,5 +1,5 @@
hash: c0f781010f60ae6afa49239d110b06ba51e08cf5d3691d4fe99b6c8990c3f78f
updated: 2018-01-24T16:19:54.509676+08:00
hash: 834ddc9f3063d1e2c3476d57dab0adc2af56a14a2515327d9758a4d4dd677653
updated: 2018-01-30T17:25:41.80486+08:00
imports:
- name: github.com/dgrijalva/jwt-go
version: dbeaa9332f19a944acb5736b4456cfcc02140e29
@ -13,12 +13,14 @@ imports:
version: d682213848ed68c0a260ca37d6dd5ace8423f5ba
- name: github.com/x-cray/logrus-prefixed-formatter
version: bb2702d423886830dee131692131d35648c382e2
- name: github.com/xwb1989/sqlparser
version: da747e0c62c4e145ebd7484cf69f92c0dd192305
- name: golang.org/x/crypto
version: 3d37316aaa6bd9929127ac9a527abf408178ea7b
version: 1875d0a70c90e57f11972aefd42276df65e895b9
subpackages:
- ssh/terminal
- name: golang.org/x/sys
version: af50095a40f9041b3b38960738837185c26e9419
version: 3dbebcf8efb6a5011a60c2b4591c1022a759af8a
subpackages:
- unix
- windows

View File

@ -6,3 +6,5 @@ import:
version: ^1.0.4
- package: github.com/x-cray/logrus-prefixed-formatter
version: ^0.5.2
- package: github.com/xwb1989/sqlparser
branch: master

BIN
utils/generator/generator Executable file

Binary file not shown.

263
utils/generator/main.go Normal file
View File

@ -0,0 +1,263 @@
package main
import (
"errors"
"flag"
"fmt"
"io"
"io/ioutil"
"os"
"regexp"
"strings"
"text/template"
"github.com/sirupsen/logrus"
"github.com/xwb1989/sqlparser"
"github.com/ehlxr/go-utils/utils/generator/tmpl"
"github.com/ehlxr/go-utils/utils/log"
)
// 定义 sqlData 数据结构。
type SqlData struct {
ClassName string
ClassNameLower string
PackageName string
TableName string
ColumnList []SqlColumn
}
// 表属性
type SqlColumn struct {
Name string // 英文小写名称
CnName string // 输入的中文名称
Type string // 类型
GoName string // Golang 名称
GoType string // Golang 类型
IsNumber bool // 判断是否数字
ColName string
ColType string
NameUpper string
IsLast bool
IsId bool // 判断是否是 Id
IsDateTime bool // 判断是否是时间类型
}
var (
// Options
flgSql = flag.String("sql", "", "path to sql file or '-' to read from stdin")
flagClz = flag.String("clz", "", "ClassName")
flagPkg = flag.String("pkg", "", "PackageName")
)
func main() {
log.SetLogLevel(logrus.InfoLevel)
flag.Parse()
if *flgSql == "" {
log.Fatal("sql is nil")
}
if *flagClz == "" {
log.Fatal("ClassName is nil")
}
if *flagPkg == "" {
log.Fatal("PackageName is nil")
}
loadSql, err := loadData(*flgSql)
if err != nil {
log.Fatal(err)
}
sql := string(loadSql)
log.Debugf("input sql %s", sql)
// 规则引擎的要求,) 后面的不能解析。
last := strings.LastIndex(sql, ")")
if last > 0 {
sql = sql[:last+1]
}
//AUTO_INCREMENT 不解析,只解析标准的 sql
sql = strings.Replace(sql, "AUTO_INCREMENT", "", -1)
sql = strings.Replace(sql, "auto_increment", "", -1)
log.Debugf("new sql %s", sql)
sqlNode, err := sqlparser.ParseStrictDDL(sql)
if err != nil {
log.Fatal(err)
}
sqlDataTmp, err := getSqlData(sqlNode)
if err != nil {
log.Fatal(err)
}
log.Debugf("sql data: %v", sqlDataTmp)
for i, v := range sqlDataTmp.ColumnList {
log.Debugf("index: %d, value: %v", i, v)
}
// path, err := filepath.Abs("")
// if err != nil {
// log.Fatal(err)
// }
// pojo := filepath.Join(path, "tmpl/java/pojo.tmpl")
// log.Info(pojo)
data := make(map[string]interface{})
data["SqlData"] = sqlDataTmp
t := template.New("javapojo")
t = template.Must(t.Parse(tmpl.Pojo))
err = t.Execute(os.Stdout, data)
if err != nil {
log.Fatal(err)
}
t = template.New("javacontroller")
t = template.Must(t.Parse(tmpl.Controller))
err = t.Execute(os.Stdout, data)
if err != nil {
log.Fatal(err)
}
}
// Helper func: Read input from specified file or stdin
func loadData(p string) ([]byte, error) {
if p == "" {
return nil, fmt.Errorf("No path specified")
}
var rdr io.Reader
if p == "-" {
rdr = os.Stdin
} else if p == "+" {
return []byte("{}"), nil
} else {
if f, err := os.Open(p); err == nil {
rdr = f
defer f.Close()
} else {
return nil, err
}
}
return ioutil.ReadAll(rdr)
}
// 解析 sql 并返回 sqlData 数据
func getSqlData(sqlNode sqlparser.SQLNode) (*SqlData, error) {
node, ok := sqlNode.(*sqlparser.DDL)
if !ok {
return nil, errors.New("不是标准的创建 sql 语句")
}
//返回首字母小写
first := string((*flagClz)[0])
clzNameLower := strings.ToLower(first) + (*flagClz)[1:]
sqlData := &SqlData{
PackageName: *flagPkg,
ClassNameLower: clzNameLower,
ClassName: *flagClz,
TableName: node.Table.Name.String(),
ColumnList: getAllColumn(node)}
return sqlData, nil
}
// 返回属性数组
func getAllColumn(node *sqlparser.DDL) []SqlColumn {
columnList := []SqlColumn{}
for index, col := range node.TableSpec.Columns {
colName := col.Name.String()
colType := col.Type.Type
tmpColumn := SqlColumn{ColName: colName, ColType: colType}
// 设置首字母大写,驼峰命名
tmpColumn.Name, tmpColumn.GoName = getNameByTitle(colName)
// 设置类型
tmpColumn.Type, tmpColumn.GoType = getType(colType)
// 返回首字母大写
first := string(tmpColumn.Name[0])
nameUpper := strings.ToUpper(first) + tmpColumn.Name[1:]
tmpColumn.NameUpper = nameUpper
// 判断是否是最后一个数据
if index == (len(node.TableSpec.Columns) - 1) {
tmpColumn.IsLast = true
} else {
tmpColumn.IsLast = false
}
//判断数据是否 == Id
if strings.ToLower(colName) == "id" {
tmpColumn.IsId = true
} else {
tmpColumn.IsId = false
}
// 判断数据是否 DateTime 类型
if strings.ToLower(colType) == "datetime" {
tmpColumn.IsDateTime = true
} else {
tmpColumn.IsDateTime = false
}
// 添加数据
columnList = append(columnList, tmpColumn)
}
return columnList
}
// 获取 title 拆分的名称。java的 和 golang 的两个名称
func getNameByTitle(colName string) (string, string) {
if colName == "" {
return "", ""
}
tmp := strings.Replace(colName, "_", " ", -1)
tmp = strings.Title(tmp) //title 支持空格分开的都title。
tmp = strings.Replace(tmp, " ", "", -1)
//返回首字母小写
first := string(tmp[0])
tmp2 := strings.ToLower(first) + tmp[1:]
return tmp2, tmp
}
// 获取类型java 和 golang 的两个类型
func getType(typeName string) (string, string) {
if typeName == "" {
return "", ""
}
var stringReg = regexp.MustCompile(`varchar|char`) // Has digit(s)
var dateTimeReg = regexp.MustCompile(`datetime|date`) // Has digit(s)
var longReg = regexp.MustCompile(`bigint|long`) // Has digit(s)
var integerReg = regexp.MustCompile(`integer|int`) // Has digit(s)
var floatReg = regexp.MustCompile(`float`) // Has digit(s)
var doubleReg = regexp.MustCompile(`float`) // Has digit(s)
var tmp, goTmp string
switch {
case stringReg.MatchString(typeName):
tmp = "String"
goTmp = "string"
break
case dateTimeReg.MatchString(typeName):
tmp = "Date"
goTmp = "time.Time"
break
case longReg.MatchString(typeName):
tmp = "Long"
goTmp = "int64"
break
case integerReg.MatchString(typeName):
tmp = "Integer"
goTmp = "int32"
break
case floatReg.MatchString(typeName):
tmp = "Float"
goTmp = "float64"
break
case doubleReg.MatchString(typeName):
tmp = "Double"
goTmp = "float64"
break
}
return tmp, goTmp
}

1
utils/generator/sql.sql Normal file
View File

@ -0,0 +1 @@
CREATE TABLE `user_info` (`id` bigint(20) NOT NULL PRIMARY KEY AUTO_INCREMENT,`name` varchar(200) NOT NULL,`password` varchar(200) NOT NULL,`status` tinyint(1) NOT NULL ,`type` tinyint(1) NOT NULL ,`create_time` datetime NOT NULL,`update_time` datetime NOT NULL) ENGINE=InnoDB DEFAULT CHARSET=utf8 ;

View File

@ -0,0 +1,121 @@
package tmpl
var Controller = `
package {{.SqlData.PackageName}}.web;
import com.google.common.base.Strings;
import {{.SqlData.PackageName}}.domain.{{.SqlData.ClassName}};
import {{.SqlData.PackageName}}.common.page.Page;
import {{.SqlData.PackageName}}.common.Constants;
import {{.SqlData.PackageName}}.service.{{.SqlData.ClassName}}Service;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Controller;
import org.springframework.ui.Model;
import org.springframework.web.bind.annotation.*;
import javax.servlet.http.HttpServletResponse;
import java.util.HashMap;
import java.util.Map;
@Controller
@RequestMapping(value = "/admin/{{.SqlData.ClassNameLower}}", method = {RequestMethod.GET, RequestMethod.POST})
public class {{.SqlData.ClassName}}Controller {
private static final Logger logger = LoggerFactory.getLogger({{.SqlData.ClassName}}Controller.class);
@Autowired
private {{.SqlData.ClassName}}Service {{.SqlData.ClassNameLower}}Service;
private static Integer pageSize = 15;
@RequestMapping(value = "/edit", method = {RequestMethod.GET})
public String edit(@RequestParam(value = "id", defaultValue = "0") Long id,
Model view) {
try {
{{.SqlData.ClassName}} {{.SqlData.ClassNameLower}} = null;
if (id != null && id.longValue() > 0) {
{{.SqlData.ClassNameLower}} = {{.SqlData.ClassNameLower}}Service.query{{.SqlData.ClassName}}ById(id);
}else{
{{.SqlData.ClassNameLower}} = new {{.SqlData.ClassName}}();
}
view.addAttribute("{{.SqlData.ClassNameLower}}", {{.SqlData.ClassNameLower}});
} catch (Exception e) {
logger.error(e.getMessage(), e);
}
return "admin/{{.SqlData.ClassNameLower}}/edit";
}
@RequestMapping(value = "/view", method = {RequestMethod.GET})
public String view(@RequestParam(value = "id", defaultValue = "0") Long id,
Model view) {
try {
{{.SqlData.ClassName}} {{.SqlData.ClassNameLower}} = null;
if (id != null && id.longValue() > 0) {
{{.SqlData.ClassNameLower}} = {{.SqlData.ClassNameLower}}Service.query{{.SqlData.ClassName}}ById(id);
}else{
{{.SqlData.ClassNameLower}} = new {{.SqlData.ClassName}}();
}
view.addAttribute("{{.SqlData.ClassNameLower}}", {{.SqlData.ClassNameLower}});
} catch (Exception e) {
logger.error(e.getMessage(), e);
}
return "admin/{{.SqlData.ClassNameLower}}/view";
}
@RequestMapping(value = "/delete", method = {RequestMethod.DELETE})
@ResponseBody
public String delete(@RequestParam(value = "id", defaultValue = "0") Long id,
Model view) {
try {
long rows = {{.SqlData.ClassNameLower}}Service.delete{{.SqlData.ClassName}}(id);
} catch (Exception e) {
logger.error(e.getMessage(), e);
}
return String.format(Constants.WEB_IFRAME_SCRIPT, "删除成功!");
}
@RequestMapping(value = "/save", method = {RequestMethod.POST})
@ResponseBody
public String save({{.SqlData.ClassName}} {{.SqlData.ClassNameLower}},
Model view) {
try {
long rows = {{.SqlData.ClassNameLower}}Service.save{{.SqlData.ClassName}}({{.SqlData.ClassNameLower}});
view.addAttribute("{{.SqlData.ClassNameLower}}", {{.SqlData.ClassNameLower}});
} catch (Exception e) {
logger.error(e.getMessage(), e);
}
return String.format(Constants.WEB_IFRAME_SCRIPT, "保存成功!");
}
@RequestMapping(value = "/list", method = {RequestMethod.GET, RequestMethod.POST})
public String list(@RequestParam(value = "page", defaultValue = "0") int page,
@RequestParam(value = "id", required = false) Long id,
Model view) {
try {
//查询
Map<String, Object> search = new HashMap<String, Object>();
if (id != null) {
search.put("id", id);
}
Page<{{.SqlData.ClassName}}> pageData = {{.SqlData.ClassNameLower}}Service.query{{.SqlData.ClassName}}Page(page, pageSize,search);
//放入page对象。
view.addAttribute("pageData", pageData);
view.addAttribute("id", id);
} catch (Exception e) {
logger.error(e.getMessage(), e);
}
return "admin/{{.SqlData.ClassNameLower}}/list";
}
}
`

View File

@ -0,0 +1,26 @@
package tmpl
var Pojo = `
package {{.SqlData.PackageName}}.domain;
import java.util.Date;
/**
* @table {{.SqlData.ClassNameLower}}
*/
public class {{.SqlData.ClassName}} {
private static final long serialVersionUID = 1L;
{{range $index, $column := .SqlData.ColumnList}}
private {{$column.Type}} {{$column.Name}}; {{if $column.CnName}} // $column.CnName {{end}}
{{end}}
{{range $index, $column := .SqlData.ColumnList}}
public {{$column.Type}} get{{$column.NameUpper}}() {
return this.{{$column.Name}};
}
public void set{{$column.NameUpper}}({{$column.Type}} {{$column.Name}}) {
this.{{$column.Name}} = {{$column.Name}};
}{{end}}
}
`