diff --git a/glide.lock b/glide.lock index cac6d94..d4d76a2 100644 --- a/glide.lock +++ b/glide.lock @@ -1,5 +1,5 @@ -hash: c0f781010f60ae6afa49239d110b06ba51e08cf5d3691d4fe99b6c8990c3f78f -updated: 2018-01-24T16:19:54.509676+08:00 +hash: 834ddc9f3063d1e2c3476d57dab0adc2af56a14a2515327d9758a4d4dd677653 +updated: 2018-01-30T17:25:41.80486+08:00 imports: - name: github.com/dgrijalva/jwt-go version: dbeaa9332f19a944acb5736b4456cfcc02140e29 @@ -13,12 +13,14 @@ imports: version: d682213848ed68c0a260ca37d6dd5ace8423f5ba - name: github.com/x-cray/logrus-prefixed-formatter version: bb2702d423886830dee131692131d35648c382e2 +- name: github.com/xwb1989/sqlparser + version: da747e0c62c4e145ebd7484cf69f92c0dd192305 - name: golang.org/x/crypto - version: 3d37316aaa6bd9929127ac9a527abf408178ea7b + version: 1875d0a70c90e57f11972aefd42276df65e895b9 subpackages: - ssh/terminal - name: golang.org/x/sys - version: af50095a40f9041b3b38960738837185c26e9419 + version: 3dbebcf8efb6a5011a60c2b4591c1022a759af8a subpackages: - unix - windows diff --git a/glide.yaml b/glide.yaml index eff8cc7..1cb9578 100644 --- a/glide.yaml +++ b/glide.yaml @@ -6,3 +6,5 @@ import: version: ^1.0.4 - package: github.com/x-cray/logrus-prefixed-formatter version: ^0.5.2 +- package: github.com/xwb1989/sqlparser + branch: master \ No newline at end of file diff --git a/utils/generator/generator b/utils/generator/generator new file mode 100755 index 0000000..3fa019d Binary files /dev/null and b/utils/generator/generator differ diff --git a/utils/generator/main.go b/utils/generator/main.go new file mode 100644 index 0000000..55cd63a --- /dev/null +++ b/utils/generator/main.go @@ -0,0 +1,263 @@ +package main + +import ( + "errors" + "flag" + "fmt" + "io" + "io/ioutil" + "os" + "regexp" + "strings" + "text/template" + + "github.com/sirupsen/logrus" + + "github.com/xwb1989/sqlparser" + + "github.com/ehlxr/go-utils/utils/generator/tmpl" + "github.com/ehlxr/go-utils/utils/log" +) + +// 定义 sqlData 数据结构。 +type SqlData struct { + ClassName string + ClassNameLower string + PackageName string + TableName string + ColumnList []SqlColumn +} + +// 表属性 +type SqlColumn struct { + Name string // 英文小写名称 + CnName string // 输入的中文名称 + Type string // 类型 + GoName string // Golang 名称 + GoType string // Golang 类型 + IsNumber bool // 判断是否数字 + ColName string + ColType string + NameUpper string + IsLast bool + IsId bool // 判断是否是 Id + IsDateTime bool // 判断是否是时间类型 +} + +var ( + // Options + flgSql = flag.String("sql", "", "path to sql file or '-' to read from stdin") + flagClz = flag.String("clz", "", "ClassName") + flagPkg = flag.String("pkg", "", "PackageName") +) + +func main() { + log.SetLogLevel(logrus.InfoLevel) + flag.Parse() + + if *flgSql == "" { + log.Fatal("sql is nil") + } + if *flagClz == "" { + log.Fatal("ClassName is nil") + } + if *flagPkg == "" { + log.Fatal("PackageName is nil") + } + + loadSql, err := loadData(*flgSql) + if err != nil { + log.Fatal(err) + } + + sql := string(loadSql) + log.Debugf("input sql %s", sql) + + // 规则引擎的要求,) 后面的不能解析。 + last := strings.LastIndex(sql, ")") + if last > 0 { + sql = sql[:last+1] + } + //AUTO_INCREMENT 不解析,只解析标准的 sql + sql = strings.Replace(sql, "AUTO_INCREMENT", "", -1) + sql = strings.Replace(sql, "auto_increment", "", -1) + log.Debugf("new sql %s", sql) + + sqlNode, err := sqlparser.ParseStrictDDL(sql) + if err != nil { + log.Fatal(err) + } + + sqlDataTmp, err := getSqlData(sqlNode) + if err != nil { + log.Fatal(err) + } + + log.Debugf("sql data: %v", sqlDataTmp) + for i, v := range sqlDataTmp.ColumnList { + log.Debugf("index: %d, value: %v", i, v) + } + + // path, err := filepath.Abs("") + // if err != nil { + // log.Fatal(err) + // } + // pojo := filepath.Join(path, "tmpl/java/pojo.tmpl") + // log.Info(pojo) + + data := make(map[string]interface{}) + data["SqlData"] = sqlDataTmp + + t := template.New("javapojo") + t = template.Must(t.Parse(tmpl.Pojo)) + err = t.Execute(os.Stdout, data) + if err != nil { + log.Fatal(err) + } + + t = template.New("javacontroller") + t = template.Must(t.Parse(tmpl.Controller)) + err = t.Execute(os.Stdout, data) + if err != nil { + log.Fatal(err) + } +} + +// Helper func: Read input from specified file or stdin +func loadData(p string) ([]byte, error) { + if p == "" { + return nil, fmt.Errorf("No path specified") + } + + var rdr io.Reader + if p == "-" { + rdr = os.Stdin + } else if p == "+" { + return []byte("{}"), nil + } else { + if f, err := os.Open(p); err == nil { + rdr = f + defer f.Close() + } else { + return nil, err + } + } + return ioutil.ReadAll(rdr) +} + +// 解析 sql 并返回 sqlData 数据 +func getSqlData(sqlNode sqlparser.SQLNode) (*SqlData, error) { + node, ok := sqlNode.(*sqlparser.DDL) + if !ok { + return nil, errors.New("不是标准的创建 sql 语句") + } + + //返回首字母小写 + first := string((*flagClz)[0]) + clzNameLower := strings.ToLower(first) + (*flagClz)[1:] + + sqlData := &SqlData{ + PackageName: *flagPkg, + ClassNameLower: clzNameLower, + ClassName: *flagClz, + TableName: node.Table.Name.String(), + ColumnList: getAllColumn(node)} + + return sqlData, nil +} + +// 返回属性数组 +func getAllColumn(node *sqlparser.DDL) []SqlColumn { + columnList := []SqlColumn{} + for index, col := range node.TableSpec.Columns { + colName := col.Name.String() + colType := col.Type.Type + + tmpColumn := SqlColumn{ColName: colName, ColType: colType} + // 设置首字母大写,驼峰命名 + tmpColumn.Name, tmpColumn.GoName = getNameByTitle(colName) + // 设置类型 + tmpColumn.Type, tmpColumn.GoType = getType(colType) + // 返回首字母大写 + first := string(tmpColumn.Name[0]) + nameUpper := strings.ToUpper(first) + tmpColumn.Name[1:] + tmpColumn.NameUpper = nameUpper + // 判断是否是最后一个数据 + if index == (len(node.TableSpec.Columns) - 1) { + tmpColumn.IsLast = true + } else { + tmpColumn.IsLast = false + } + //判断数据是否 == Id + if strings.ToLower(colName) == "id" { + tmpColumn.IsId = true + } else { + tmpColumn.IsId = false + } + // 判断数据是否 DateTime 类型 + if strings.ToLower(colType) == "datetime" { + tmpColumn.IsDateTime = true + } else { + tmpColumn.IsDateTime = false + } + // 添加数据 + columnList = append(columnList, tmpColumn) + } + return columnList +} + +// 获取 title 拆分的名称。java的 和 golang 的两个名称 +func getNameByTitle(colName string) (string, string) { + if colName == "" { + return "", "" + } + tmp := strings.Replace(colName, "_", " ", -1) + tmp = strings.Title(tmp) //title 支持空格分开的都title。 + tmp = strings.Replace(tmp, " ", "", -1) + //返回首字母小写 + first := string(tmp[0]) + tmp2 := strings.ToLower(first) + tmp[1:] + return tmp2, tmp +} + +// 获取类型,java 和 golang 的两个类型 +func getType(typeName string) (string, string) { + if typeName == "" { + return "", "" + } + var stringReg = regexp.MustCompile(`varchar|char`) // Has digit(s) + var dateTimeReg = regexp.MustCompile(`datetime|date`) // Has digit(s) + var longReg = regexp.MustCompile(`bigint|long`) // Has digit(s) + var integerReg = regexp.MustCompile(`integer|int`) // Has digit(s) + var floatReg = regexp.MustCompile(`float`) // Has digit(s) + var doubleReg = regexp.MustCompile(`float`) // Has digit(s) + + var tmp, goTmp string + switch { + case stringReg.MatchString(typeName): + tmp = "String" + goTmp = "string" + break + case dateTimeReg.MatchString(typeName): + tmp = "Date" + goTmp = "time.Time" + break + case longReg.MatchString(typeName): + tmp = "Long" + goTmp = "int64" + break + case integerReg.MatchString(typeName): + tmp = "Integer" + goTmp = "int32" + break + case floatReg.MatchString(typeName): + tmp = "Float" + goTmp = "float64" + break + case doubleReg.MatchString(typeName): + tmp = "Double" + goTmp = "float64" + break + } + return tmp, goTmp +} diff --git a/utils/generator/sql.sql b/utils/generator/sql.sql new file mode 100644 index 0000000..79e11dc --- /dev/null +++ b/utils/generator/sql.sql @@ -0,0 +1 @@ +CREATE TABLE `user_info` (`id` bigint(20) NOT NULL PRIMARY KEY AUTO_INCREMENT,`name` varchar(200) NOT NULL,`password` varchar(200) NOT NULL,`status` tinyint(1) NOT NULL ,`type` tinyint(1) NOT NULL ,`create_time` datetime NOT NULL,`update_time` datetime NOT NULL) ENGINE=InnoDB DEFAULT CHARSET=utf8 ; \ No newline at end of file diff --git a/utils/generator/tmpl/javacontroller.go b/utils/generator/tmpl/javacontroller.go new file mode 100644 index 0000000..d4f0f25 --- /dev/null +++ b/utils/generator/tmpl/javacontroller.go @@ -0,0 +1,121 @@ +package tmpl + +var Controller = ` +package {{.SqlData.PackageName}}.web; + +import com.google.common.base.Strings; +import {{.SqlData.PackageName}}.domain.{{.SqlData.ClassName}}; +import {{.SqlData.PackageName}}.common.page.Page; +import {{.SqlData.PackageName}}.common.Constants; +import {{.SqlData.PackageName}}.service.{{.SqlData.ClassName}}Service; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Controller; +import org.springframework.ui.Model; +import org.springframework.web.bind.annotation.*; +import javax.servlet.http.HttpServletResponse; +import java.util.HashMap; +import java.util.Map; + +@Controller +@RequestMapping(value = "/admin/{{.SqlData.ClassNameLower}}", method = {RequestMethod.GET, RequestMethod.POST}) +public class {{.SqlData.ClassName}}Controller { + + private static final Logger logger = LoggerFactory.getLogger({{.SqlData.ClassName}}Controller.class); + + @Autowired + private {{.SqlData.ClassName}}Service {{.SqlData.ClassNameLower}}Service; + + private static Integer pageSize = 15; + + @RequestMapping(value = "/edit", method = {RequestMethod.GET}) + public String edit(@RequestParam(value = "id", defaultValue = "0") Long id, + Model view) { + try { + {{.SqlData.ClassName}} {{.SqlData.ClassNameLower}} = null; + if (id != null && id.longValue() > 0) { + {{.SqlData.ClassNameLower}} = {{.SqlData.ClassNameLower}}Service.query{{.SqlData.ClassName}}ById(id); + }else{ + {{.SqlData.ClassNameLower}} = new {{.SqlData.ClassName}}(); + } + view.addAttribute("{{.SqlData.ClassNameLower}}", {{.SqlData.ClassNameLower}}); + + } catch (Exception e) { + logger.error(e.getMessage(), e); + } + return "admin/{{.SqlData.ClassNameLower}}/edit"; + } + + @RequestMapping(value = "/view", method = {RequestMethod.GET}) + public String view(@RequestParam(value = "id", defaultValue = "0") Long id, + Model view) { + try { + {{.SqlData.ClassName}} {{.SqlData.ClassNameLower}} = null; + if (id != null && id.longValue() > 0) { + {{.SqlData.ClassNameLower}} = {{.SqlData.ClassNameLower}}Service.query{{.SqlData.ClassName}}ById(id); + }else{ + {{.SqlData.ClassNameLower}} = new {{.SqlData.ClassName}}(); + } + view.addAttribute("{{.SqlData.ClassNameLower}}", {{.SqlData.ClassNameLower}}); + + } catch (Exception e) { + logger.error(e.getMessage(), e); + } + return "admin/{{.SqlData.ClassNameLower}}/view"; + } + + @RequestMapping(value = "/delete", method = {RequestMethod.DELETE}) + @ResponseBody + public String delete(@RequestParam(value = "id", defaultValue = "0") Long id, + Model view) { + try { + + long rows = {{.SqlData.ClassNameLower}}Service.delete{{.SqlData.ClassName}}(id); + + } catch (Exception e) { + logger.error(e.getMessage(), e); + } + return String.format(Constants.WEB_IFRAME_SCRIPT, "删除成功!"); + } + + @RequestMapping(value = "/save", method = {RequestMethod.POST}) + @ResponseBody + public String save({{.SqlData.ClassName}} {{.SqlData.ClassNameLower}}, + Model view) { + try { + + long rows = {{.SqlData.ClassNameLower}}Service.save{{.SqlData.ClassName}}({{.SqlData.ClassNameLower}}); + view.addAttribute("{{.SqlData.ClassNameLower}}", {{.SqlData.ClassNameLower}}); + + } catch (Exception e) { + logger.error(e.getMessage(), e); + } + return String.format(Constants.WEB_IFRAME_SCRIPT, "保存成功!"); + } + + @RequestMapping(value = "/list", method = {RequestMethod.GET, RequestMethod.POST}) + public String list(@RequestParam(value = "page", defaultValue = "0") int page, + @RequestParam(value = "id", required = false) Long id, + Model view) { + try { + //查询 + Map search = new HashMap(); + if (id != null) { + search.put("id", id); + } + + Page<{{.SqlData.ClassName}}> pageData = {{.SqlData.ClassNameLower}}Service.query{{.SqlData.ClassName}}Page(page, pageSize,search); + //放入page对象。 + view.addAttribute("pageData", pageData); + view.addAttribute("id", id); + + + } catch (Exception e) { + logger.error(e.getMessage(), e); + } + return "admin/{{.SqlData.ClassNameLower}}/list"; + } + +} +` diff --git a/utils/generator/tmpl/javapojo.go b/utils/generator/tmpl/javapojo.go new file mode 100644 index 0000000..73c0d8d --- /dev/null +++ b/utils/generator/tmpl/javapojo.go @@ -0,0 +1,26 @@ +package tmpl + +var Pojo = ` +package {{.SqlData.PackageName}}.domain; + +import java.util.Date; + +/** +* @table {{.SqlData.ClassNameLower}} +*/ +public class {{.SqlData.ClassName}} { +private static final long serialVersionUID = 1L; + {{range $index, $column := .SqlData.ColumnList}} + private {{$column.Type}} {{$column.Name}}; {{if $column.CnName}} // $column.CnName {{end}} + {{end}} + {{range $index, $column := .SqlData.ColumnList}} + public {{$column.Type}} get{{$column.NameUpper}}() { + return this.{{$column.Name}}; + } + + public void set{{$column.NameUpper}}({{$column.Type}} {{$column.Name}}) { + this.{{$column.Name}} = {{$column.Name}}; + }{{end}} + +} +`