diff --git a/example/dump/main.go b/example/dump/main.go index a3bb22a..4720f5a 100644 --- a/example/dump/main.go +++ b/example/dump/main.go @@ -1,19 +1,27 @@ package main import ( + "database/sql" + "log" "os" - "github.com/jarvanstack/mysqldump" + "github.com/notyusta/mysqldump" ) func main() { + dsn := "root:rootpasswd@tcp(localhost:3306)/dbname?charset=utf8mb4&parseTime=true&loc=Asia%2FJakarta" + db, err := sql.Open("mysql", dsn) + if err != nil { + log.Printf("[error] %v \n", err) + return + } - dsn := "root:rootpasswd@tcp(localhost:3306)/dbname?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai" - + defer db.Close() f, _ := os.Create("dump.sql") _ = mysqldump.Dump( - dsn, // DSN + db, // DSN + "test", mysqldump.WithDropTable(), // Option: Delete table before create (Default: Not delete table) mysqldump.WithData(), // Option: Dump Data (Default: Only dump table schema) mysqldump.WithTables("test"), // Option: Dump Tables (Default: All tables) diff --git a/example/source/main.go b/example/source/main.go index c460e1c..2b4e40c 100644 --- a/example/source/main.go +++ b/example/source/main.go @@ -1,18 +1,27 @@ package main import ( + "database/sql" + "log" "os" - "github.com/jarvanstack/mysqldump" + "github.com/notyusta/mysqldump" ) func main() { - dns := "root:rootpasswd@tcp(localhost:3306)/dbname?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai" + dsn := "root:rootpasswd@tcp(localhost:3306)/dbname?charset=utf8mb4&parseTime=true&loc=Asia%2FJakarta" + db, err := sql.Open("mysql", dsn) + if err != nil { + log.Printf("[error] %v \n", err) + return + } + f, _ := os.Open("dump.sql") _ = mysqldump.Source( - dns, + db, + "test", f, mysqldump.WithMergeInsert(1000), // Option: Merge insert 1000 (Default: Not merge insert) mysqldump.WithDebug(), // Option: Print execute sql (Default: Not print execute sql) diff --git a/go.mod b/go.mod index fa0a6f3..8a7240b 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/jarvanstack/mysqldump +module github.com/notyusta/mysqldump go 1.18 diff --git a/mysqldump.go b/mysqldump.go index a26d21d..b646639 100644 --- a/mysqldump.go +++ b/mysqldump.go @@ -5,7 +5,6 @@ import ( "database/sql" "fmt" "io" - "log" "os" "strings" "time" @@ -13,10 +12,7 @@ import ( _ "github.com/go-sql-driver/mysql" ) -func init() { - // 打印 日志 行数 - log.SetFlags(log.Lshortfile | log.LstdFlags) -} +func init() {} type dumpOption struct { // 导出表数据 @@ -70,16 +66,10 @@ func WithWriter(writer io.Writer) DumpOption { } } -func Dump(dsn string, opts ...DumpOption) error { +func Dump(db *sql.DB, dbName string, opts ...DumpOption) error { // 打印开始 start := time.Now() - log.Printf("[info] [dump] start at %s\n", start.Format("2006-01-02 15:04:05")) // 打印结束 - defer func() { - end := time.Now() - log.Printf("[info] [dump] end at %s, cost %s\n", end.Format("2006-01-02 15:04:05"), end.Sub(start)) - }() - var err error var o dumpOption @@ -105,26 +95,21 @@ func Dump(dsn string, opts ...DumpOption) error { _, _ = buf.WriteString("-- ----------------------------\n") _, _ = buf.WriteString("-- MySQL Database Dump\n") _, _ = buf.WriteString("-- Start Time: " + start.Format("2006-01-02 15:04:05") + "\n") + _, _ = buf.WriteString("-- Database Name: " + dbName + "\n") _, _ = buf.WriteString("-- ----------------------------\n") - _, _ = buf.WriteString("\n\n") - // 连接数据库 - db, err := sql.Open("mysql", dsn) - if err != nil { - log.Printf("[error] %v \n", err) - return err - } - defer db.Close() + // Session / export settings + _, _ = buf.WriteString("/*!40101 SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT */;\n") + _, _ = buf.WriteString("/*!40101 SET NAMES utf8 */;\n") + _, _ = buf.WriteString("/*!50503 SET NAMES utf8mb4 */;\n") + _, _ = buf.WriteString("/*!40103 SET @OLD_TIME_ZONE=@@TIME_ZONE */;\n") + _, _ = buf.WriteString("/*!40103 SET TIME_ZONE='+00:00' */;\n") + _, _ = buf.WriteString("/*!40014 SET @OLD_FOREIGN_KEY_CHECKS=@@FOREIGN_KEY_CHECKS, FOREIGN_KEY_CHECKS=0 */;\n") + _, _ = buf.WriteString("/*!40101 SET @OLD_SQL_MODE=@@SQL_MODE, SQL_MODE='NO_AUTO_VALUE_ON_ZERO' */;\n") + _, _ = buf.WriteString("/*!40111 SET @OLD_SQL_NOTES=@@SQL_NOTES, SQL_NOTES=0 */;\n") - // 1. 获取数据库 - dbName, err := GetDBNameFromDSN(dsn) - if err != nil { - log.Printf("[error] %v \n", err) - return err - } _, err = db.Exec(fmt.Sprintf("USE `%s`", dbName)) if err != nil { - log.Printf("[error] %v \n", err) return err } @@ -133,7 +118,6 @@ func Dump(dsn string, opts ...DumpOption) error { if o.isAllTable { tmp, err := getAllTables(db) if err != nil { - log.Printf("[error] %v \n", err) return err } tables = tmp @@ -151,25 +135,31 @@ func Dump(dsn string, opts ...DumpOption) error { // 导出表结构 err = writeTableStruct(db, table, buf) if err != nil { - log.Printf("[error] %v \n", err) return err } - // 导出表数据 - if o.isData { - err = writeTableData(db, table, buf) + } + + allTotalRows := uint64(0) + if o.isData { + for _, table := range tables { + totalRows, err := writeTableData(db, table, buf) + allTotalRows += totalRows if err != nil { - log.Printf("[error] %v \n", err) return err } } } // 导出每个表的结构和数据 - + _, _ = buf.WriteString("SET FOREIGN_KEY_CHECKS=1;\n") _, _ = buf.WriteString("-- ----------------------------\n") _, _ = buf.WriteString("-- Dumped by mysqldump\n") + _, _ = buf.WriteString("-- Maintained by Yusta (https://github.com/NotYusta)\n") _, _ = buf.WriteString("-- Cost Time: " + time.Since(start).String() + "\n") + _, _ = buf.WriteString("-- Complete Time: " + time.Now().Format("2006-01-02 15:04:05") + "\n") + _, _ = buf.WriteString("-- Table Counts: " + fmt.Sprintf("%d", len(tables)) + "\n") + _, _ = buf.WriteString("-- Table Rows: " + fmt.Sprintf("%d", allTotalRows) + "\n") _, _ = buf.WriteString("-- ----------------------------\n") buf.Flush() @@ -203,6 +193,7 @@ func getAllTables(db *sql.DB) ([]string, error) { } tables = append(tables, table) } + return tables, nil } @@ -211,153 +202,89 @@ func writeTableStruct(db *sql.DB, table string, buf *bufio.Writer) error { _, _ = buf.WriteString("-- ----------------------------\n") _, _ = buf.WriteString(fmt.Sprintf("-- Table structure for %s\n", table)) _, _ = buf.WriteString("-- ----------------------------\n") - createTableSQL, err := getCreateTableSQL(db, table) if err != nil { - log.Printf("[error] %v \n", err) return err } - _, _ = buf.WriteString(createTableSQL) - _, _ = buf.WriteString(";") - - _, _ = buf.WriteString("\n\n") - _, _ = buf.WriteString("\n\n") + _, _ = buf.WriteString(fmt.Sprintf("%s;\n\n", createTableSQL)) return nil } // 禁止 golangci-lint 检查 // nolint: gocyclo -func writeTableData(db *sql.DB, table string, buf *bufio.Writer) error { +func writeTableData(db *sql.DB, table string, buf *bufio.Writer) (uint64, error) { + var totalRow uint64 + row := db.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM `%s`", table)) + row.Scan(&totalRow) // 导出表数据 _, _ = buf.WriteString("-- ----------------------------\n") - _, _ = buf.WriteString(fmt.Sprintf("-- Records of %s\n", table)) + _, _ = buf.WriteString(fmt.Sprintf("-- Records of %s (%d Rows)\n", table, totalRow)) _, _ = buf.WriteString("-- ----------------------------\n") - lineRows, err := db.Query(fmt.Sprintf("SELECT * FROM `%s`", table)) + rows, err := db.Query(fmt.Sprintf("SELECT * FROM `%s`", table)) if err != nil { - log.Printf("[error] %v \n", err) - return err + return totalRow, err } - defer lineRows.Close() + defer rows.Close() - var columns []string - columns, err = lineRows.Columns() + columns, err := rows.Columns() if err != nil { - log.Printf("[error] %v \n", err) - return err - } - columnTypes, err := lineRows.ColumnTypes() - if err != nil { - log.Printf("[error] %v \n", err) - return err + return totalRow, err } - var values [][]interface{} - for lineRows.Next() { - row := make([]interface{}, len(columns)) - rowPointers := make([]interface{}, len(columns)) - for i := range columns { - rowPointers[i] = &row[i] - } - err = lineRows.Scan(rowPointers...) - if err != nil { - log.Printf("[error] %v \n", err) - return err - } - values = append(values, row) + quotedColumns := make([]string, len(columns)) + for i, col := range columns { + quotedColumns[i] = "`" + col + "`" } + columnNames := strings.Join(quotedColumns, ",") + + if totalRow > 0 { + const batchSize = 1024 + var batch []string + count := 0 + + for rows.Next() { + data := make([]*sql.NullString, len(columns)) + ptrs := make([]interface{}, len(columns)) + for i := range data { + ptrs[i] = &data[i] + } + + if err := rows.Scan(ptrs...); err != nil { + return totalRow, err + } - for _, row := range values { - ssql := "INSERT INTO `" + table + "` VALUES (" - - for i, col := range row { - if col == nil { - ssql += "NULL" - } else { - Type := columnTypes[i].DatabaseTypeName() - // 去除 UNSIGNED 和空格 - Type = strings.Replace(Type, "UNSIGNED", "", -1) - Type = strings.Replace(Type, " ", "", -1) - switch Type { - case "TINYINT", "SMALLINT", "MEDIUMINT", "INT", "INTEGER", "BIGINT": - if bs, ok := col.([]byte); ok { - ssql += string(bs) - } else { - ssql += fmt.Sprintf("%d", col) - } - case "FLOAT", "DOUBLE": - if bs, ok := col.([]byte); ok { - ssql += string(bs) - } else { - ssql += fmt.Sprintf("%f", col) - } - case "DECIMAL", "DEC": - ssql += fmt.Sprintf("%s", col) - - case "DATE": - t, ok := col.(time.Time) - if !ok { - log.Println("DATE 类型转换错误") - return err - } - ssql += fmt.Sprintf("'%s'", t.Format("2006-01-02")) - case "DATETIME": - t, ok := col.(time.Time) - if !ok { - log.Println("DATETIME 类型转换错误") - return err - } - ssql += fmt.Sprintf("'%s'", t.Format("2006-01-02 15:04:05")) - case "TIMESTAMP": - t, ok := col.(time.Time) - if !ok { - log.Println("TIMESTAMP 类型转换错误") - return err - } - ssql += fmt.Sprintf("'%s'", t.Format("2006-01-02 15:04:05")) - case "TIME": - t, ok := col.([]byte) - if !ok { - log.Println("TIME 类型转换错误") - return err - } - ssql += fmt.Sprintf("'%s'", string(t)) - case "YEAR": - t, ok := col.([]byte) - if !ok { - log.Println("YEAR 类型转换错误") - return err - } - ssql += string(t) - case "CHAR", "VARCHAR", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT": - ssql += fmt.Sprintf("'%s'", strings.Replace(fmt.Sprintf("%s", col), "'", "''", -1)) - case "BIT", "BINARY", "VARBINARY", "TINYBLOB", "BLOB", "MEDIUMBLOB", "LONGBLOB": - ssql += fmt.Sprintf("0x%X", col) - case "ENUM", "SET": - ssql += fmt.Sprintf("'%s'", col) - case "BOOL", "BOOLEAN": - if col.(bool) { - ssql += "true" - } else { - ssql += "false" - } - case "JSON": - ssql += fmt.Sprintf("'%s'", col) - default: - // unsupported type - log.Printf("unsupported type: %s", Type) - return fmt.Errorf("unsupported type: %s", Type) + dataStrings := make([]string, len(columns)) + for key, value := range data { + if value != nil && value.Valid { + escaped := strings.ReplaceAll(value.String, "\\", "\\\\") + escaped = strings.ReplaceAll(escaped, "'", "''") + dataStrings[key] = "'" + escaped + "'" + } else { + dataStrings[key] = "NULL" } } - if i < len(row)-1 { - ssql += "," + + if len(dataStrings) != len(columns) { + return totalRow, fmt.Errorf("row has mismatched column count") + } + + batch = append(batch, "("+strings.Join(dataStrings, ",")+")") + count++ + + if count%batchSize == 0 { + buf.WriteString(fmt.Sprintf("INSERT INTO `%s` (%s) VALUES %s;\n", table, columnNames, strings.Join(batch, ","))) + batch = batch[:0] } } - ssql += ");\n" - _, _ = buf.WriteString(ssql) + + if len(batch) > 0 { + buf.WriteString(fmt.Sprintf("INSERT INTO `%s` (%s) VALUES %s;\n", table, columnNames, strings.Join(batch, ","))) + } + } - _, _ = buf.WriteString("\n\n") - return nil + _, _ = buf.WriteString("\n") + return totalRow, nil } diff --git a/source.go b/source.go index 25eea72..0be1a1f 100644 --- a/source.go +++ b/source.go @@ -6,9 +6,7 @@ import ( "errors" "fmt" "io" - "log" "strings" - "time" ) type sourceOption struct { @@ -53,10 +51,6 @@ func newDBWrapper(db *sql.DB, dryRun, debug bool) *dbWrapper { } func (db *dbWrapper) Exec(query string, args ...interface{}) (sql.Result, error) { - if db.debug { - log.Printf("[debug] [query]\n%s\n", query) - } - if db.dryRun { return nil, nil } @@ -66,44 +60,20 @@ func (db *dbWrapper) Exec(query string, args ...interface{}) (sql.Result, error) // Source 加载 // 禁止 golangci-lint 检查 // nolint: gocyclo -func Source(dsn string, reader io.Reader, opts ...SourceOption) error { +func Source(db *sql.DB, dbName string, reader io.Reader, opts ...SourceOption) error { // 打印开始 - start := time.Now() - log.Printf("[info] [source] start at %s\n", start.Format("2006-01-02 15:04:05")) - // 打印结束 - defer func() { - end := time.Now() - log.Printf("[info] [source] end at %s, cost %s\n", end.Format("2006-01-02 15:04:05"), end.Sub(start)) - }() - var err error - var db *sql.DB var o sourceOption for _, opt := range opts { opt(&o) } - dbName, err := GetDBNameFromDSN(dsn) - if err != nil { - log.Printf("[error] %v\n", err) - return err - } - - // Open database - db, err = sql.Open("mysql", dsn) - if err != nil { - log.Printf("[error] %v\n", err) - return err - } - defer db.Close() - // DB Wrapper dbWrapper := newDBWrapper(db, o.dryRun, o.debug) // Use database - _, err = dbWrapper.Exec(fmt.Sprintf("USE %s;", dbName)) + _, err = dbWrapper.Exec(fmt.Sprintf("USE `%s`", dbName)) if err != nil { - log.Printf("[error] %v\n", err) return err } @@ -115,7 +85,6 @@ func Source(dsn string, reader io.Reader, opts ...SourceOption) error { // 关闭事务 _, err = dbWrapper.Exec("SET autocommit=0;") if err != nil { - log.Printf("[error] %v\n", err) return err } @@ -125,7 +94,6 @@ func Source(dsn string, reader io.Reader, opts ...SourceOption) error { if err == io.EOF { break } - log.Printf("[error] %v\n", err) return err } @@ -134,7 +102,6 @@ func Source(dsn string, reader io.Reader, opts ...SourceOption) error { // 删除末尾的换行符 ssql = trim(ssql) if err != nil { - log.Printf("[error] [trim] %v\n", err) return err } @@ -148,14 +115,12 @@ func Source(dsn string, reader io.Reader, opts ...SourceOption) error { if err == io.EOF { break } - log.Printf("[error] %v\n", err) return err } ssql2 := string(line) ssql2 = trim(ssql2) if err != nil { - log.Printf("[error] [trim] %v\n", err) return err } if strings.HasPrefix(ssql2, "INSERT INTO") { @@ -168,14 +133,12 @@ func Source(dsn string, reader io.Reader, opts ...SourceOption) error { // 合并 INSERT ssql, err = mergeInsert(insertSQLs) if err != nil { - log.Printf("[error] [mergeInsert] %v\n", err) return err } } _, err = dbWrapper.Exec(ssql) if err != nil { - log.Printf("[error] %v\n", err) return err } } @@ -183,14 +146,12 @@ func Source(dsn string, reader io.Reader, opts ...SourceOption) error { // 提交事务 _, err = dbWrapper.Exec("COMMIT;") if err != nil { - log.Printf("[error] %v\n", err) return err } // 开启事务 _, err = dbWrapper.Exec("SET autocommit=1;") if err != nil { - log.Printf("[error] %v\n", err) return err } diff --git a/util.go b/util.go deleted file mode 100644 index ad084b4..0000000 --- a/util.go +++ /dev/null @@ -1,21 +0,0 @@ -package mysqldump - -import ( - "fmt" - "strings" -) - -//从dsn中提取出数据库名称,并将其作为结果返回。 -//如果无法解析出数据库名称,将返回一个错误。 - -func GetDBNameFromDSN(dsn string) (string, error) { - ss1 := strings.Split(dsn, "/") - if len(ss1) == 2 { - ss2 := strings.Split(ss1[1], "?") - if len(ss2) == 2 { - return ss2[0], nil - } - } - - return "", fmt.Errorf("dsn error: %s", dsn) -}