diff --git a/go/base/context.go b/go/base/context.go index 300ec1201..693db0572 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -77,10 +77,12 @@ func NewThrottleCheckResult(throttle bool, reason string, reasonHint ThrottleRea type MigrationContext struct { Uuid string - DatabaseName string - OriginalTableName string - AlterStatement string - AlterStatementOptions string // anything following the 'ALTER TABLE [schema.]table' from AlterStatement + DatabaseName string + OriginalTableName string + AlterStatement string + AlterStatementOptions string // anything following the 'ALTER TABLE [schema.]table' from AlterStatement + CreateTableStatement string + CreateTableStatementBody string // anything following the 'CREATE TABLE [schema.]table' from CreateTableStatement countMutex sync.Mutex countTableRowsCancelFunc func() diff --git a/go/cmd/gh-ost/main.go b/go/cmd/gh-ost/main.go index 139703077..c5eb359fa 100644 --- a/go/cmd/gh-ost/main.go +++ b/go/cmd/gh-ost/main.go @@ -66,7 +66,8 @@ func main() { flag.StringVar(&migrationContext.DatabaseName, "database", "", "database name (mandatory)") flag.StringVar(&migrationContext.OriginalTableName, "table", "", "table name (mandatory)") - flag.StringVar(&migrationContext.AlterStatement, "alter", "", "alter statement (mandatory)") + flag.StringVar(&migrationContext.AlterStatement, "alter", "", "alter statement (mandatory if no '--create-table' statement provided)") + flag.StringVar(&migrationContext.CreateTableStatement, "create-table", "", "create table statement (mandatory if no '--alter' statement provided)") flag.BoolVar(&migrationContext.AttemptInstantDDL, "attempt-instant-ddl", false, "Attempt to use instant DDL for this migration first") storageEngine := flag.String("storage-engine", "innodb", "Specify table storage engine (default: 'innodb'). When 'rocksdb': the session transaction isolation level is changed from REPEATABLE_READ to READ_COMMITTED.") @@ -191,17 +192,29 @@ func main() { migrationContext.Log.Fatale(err) } - if migrationContext.AlterStatement == "" { - log.Fatal("--alter must be provided and statement must not be empty") + if migrationContext.AlterStatement == "" && migrationContext.CreateTableStatement == "" { + log.Fatal("--alter or --create-table must be provided and statement must not be empty") + } + + if migrationContext.AlterStatement != "" && migrationContext.CreateTableStatement != "" { + log.Fatal("--alter and --create-table cannot both be provided at the same time") + } + + var parser sql.Parser + if migrationContext.CreateTableStatement != "" { + confirmNoAlterStatementFlags(migrationContext) + parser = sql.NewParserFromCreateTableStatement(migrationContext.AlterStatement) + migrationContext.CreateTableStatementBody = parser.GetOptions() + } else { + parser = sql.NewParserFromAlterStatement(migrationContext.AlterStatement) + migrationContext.AlterStatementOptions = parser.GetOptions() } - parser := sql.NewParserFromAlterStatement(migrationContext.AlterStatement) - migrationContext.AlterStatementOptions = parser.GetAlterStatementOptions() if migrationContext.DatabaseName == "" { if parser.HasExplicitSchema() { migrationContext.DatabaseName = parser.GetExplicitSchema() } else { - log.Fatal("--database must be provided and database name must not be empty, or --alter must specify database name") + log.Fatal(fmt.Sprintf("--database must be provided and database name must not be empty, or --%v must specify database name", parser.Type().String())) } } @@ -213,7 +226,7 @@ func main() { if parser.HasExplicitTable() { migrationContext.OriginalTableName = parser.GetExplicitTable() } else { - log.Fatal("--table must be provided and table name must not be empty, or --alter must specify table name") + log.Fatal(fmt.Sprintf("--table must be provided and table name must not be empty, or --%v must specify table name", parser.Type().String())) } } migrationContext.Noop = !(*executeFlag) @@ -321,3 +334,21 @@ func main() { } fmt.Fprintln(os.Stdout, "# Done") } + +func confirmNoAlterStatementFlags(migrationContext *base.MigrationContext) { + if migrationContext.AttemptInstantDDL { + migrationContext.Log.Fatal("--attempt-instant-ddl cannot be used with --create-table") + } + if migrationContext.ApproveRenamedColumns { + migrationContext.Log.Fatal("--approve-renamed-columns cannot be used with --create-table") + } + if migrationContext.SkipRenamedColumns { + migrationContext.Log.Fatal("--skip-renamed-columns cannot be used with --create-table") + } + if migrationContext.DiscardForeignKeys { + migrationContext.Log.Fatal("--discard-foreign-keys cannot be used with --create-table") + } + if migrationContext.SkipForeignKeyChecks { + migrationContext.Log.Fatal("--skip-foreign-key-checks cannot be used with --create-table") + } +} diff --git a/go/logic/applier.go b/go/logic/applier.go index fa374a70f..c5688ee96 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -218,14 +218,24 @@ func (this *Applier) AttemptInstantDDL() error { return err } +func (this *Applier) generateCreateGhostTableQuery() string { + query := fmt.Sprintf(`create /* gh-ost */ table %s.%s `, + sql.EscapeName(this.migrationContext.DatabaseName), + sql.EscapeName(this.migrationContext.GetGhostTableName())) + if this.migrationContext.CreateTableStatementBody != "" { + query += this.migrationContext.CreateTableStatementBody + } else { + query += fmt.Sprintf(`like %s.%s`, + sql.EscapeName(this.migrationContext.DatabaseName), + sql.EscapeName(this.migrationContext.OriginalTableName), + ) + } + return query +} + // CreateGhostTable creates the ghost table on the applier host func (this *Applier) CreateGhostTable() error { - query := fmt.Sprintf(`create /* gh-ost */ table %s.%s like %s.%s`, - sql.EscapeName(this.migrationContext.DatabaseName), - sql.EscapeName(this.migrationContext.GetGhostTableName()), - sql.EscapeName(this.migrationContext.DatabaseName), - sql.EscapeName(this.migrationContext.OriginalTableName), - ) + query := this.generateCreateGhostTableQuery() this.migrationContext.Log.Infof("Creating ghost table %s.%s", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.GetGhostTableName()), @@ -261,6 +271,9 @@ func (this *Applier) CreateGhostTable() error { // AlterGhost applies `alter` statement on ghost table func (this *Applier) AlterGhost() error { + if this.migrationContext.AlterStatement == "" { + return nil + } query := fmt.Sprintf(`alter /* gh-ost */ table %s.%s %s`, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.GetGhostTableName()), diff --git a/go/logic/applier_test.go b/go/logic/applier_test.go index 36c28d9e8..2917b1e25 100644 --- a/go/logic/applier_test.go +++ b/go/logic/applier_test.go @@ -183,3 +183,26 @@ func TestApplierInstantDDL(t *testing.T) { test.S(t).ExpectEquals(stmt, "ALTER /* gh-ost */ TABLE `test`.`mytable` ADD INDEX (foo), ALGORITHM=INSTANT") }) } + +func TestApplierCreateGhostTable(t *testing.T) { + t.Run("default", func(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrationContext.DatabaseName = "test" + migrationContext.OriginalTableName = "mytable" + applier := NewApplier(migrationContext) + + stmt := applier.generateCreateGhostTableQuery() + test.S(t).ExpectEquals(stmt, "create /* gh-ost */ table `test`.`_mytable_gho` like `test`.`mytable`") + }) + + t.Run("withCustomCreateTable", func(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrationContext.DatabaseName = "test" + migrationContext.OriginalTableName = "mytable" + migrationContext.CreateTableStatementBody = "(id VARCHAR(24), PRIMARY KEY(id))" + applier := NewApplier(migrationContext) + + stmt := applier.generateCreateGhostTableQuery() + test.S(t).ExpectEquals(stmt, "create /* gh-ost */ table `test`.`_mytable_gho` (id VARCHAR(24), PRIMARY KEY(id))") + }) +} diff --git a/go/logic/hooks.go b/go/logic/hooks.go index 2543f8e9a..99f88b616 100644 --- a/go/logic/hooks.go +++ b/go/logic/hooks.go @@ -52,6 +52,7 @@ func (this *HooksExecutor) applyEnvironmentVariables(extraVariables ...string) [ env = append(env, fmt.Sprintf("GH_OST_GHOST_TABLE_NAME=%s", this.migrationContext.GetGhostTableName())) env = append(env, fmt.Sprintf("GH_OST_OLD_TABLE_NAME=%s", this.migrationContext.GetOldTableName())) env = append(env, fmt.Sprintf("GH_OST_DDL=%s", this.migrationContext.AlterStatement)) + env = append(env, fmt.Sprintf("GH_OST_CREATE_TABLE=%s", this.migrationContext.CreateTableStatement)) env = append(env, fmt.Sprintf("GH_OST_ELAPSED_SECONDS=%f", this.migrationContext.ElapsedTime().Seconds())) env = append(env, fmt.Sprintf("GH_OST_ELAPSED_COPY_SECONDS=%f", this.migrationContext.ElapsedRowCopyTime().Seconds())) estimatedRows := atomic.LoadInt64(&this.migrationContext.RowsEstimate) + atomic.LoadInt64(&this.migrationContext.RowsDeltaEstimate) diff --git a/go/logic/migrator.go b/go/logic/migrator.go index fed7c944b..4e1ace4af 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -23,7 +23,8 @@ import ( ) var ( - ErrMigratorUnsupportedRenameAlter = errors.New("ALTER statement seems to RENAME the table. This is not supported, and you should run your RENAME outside gh-ost.") + ErrMigratorUnsupportedRenameAlter = errors.New("ALTER statement seems to RENAME the table. This is not supported, and you should run your RENAME outside of gh-ost.") + ErrMigratorUnsupportedRenameCreateTable = errors.New("CREATE TABLE statement seems to RENAME the table. This is not supported, and you should first run your RENAME outside of gh-ost.") ) type ChangelogState string @@ -69,7 +70,7 @@ const ( // Migrator is the main schema migration flow manager. type Migrator struct { appVersion string - parser *sql.AlterTableParser + parser sql.Parser inspector *Inspector applier *Applier eventsStreamer *EventsStreamer @@ -94,21 +95,27 @@ type Migrator struct { finishedMigrating int64 } +func getParser(context *base.MigrationContext) sql.Parser { + if context.CreateTableStatement != "" { + return sql.NewCreateTableParser(context.CreateTableStatement) + } + return sql.NewAlterTableParser(context.AlterStatement) +} + func NewMigrator(context *base.MigrationContext, appVersion string) *Migrator { migrator := &Migrator{ appVersion: appVersion, hooksExecutor: NewHooksExecutor(context), migrationContext: context, - parser: sql.NewAlterTableParser(), + parser: getParser(context), ghostTableMigrated: make(chan bool), firstThrottlingCollected: make(chan bool, 3), rowCopyComplete: make(chan error), allEventsUpToLockProcessed: make(chan string), - - copyRowsQueue: make(chan tableWriteFunc), - applyEventsQueue: make(chan *applyEventStruct, base.MaxEventsBatchSize), - handledChangelogStates: make(map[string]bool), - finishedMigrating: 0, + copyRowsQueue: make(chan tableWriteFunc), + applyEventsQueue: make(chan *applyEventStruct, base.MaxEventsBatchSize), + handledChangelogStates: make(map[string]bool), + finishedMigrating: 0, } return migrator } @@ -258,7 +265,14 @@ func (this *Migrator) listenOnPanicAbort() { this.migrationContext.Log.Fatale(err) } -// validateAlterStatement validates the `alter` statement meets criteria. +func (this *Migrator) validateStatement() error { + if this.parser.Type() == sql.ParserTypeAlterTable { + return this.validateAlterStatement() + } + return this.validateCreateTableStatement() +} + +// validateAlterStatement validates that the `alter` statement meets criteria. // At this time this means: // - column renames are approved // - no table rename allowed @@ -277,6 +291,23 @@ func (this *Migrator) validateAlterStatement() (err error) { return nil } +// validateCreateTableStatement validates that the `create table` statement meets criteria. +// At this time this means: +// - no table rename allowed +// - no foreign keys +func (this *Migrator) validateCreateTableStatement() (err error) { + if this.migrationContext.OriginalTableName != this.parser.GetExplicitTable() { + return ErrMigratorUnsupportedRenameCreateTable + } + if this.parser.HasForeignKeys() { + if this.migrationContext.DiscardForeignKeys { + return errors.New("CREATE TABLE includes FOREIGN KEYS but --discard-foreign-keys was set. You should instead remove the foreign keys from the table definition") + } + return errors.New("CREATE TABLE statement seems to FOREIGN KEYS on the table. Child-side foreign keys are not supported. Bailing out") + } + return nil +} + func (this *Migrator) countTableRows() (err error) { if !this.migrationContext.CountTableRows { // Not counting; we stay with an estimate @@ -336,15 +367,15 @@ func (this *Migrator) Migrate() (err error) { if err := this.hooksExecutor.onStartup(); err != nil { return err } - if err := this.parser.ParseAlterStatement(this.migrationContext.AlterStatement); err != nil { + if err := this.parser.ParseStatement(); err != nil { return err } - if err := this.validateAlterStatement(); err != nil { + if err := this.validateStatement(); err != nil { return err } // After this point, we'll need to teardown anything that's been started - // so we don't leave things hanging around + // so we don't leave things hanging around defer this.teardown() if err := this.initiateInspector(); err != nil { @@ -1139,18 +1170,20 @@ func (this *Migrator) initiateApplier() error { return err } - if err := this.applier.AlterGhost(); err != nil { - this.migrationContext.Log.Errorf("Unable to ALTER ghost table, see further error details. Bailing out") - return err - } - - if this.migrationContext.OriginalTableAutoIncrement > 0 && !this.parser.IsAutoIncrementDefined() { - // Original table has AUTO_INCREMENT value and the -alter statement does not indicate any override, - // so we should copy AUTO_INCREMENT value onto our ghost table. - if err := this.applier.AlterGhostAutoIncrement(); err != nil { - this.migrationContext.Log.Errorf("Unable to ALTER ghost table AUTO_INCREMENT value, see further error details. Bailing out") + if this.parser.Type() == sql.ParserTypeAlterTable { + if err := this.applier.AlterGhost(); err != nil { + this.migrationContext.Log.Errorf("Unable to ALTER ghost table, see further error details. Bailing out") return err } + + if this.migrationContext.OriginalTableAutoIncrement > 0 && !this.parser.IsAutoIncrementDefined() { + // Original table has AUTO_INCREMENT value and the -alter statement does not indicate any override, + // so we should copy AUTO_INCREMENT value onto our ghost table. + if err := this.applier.AlterGhostAutoIncrement(); err != nil { + this.migrationContext.Log.Errorf("Unable to ALTER ghost table AUTO_INCREMENT value, see further error details. Bailing out") + return err + } + } } this.applier.WriteChangelogState(string(GhostTableMigrated)) go this.applier.InitiateHeartbeat() diff --git a/go/logic/migrator_test.go b/go/logic/migrator_test.go index 242a749a5..a47223d1f 100644 --- a/go/logic/migrator_test.go +++ b/go/logic/migrator_test.go @@ -114,8 +114,9 @@ func TestMigratorOnChangelogEvent(t *testing.T) { func TestMigratorValidateStatement(t *testing.T) { t.Run("add-column", func(t *testing.T) { migrationContext := base.NewMigrationContext() + migrationContext.AlterStatement = `ALTER TABLE test ADD test_new VARCHAR(64) NOT NULL` migrator := NewMigrator(migrationContext, "1.2.3") - tests.S(t).ExpectNil(migrator.parser.ParseAlterStatement(`ALTER TABLE test ADD test_new VARCHAR(64) NOT NULL`)) + tests.S(t).ExpectNil(migrator.parser.ParseStatement()) tests.S(t).ExpectNil(migrator.validateAlterStatement()) tests.S(t).ExpectEquals(len(migrator.migrationContext.DroppedColumnsMap), 0) @@ -123,8 +124,9 @@ func TestMigratorValidateStatement(t *testing.T) { t.Run("drop-column", func(t *testing.T) { migrationContext := base.NewMigrationContext() + migrationContext.AlterStatement = `ALTER TABLE test DROP abc` migrator := NewMigrator(migrationContext, "1.2.3") - tests.S(t).ExpectNil(migrator.parser.ParseAlterStatement(`ALTER TABLE test DROP abc`)) + tests.S(t).ExpectNil(migrator.parser.ParseStatement()) tests.S(t).ExpectNil(migrator.validateAlterStatement()) tests.S(t).ExpectEquals(len(migrator.migrationContext.DroppedColumnsMap), 1) @@ -134,8 +136,9 @@ func TestMigratorValidateStatement(t *testing.T) { t.Run("rename-column", func(t *testing.T) { migrationContext := base.NewMigrationContext() + migrationContext.AlterStatement = `ALTER TABLE test CHANGE test123 test1234 bigint unsigned` migrator := NewMigrator(migrationContext, "1.2.3") - tests.S(t).ExpectNil(migrator.parser.ParseAlterStatement(`ALTER TABLE test CHANGE test123 test1234 bigint unsigned`)) + tests.S(t).ExpectNil(migrator.parser.ParseStatement()) err := migrator.validateAlterStatement() tests.S(t).ExpectNotNil(err) @@ -145,9 +148,10 @@ func TestMigratorValidateStatement(t *testing.T) { t.Run("rename-column-approved", func(t *testing.T) { migrationContext := base.NewMigrationContext() + migrationContext.AlterStatement = `ALTER TABLE test CHANGE test123 test1234 bigint unsigned` migrator := NewMigrator(migrationContext, "1.2.3") migrator.migrationContext.ApproveRenamedColumns = true - tests.S(t).ExpectNil(migrator.parser.ParseAlterStatement(`ALTER TABLE test CHANGE test123 test1234 bigint unsigned`)) + tests.S(t).ExpectNil(migrator.parser.ParseStatement()) tests.S(t).ExpectNil(migrator.validateAlterStatement()) tests.S(t).ExpectEquals(len(migrator.migrationContext.DroppedColumnsMap), 0) @@ -155,8 +159,9 @@ func TestMigratorValidateStatement(t *testing.T) { t.Run("rename-table", func(t *testing.T) { migrationContext := base.NewMigrationContext() + migrationContext.AlterStatement = `ALTER TABLE test RENAME TO test_new` migrator := NewMigrator(migrationContext, "1.2.3") - tests.S(t).ExpectNil(migrator.parser.ParseAlterStatement(`ALTER TABLE test RENAME TO test_new`)) + tests.S(t).ExpectNil(migrator.parser.ParseStatement()) err := migrator.validateAlterStatement() tests.S(t).ExpectNotNil(err) diff --git a/go/sql/builder.go b/go/sql/builder.go index 7be428f93..dfeee885e 100644 --- a/go/sql/builder.go +++ b/go/sql/builder.go @@ -23,7 +23,7 @@ const ( ) // EscapeName will escape a db/table/column/... name by wrapping with backticks. -// It is not fool proof. I'm just trying to do the right thing here, not solving +// It is not full proof. I'm just trying to do the right thing here, not solving // SQL injection issues, which should be irrelevant for this tool. func EscapeName(name string) string { if unquoted, err := strconv.Unquote(name); err == nil { diff --git a/go/sql/create_table_parser.go b/go/sql/create_table_parser.go new file mode 100644 index 000000000..6bf141093 --- /dev/null +++ b/go/sql/create_table_parser.go @@ -0,0 +1,115 @@ +/* + Copyright 2022 GitHub Inc. + See https://github.com/github/gh-ost/blob/master/LICENSE +*/ + +package sql + +type CreateTableParser struct { + createTableStatementBody string + + explicitSchema string + explicitTable string + hasForeignKeys bool + isAutoIncrementDefined bool +} + +func NewCreateTableParser(statement string) *CreateTableParser { + return &CreateTableParser{ + createTableStatementBody: statement, + } +} + +func NewParserFromCreateTableStatement(alterStatement string) *CreateTableParser { + parser := NewCreateTableParser(alterStatement) + parser.ParseStatement() + return parser +} + +func (this *CreateTableParser) ParseStatement() (err error) { + for _, createTableRegexp := range createTableExplicitSchemaTableRegexps { + if submatch := createTableRegexp.FindStringSubmatch(this.createTableStatementBody); len(submatch) > 0 { + this.explicitSchema = submatch[1] + this.explicitTable = submatch[2] + this.createTableStatementBody = submatch[3] + break + } + } + for _, createTableRegexp := range createTableExplicitTableRegexps { + if submatch := createTableRegexp.FindStringSubmatch(this.createTableStatementBody); len(submatch) > 0 { + this.explicitTable = submatch[1] + this.createTableStatementBody = submatch[2] + break + } + } + for _, token := range tokenizeStatement(this.createTableStatementBody) { + token = sanitizeQuotesFromToken(token) + this.parseCreateTableToken(token) + } + return nil +} + +func (this *CreateTableParser) parseCreateTableToken(token string) { + { + // foreign key. + if foreignKeyTableRegexp.MatchString(token) { + this.hasForeignKeys = true + } + } + { + // auto_increment + if autoIncrementRegexp.MatchString(token) { + this.isAutoIncrementDefined = true + } + } +} + +func (this *CreateTableParser) Type() ParserType { + return ParserTypeCreateTable +} + +func (this *CreateTableParser) GetNonTrivialRenames() map[string]string { + return make(map[string]string) +} + +func (this *CreateTableParser) HasNonTrivialRenames() bool { + return len(this.GetNonTrivialRenames()) > 0 +} + +func (this *CreateTableParser) DroppedColumnsMap() map[string]bool { + return nil +} + +func (this *CreateTableParser) IsRenameTable() bool { + // We always return false because we need to check for table renames manually + // outside the parser. + return false +} + +func (this *CreateTableParser) IsAutoIncrementDefined() bool { + return this.isAutoIncrementDefined +} + +func (this *CreateTableParser) GetExplicitSchema() string { + return this.explicitSchema +} + +func (this *CreateTableParser) HasExplicitSchema() bool { + return this.GetExplicitSchema() != "" +} + +func (this *CreateTableParser) GetExplicitTable() string { + return this.explicitTable +} + +func (this *CreateTableParser) HasExplicitTable() bool { + return this.GetExplicitTable() != "" +} + +func (this *CreateTableParser) GetOptions() string { + return this.createTableStatementBody +} + +func (this *CreateTableParser) HasForeignKeys() bool { + return this.hasForeignKeys +} diff --git a/go/sql/parser.go b/go/sql/parser.go index 2ddc60f50..270781c96 100644 --- a/go/sql/parser.go +++ b/go/sql/parser.go @@ -16,6 +16,7 @@ var ( renameColumnRegexp = regexp.MustCompile(`(?i)\bchange\s+(column\s+|)([\S]+)\s+([\S]+)\s+`) dropColumnRegexp = regexp.MustCompile(`(?i)\bdrop\s+(column\s+|)([\S]+)$`) renameTableRegexp = regexp.MustCompile(`(?i)\brename\s+(to|as)\s+`) + foreignKeyTableRegexp = regexp.MustCompile(`(?i)\bforeign\s+key\s+`) autoIncrementRegexp = regexp.MustCompile(`(?i)\bauto_increment[\s]*=[\s]*([0-9]+)`) alterTableExplicitSchemaTableRegexps = []*regexp.Regexp{ // ALTER TABLE `scm`.`tbl` something @@ -36,6 +37,25 @@ var ( enumValuesRegexp = regexp.MustCompile("^enum[(](.*)[)]$") ) +var ( + createTableExplicitSchemaTableRegexps = []*regexp.Regexp{ + // CREATE TABLE `scm`.`tbl` something + regexp.MustCompile(`(?i)\bcreate\s+table\s+` + "`" + `([^` + "`" + `]+)` + "`" + `[.]` + "`" + `([^` + "`" + `]+)` + "`" + `\s+(.*$)`), + // CREATE TABLE `scm`.tbl something + regexp.MustCompile(`(?i)\bcreate\s+table\s+` + "`" + `([^` + "`" + `]+)` + "`" + `[.]([\S]+)\s+(.*$)`), + // CREATE TABLE scm.`tbl` something + regexp.MustCompile(`(?i)\bcreate\s+table\s+([\S]+)[.]` + "`" + `([^` + "`" + `]+)` + "`" + `\s+(.*$)`), + // CREATE TABLE scm.tbl something + regexp.MustCompile(`(?i)\bcreate\s+table\s+([\S]+)[.]([\S]+)\s+(.*$)`), + } + createTableExplicitTableRegexps = []*regexp.Regexp{ + // CREATE TABLE `tbl` something + regexp.MustCompile(`(?i)\bcreate\s+table\s+` + "`" + `([^` + "`" + `]+)` + "`" + `\s+(.*$)`), + // CREATE TABLE tbl something + regexp.MustCompile(`(?i)\bcreate\s+table\s+([\S]+)\s+(.*$)`), + } +) + type AlterTableParser struct { columnRenameMap map[string]string droppedColumns map[string]bool @@ -49,20 +69,21 @@ type AlterTableParser struct { explicitTable string } -func NewAlterTableParser() *AlterTableParser { +func NewAlterTableParser(statement string) Parser { return &AlterTableParser{ - columnRenameMap: make(map[string]string), - droppedColumns: make(map[string]bool), + alterStatementOptions: statement, + columnRenameMap: make(map[string]string), + droppedColumns: make(map[string]bool), } } -func NewParserFromAlterStatement(alterStatement string) *AlterTableParser { - parser := NewAlterTableParser() - parser.ParseAlterStatement(alterStatement) +func NewParserFromAlterStatement(alterStatement string) Parser { + parser := NewAlterTableParser(alterStatement) + parser.ParseStatement() return parser } -func (this *AlterTableParser) tokenizeAlterStatement(alterStatement string) (tokens []string) { +func tokenizeStatement(alterStatement string) (tokens []string) { terminatingQuote := rune(0) f := func(c rune) bool { switch { @@ -89,7 +110,7 @@ func (this *AlterTableParser) tokenizeAlterStatement(alterStatement string) (tok return tokens } -func (this *AlterTableParser) sanitizeQuotesFromAlterStatement(alterStatement string) (strippedStatement string) { +func sanitizeQuotesFromToken(alterStatement string) (strippedStatement string) { strippedStatement = alterStatement strippedStatement = sanitizeQuotesRegexp.ReplaceAllString(strippedStatement, "''") return strippedStatement @@ -133,8 +154,7 @@ func (this *AlterTableParser) parseAlterToken(alterToken string) { } } -func (this *AlterTableParser) ParseAlterStatement(alterStatement string) (err error) { - this.alterStatementOptions = alterStatement +func (this *AlterTableParser) ParseStatement() (err error) { for _, alterTableRegexp := range alterTableExplicitSchemaTableRegexps { if submatch := alterTableRegexp.FindStringSubmatch(this.alterStatementOptions); len(submatch) > 0 { this.explicitSchema = submatch[1] @@ -150,14 +170,18 @@ func (this *AlterTableParser) ParseAlterStatement(alterStatement string) (err er break } } - for _, alterToken := range this.tokenizeAlterStatement(this.alterStatementOptions) { - alterToken = this.sanitizeQuotesFromAlterStatement(alterToken) + for _, alterToken := range tokenizeStatement(this.alterStatementOptions) { + alterToken = sanitizeQuotesFromToken(alterToken) this.parseAlterToken(alterToken) this.alterTokens = append(this.alterTokens, alterToken) } return nil } +func (this *AlterTableParser) Type() ParserType { + return ParserTypeAlterTable +} + func (this *AlterTableParser) GetNonTrivialRenames() map[string]string { result := make(map[string]string) for column, renamed := range this.columnRenameMap { @@ -200,10 +224,14 @@ func (this *AlterTableParser) HasExplicitTable() bool { return this.GetExplicitTable() != "" } -func (this *AlterTableParser) GetAlterStatementOptions() string { +func (this *AlterTableParser) GetOptions() string { return this.alterStatementOptions } +func (this *AlterTableParser) HasForeignKeys() bool { + return false +} + func ParseEnumValues(enumColumnType string) string { if submatch := enumValuesRegexp.FindStringSubmatch(enumColumnType); len(submatch) > 0 { return submatch[1] diff --git a/go/sql/parser_test.go b/go/sql/parser_test.go index df9284280..10990d271 100644 --- a/go/sql/parser_test.go +++ b/go/sql/parser_test.go @@ -17,29 +17,30 @@ func init() { log.SetLevel(log.ERROR) } -func TestParseAlterStatement(t *testing.T) { +func TestParseStatement(t *testing.T) { statement := "add column t int, engine=innodb" - parser := NewAlterTableParser() - err := parser.ParseAlterStatement(statement) + parser := NewAlterTableParser(statement) + err := parser.ParseStatement() test.S(t).ExpectNil(err) - test.S(t).ExpectEquals(parser.alterStatementOptions, statement) + test.S(t).ExpectEquals(parser.GetOptions(), statement) test.S(t).ExpectFalse(parser.HasNonTrivialRenames()) test.S(t).ExpectFalse(parser.IsAutoIncrementDefined()) } -func TestParseAlterStatementTrivialRename(t *testing.T) { +func TestParseStatementTrivialRename(t *testing.T) { statement := "add column t int, change ts ts timestamp, engine=innodb" - parser := NewAlterTableParser() - err := parser.ParseAlterStatement(statement) + parser := NewAlterTableParser(statement) + err := parser.ParseStatement() test.S(t).ExpectNil(err) - test.S(t).ExpectEquals(parser.alterStatementOptions, statement) + test.S(t).ExpectEquals(parser.GetOptions(), statement) test.S(t).ExpectFalse(parser.HasNonTrivialRenames()) test.S(t).ExpectFalse(parser.IsAutoIncrementDefined()) - test.S(t).ExpectEquals(len(parser.columnRenameMap), 1) - test.S(t).ExpectEquals(parser.columnRenameMap["ts"], "ts") + p := parser.(*AlterTableParser) + test.S(t).ExpectEquals(len(p.columnRenameMap), 1) + test.S(t).ExpectEquals(p.columnRenameMap["ts"], "ts") } -func TestParseAlterStatementWithAutoIncrement(t *testing.T) { +func TestParseStatementWithAutoIncrement(t *testing.T) { statements := []string{ "auto_increment=7", "auto_increment = 7", @@ -50,28 +51,29 @@ func TestParseAlterStatementWithAutoIncrement(t *testing.T) { "add column t int, change ts ts timestamp, engine=innodb auto_increment=73425", } for _, statement := range statements { - parser := NewAlterTableParser() - err := parser.ParseAlterStatement(statement) + parser := NewAlterTableParser(statement) + err := parser.ParseStatement() test.S(t).ExpectNil(err) - test.S(t).ExpectEquals(parser.alterStatementOptions, statement) + test.S(t).ExpectEquals(parser.GetOptions(), statement) test.S(t).ExpectTrue(parser.IsAutoIncrementDefined()) } } -func TestParseAlterStatementTrivialRenames(t *testing.T) { +func TestParseStatementTrivialRenames(t *testing.T) { statement := "add column t int, change ts ts timestamp, CHANGE f `f` float, engine=innodb" - parser := NewAlterTableParser() - err := parser.ParseAlterStatement(statement) + parser := NewAlterTableParser(statement) + err := parser.ParseStatement() test.S(t).ExpectNil(err) - test.S(t).ExpectEquals(parser.alterStatementOptions, statement) + test.S(t).ExpectEquals(parser.GetOptions(), statement) test.S(t).ExpectFalse(parser.HasNonTrivialRenames()) test.S(t).ExpectFalse(parser.IsAutoIncrementDefined()) - test.S(t).ExpectEquals(len(parser.columnRenameMap), 2) - test.S(t).ExpectEquals(parser.columnRenameMap["ts"], "ts") - test.S(t).ExpectEquals(parser.columnRenameMap["f"], "f") + p := parser.(*AlterTableParser) + test.S(t).ExpectEquals(len(p.columnRenameMap), 2) + test.S(t).ExpectEquals(p.columnRenameMap["ts"], "ts") + test.S(t).ExpectEquals(p.columnRenameMap["f"], "f") } -func TestParseAlterStatementNonTrivial(t *testing.T) { +func TestParseStatementNonTrivial(t *testing.T) { statements := []string{ `add column b bigint, change f fl float, change i count int, engine=innodb`, "add column b bigint, change column `f` fl float, change `i` `count` int, engine=innodb", @@ -83,11 +85,11 @@ func TestParseAlterStatementNonTrivial(t *testing.T) { } for _, statement := range statements { - parser := NewAlterTableParser() - err := parser.ParseAlterStatement(statement) + parser := NewAlterTableParser(statement) + err := parser.ParseStatement() test.S(t).ExpectNil(err) test.S(t).ExpectFalse(parser.IsAutoIncrementDefined()) - test.S(t).ExpectEquals(parser.alterStatementOptions, statement) + test.S(t).ExpectEquals(parser.GetOptions(), statement) renames := parser.GetNonTrivialRenames() test.S(t).ExpectEquals(len(renames), 2) test.S(t).ExpectEquals(renames["i"], "count") @@ -96,226 +98,237 @@ func TestParseAlterStatementNonTrivial(t *testing.T) { } func TestTokenizeAlterStatement(t *testing.T) { - parser := NewAlterTableParser() { alterStatement := "add column t int" - tokens := parser.tokenizeAlterStatement(alterStatement) + tokens := tokenizeStatement(alterStatement) test.S(t).ExpectTrue(reflect.DeepEqual(tokens, []string{"add column t int"})) } { alterStatement := "add column t int, change column i int" - tokens := parser.tokenizeAlterStatement(alterStatement) + tokens := tokenizeStatement(alterStatement) test.S(t).ExpectTrue(reflect.DeepEqual(tokens, []string{"add column t int", "change column i int"})) } { alterStatement := "add column t int, change column i int 'some comment'" - tokens := parser.tokenizeAlterStatement(alterStatement) + tokens := tokenizeStatement(alterStatement) test.S(t).ExpectTrue(reflect.DeepEqual(tokens, []string{"add column t int", "change column i int 'some comment'"})) } { alterStatement := "add column t int, change column i int 'some comment, with comma'" - tokens := parser.tokenizeAlterStatement(alterStatement) + tokens := tokenizeStatement(alterStatement) test.S(t).ExpectTrue(reflect.DeepEqual(tokens, []string{"add column t int", "change column i int 'some comment, with comma'"})) } { alterStatement := "add column t int, add column d decimal(10,2)" - tokens := parser.tokenizeAlterStatement(alterStatement) + tokens := tokenizeStatement(alterStatement) test.S(t).ExpectTrue(reflect.DeepEqual(tokens, []string{"add column t int", "add column d decimal(10,2)"})) } { alterStatement := "add column t int, add column e enum('a','b','c')" - tokens := parser.tokenizeAlterStatement(alterStatement) + tokens := tokenizeStatement(alterStatement) test.S(t).ExpectTrue(reflect.DeepEqual(tokens, []string{"add column t int", "add column e enum('a','b','c')"})) } { alterStatement := "add column t int(11), add column e enum('a','b','c')" - tokens := parser.tokenizeAlterStatement(alterStatement) + tokens := tokenizeStatement(alterStatement) test.S(t).ExpectTrue(reflect.DeepEqual(tokens, []string{"add column t int(11)", "add column e enum('a','b','c')"})) } } func TestSanitizeQuotesFromAlterStatement(t *testing.T) { - parser := NewAlterTableParser() { alterStatement := "add column e enum('a','b','c')" - strippedStatement := parser.sanitizeQuotesFromAlterStatement(alterStatement) + strippedStatement := sanitizeQuotesFromToken(alterStatement) test.S(t).ExpectEquals(strippedStatement, "add column e enum('','','')") } { alterStatement := "change column i int 'some comment, with comma'" - strippedStatement := parser.sanitizeQuotesFromAlterStatement(alterStatement) + strippedStatement := sanitizeQuotesFromToken(alterStatement) test.S(t).ExpectEquals(strippedStatement, "change column i int ''") } } -func TestParseAlterStatementDroppedColumns(t *testing.T) { +func TestParseStatementDroppedColumns(t *testing.T) { { - parser := NewAlterTableParser() statement := "drop column b" - err := parser.ParseAlterStatement(statement) + parser := NewAlterTableParser(statement) + err := parser.ParseStatement() test.S(t).ExpectNil(err) - test.S(t).ExpectEquals(len(parser.droppedColumns), 1) - test.S(t).ExpectTrue(parser.droppedColumns["b"]) + p := parser.(*AlterTableParser) + test.S(t).ExpectEquals(len(p.droppedColumns), 1) + test.S(t).ExpectTrue(p.droppedColumns["b"]) } { - parser := NewAlterTableParser() statement := "drop column b, drop key c_idx, drop column `d`" - err := parser.ParseAlterStatement(statement) + parser := NewAlterTableParser(statement) + err := parser.ParseStatement() test.S(t).ExpectNil(err) - test.S(t).ExpectEquals(parser.alterStatementOptions, statement) - test.S(t).ExpectEquals(len(parser.droppedColumns), 2) - test.S(t).ExpectTrue(parser.droppedColumns["b"]) - test.S(t).ExpectTrue(parser.droppedColumns["d"]) + p := parser.(*AlterTableParser) + test.S(t).ExpectEquals(parser.GetOptions(), statement) + test.S(t).ExpectEquals(len(p.droppedColumns), 2) + test.S(t).ExpectTrue(p.droppedColumns["b"]) + test.S(t).ExpectTrue(p.droppedColumns["d"]) } { - parser := NewAlterTableParser() statement := "drop column b, drop key c_idx, drop column `d`, drop `e`, drop primary key, drop foreign key fk_1" - err := parser.ParseAlterStatement(statement) + parser := NewAlterTableParser(statement) + err := parser.ParseStatement() test.S(t).ExpectNil(err) - test.S(t).ExpectEquals(len(parser.droppedColumns), 3) - test.S(t).ExpectTrue(parser.droppedColumns["b"]) - test.S(t).ExpectTrue(parser.droppedColumns["d"]) - test.S(t).ExpectTrue(parser.droppedColumns["e"]) + p := parser.(*AlterTableParser) + test.S(t).ExpectEquals(len(p.droppedColumns), 3) + test.S(t).ExpectTrue(p.droppedColumns["b"]) + test.S(t).ExpectTrue(p.droppedColumns["d"]) + test.S(t).ExpectTrue(p.droppedColumns["e"]) } { - parser := NewAlterTableParser() statement := "drop column b, drop bad statement, add column i int" - err := parser.ParseAlterStatement(statement) + parser := NewAlterTableParser(statement) + err := parser.ParseStatement() test.S(t).ExpectNil(err) - test.S(t).ExpectEquals(len(parser.droppedColumns), 1) - test.S(t).ExpectTrue(parser.droppedColumns["b"]) + p := parser.(*AlterTableParser) + test.S(t).ExpectEquals(len(p.droppedColumns), 1) + test.S(t).ExpectTrue(p.droppedColumns["b"]) } } -func TestParseAlterStatementRenameTable(t *testing.T) { +func TestParseStatementRenameTable(t *testing.T) { { - parser := NewAlterTableParser() statement := "drop column b" - err := parser.ParseAlterStatement(statement) + parser := NewAlterTableParser(statement) + err := parser.ParseStatement() test.S(t).ExpectNil(err) - test.S(t).ExpectFalse(parser.isRenameTable) + test.S(t).ExpectFalse(parser.IsRenameTable()) } { - parser := NewAlterTableParser() statement := "rename as something_else" - err := parser.ParseAlterStatement(statement) + parser := NewAlterTableParser(statement) + err := parser.ParseStatement() test.S(t).ExpectNil(err) - test.S(t).ExpectTrue(parser.isRenameTable) + test.S(t).ExpectTrue(parser.IsRenameTable()) } { - parser := NewAlterTableParser() statement := "drop column b, rename as something_else" - err := parser.ParseAlterStatement(statement) + parser := NewAlterTableParser(statement) + err := parser.ParseStatement() test.S(t).ExpectNil(err) - test.S(t).ExpectEquals(parser.alterStatementOptions, statement) - test.S(t).ExpectTrue(parser.isRenameTable) + test.S(t).ExpectEquals(parser.GetOptions(), statement) + test.S(t).ExpectTrue(parser.IsRenameTable()) } { - parser := NewAlterTableParser() statement := "engine=innodb rename as something_else" - err := parser.ParseAlterStatement(statement) + parser := NewAlterTableParser(statement) + err := parser.ParseStatement() test.S(t).ExpectNil(err) - test.S(t).ExpectTrue(parser.isRenameTable) + test.S(t).ExpectTrue(parser.IsRenameTable()) } { - parser := NewAlterTableParser() statement := "rename as something_else, engine=innodb" - err := parser.ParseAlterStatement(statement) + parser := NewAlterTableParser(statement) + err := parser.ParseStatement() test.S(t).ExpectNil(err) - test.S(t).ExpectTrue(parser.isRenameTable) + test.S(t).ExpectTrue(parser.IsRenameTable()) } } -func TestParseAlterStatementExplicitTable(t *testing.T) { +func TestParseStatementExplicitTable(t *testing.T) { { - parser := NewAlterTableParser() statement := "drop column b" - err := parser.ParseAlterStatement(statement) + parser := NewAlterTableParser(statement) + err := parser.ParseStatement() test.S(t).ExpectNil(err) - test.S(t).ExpectEquals(parser.explicitSchema, "") - test.S(t).ExpectEquals(parser.explicitTable, "") - test.S(t).ExpectEquals(parser.alterStatementOptions, "drop column b") - test.S(t).ExpectTrue(reflect.DeepEqual(parser.alterTokens, []string{"drop column b"})) + test.S(t).ExpectEquals(parser.GetExplicitSchema(), "") + test.S(t).ExpectEquals(parser.GetExplicitTable(), "") + test.S(t).ExpectEquals(parser.GetOptions(), "drop column b") + p := parser.(*AlterTableParser) + test.S(t).ExpectTrue(reflect.DeepEqual(p.alterTokens, []string{"drop column b"})) } { - parser := NewAlterTableParser() statement := "alter table tbl drop column b" - err := parser.ParseAlterStatement(statement) + parser := NewAlterTableParser(statement) + err := parser.ParseStatement() test.S(t).ExpectNil(err) - test.S(t).ExpectEquals(parser.explicitSchema, "") - test.S(t).ExpectEquals(parser.explicitTable, "tbl") - test.S(t).ExpectEquals(parser.alterStatementOptions, "drop column b") - test.S(t).ExpectTrue(reflect.DeepEqual(parser.alterTokens, []string{"drop column b"})) + test.S(t).ExpectEquals(parser.GetExplicitSchema(), "") + test.S(t).ExpectEquals(parser.GetExplicitTable(), "tbl") + test.S(t).ExpectEquals(parser.GetOptions(), "drop column b") + p := parser.(*AlterTableParser) + test.S(t).ExpectTrue(reflect.DeepEqual(p.alterTokens, []string{"drop column b"})) } { - parser := NewAlterTableParser() statement := "alter table `tbl` drop column b" - err := parser.ParseAlterStatement(statement) + parser := NewAlterTableParser(statement) + err := parser.ParseStatement() test.S(t).ExpectNil(err) - test.S(t).ExpectEquals(parser.explicitSchema, "") - test.S(t).ExpectEquals(parser.explicitTable, "tbl") - test.S(t).ExpectEquals(parser.alterStatementOptions, "drop column b") - test.S(t).ExpectTrue(reflect.DeepEqual(parser.alterTokens, []string{"drop column b"})) + test.S(t).ExpectEquals(parser.GetExplicitSchema(), "") + test.S(t).ExpectEquals(parser.GetExplicitTable(), "tbl") + test.S(t).ExpectEquals(parser.GetOptions(), "drop column b") + p := parser.(*AlterTableParser) + test.S(t).ExpectTrue(reflect.DeepEqual(p.alterTokens, []string{"drop column b"})) } { - parser := NewAlterTableParser() statement := "alter table `scm with spaces`.`tbl` drop column b" - err := parser.ParseAlterStatement(statement) + parser := NewAlterTableParser(statement) + err := parser.ParseStatement() test.S(t).ExpectNil(err) - test.S(t).ExpectEquals(parser.explicitSchema, "scm with spaces") - test.S(t).ExpectEquals(parser.explicitTable, "tbl") - test.S(t).ExpectEquals(parser.alterStatementOptions, "drop column b") - test.S(t).ExpectTrue(reflect.DeepEqual(parser.alterTokens, []string{"drop column b"})) + test.S(t).ExpectEquals(parser.GetExplicitSchema(), "scm with spaces") + test.S(t).ExpectEquals(parser.GetExplicitTable(), "tbl") + test.S(t).ExpectEquals(parser.GetOptions(), "drop column b") + p := parser.(*AlterTableParser) + test.S(t).ExpectTrue(reflect.DeepEqual(p.alterTokens, []string{"drop column b"})) } { - parser := NewAlterTableParser() statement := "alter table `scm`.`tbl with spaces` drop column b" - err := parser.ParseAlterStatement(statement) + parser := NewAlterTableParser(statement) + err := parser.ParseStatement() test.S(t).ExpectNil(err) - test.S(t).ExpectEquals(parser.explicitSchema, "scm") - test.S(t).ExpectEquals(parser.explicitTable, "tbl with spaces") - test.S(t).ExpectEquals(parser.alterStatementOptions, "drop column b") - test.S(t).ExpectTrue(reflect.DeepEqual(parser.alterTokens, []string{"drop column b"})) + test.S(t).ExpectEquals(parser.GetExplicitSchema(), "scm") + test.S(t).ExpectEquals(parser.GetExplicitTable(), "tbl with spaces") + test.S(t).ExpectEquals(parser.GetOptions(), "drop column b") + p := parser.(*AlterTableParser) + test.S(t).ExpectTrue(reflect.DeepEqual(p.alterTokens, []string{"drop column b"})) } { - parser := NewAlterTableParser() statement := "alter table `scm`.tbl drop column b" - err := parser.ParseAlterStatement(statement) + parser := NewAlterTableParser(statement) + err := parser.ParseStatement() test.S(t).ExpectNil(err) - test.S(t).ExpectEquals(parser.explicitSchema, "scm") - test.S(t).ExpectEquals(parser.explicitTable, "tbl") - test.S(t).ExpectEquals(parser.alterStatementOptions, "drop column b") - test.S(t).ExpectTrue(reflect.DeepEqual(parser.alterTokens, []string{"drop column b"})) + test.S(t).ExpectEquals(parser.GetExplicitSchema(), "scm") + test.S(t).ExpectEquals(parser.GetExplicitTable(), "tbl") + test.S(t).ExpectEquals(parser.GetOptions(), "drop column b") + p := parser.(*AlterTableParser) + test.S(t).ExpectTrue(reflect.DeepEqual(p.alterTokens, []string{"drop column b"})) } { - parser := NewAlterTableParser() statement := "alter table scm.`tbl` drop column b" - err := parser.ParseAlterStatement(statement) + parser := NewAlterTableParser(statement) + err := parser.ParseStatement() test.S(t).ExpectNil(err) - test.S(t).ExpectEquals(parser.explicitSchema, "scm") - test.S(t).ExpectEquals(parser.explicitTable, "tbl") - test.S(t).ExpectEquals(parser.alterStatementOptions, "drop column b") - test.S(t).ExpectTrue(reflect.DeepEqual(parser.alterTokens, []string{"drop column b"})) + test.S(t).ExpectEquals(parser.GetExplicitSchema(), "scm") + test.S(t).ExpectEquals(parser.GetExplicitTable(), "tbl") + test.S(t).ExpectEquals(parser.GetOptions(), "drop column b") + p := parser.(*AlterTableParser) + test.S(t).ExpectTrue(reflect.DeepEqual(p.alterTokens, []string{"drop column b"})) } { - parser := NewAlterTableParser() statement := "alter table scm.tbl drop column b" - err := parser.ParseAlterStatement(statement) + parser := NewAlterTableParser(statement) + err := parser.ParseStatement() test.S(t).ExpectNil(err) - test.S(t).ExpectEquals(parser.explicitSchema, "scm") - test.S(t).ExpectEquals(parser.explicitTable, "tbl") - test.S(t).ExpectEquals(parser.alterStatementOptions, "drop column b") - test.S(t).ExpectTrue(reflect.DeepEqual(parser.alterTokens, []string{"drop column b"})) + test.S(t).ExpectEquals(parser.GetExplicitSchema(), "scm") + test.S(t).ExpectEquals(parser.GetExplicitTable(), "tbl") + test.S(t).ExpectEquals(parser.GetOptions(), "drop column b") + p := parser.(*AlterTableParser) + test.S(t).ExpectTrue(reflect.DeepEqual(p.alterTokens, []string{"drop column b"})) } { - parser := NewAlterTableParser() statement := "alter table scm.tbl drop column b, add index idx(i)" - err := parser.ParseAlterStatement(statement) + parser := NewAlterTableParser(statement) + err := parser.ParseStatement() test.S(t).ExpectNil(err) - test.S(t).ExpectEquals(parser.explicitSchema, "scm") - test.S(t).ExpectEquals(parser.explicitTable, "tbl") - test.S(t).ExpectEquals(parser.alterStatementOptions, "drop column b, add index idx(i)") - test.S(t).ExpectTrue(reflect.DeepEqual(parser.alterTokens, []string{"drop column b", "add index idx(i)"})) + test.S(t).ExpectEquals(parser.GetExplicitSchema(), "scm") + test.S(t).ExpectEquals(parser.GetExplicitTable(), "tbl") + test.S(t).ExpectEquals(parser.GetOptions(), "drop column b, add index idx(i)") + p := parser.(*AlterTableParser) + test.S(t).ExpectTrue(reflect.DeepEqual(p.alterTokens, []string{"drop column b", "add index idx(i)"})) } } diff --git a/go/sql/types.go b/go/sql/types.go index 3be1a44ca..f9a6d4b77 100644 --- a/go/sql/types.go +++ b/go/sql/types.go @@ -13,6 +13,41 @@ import ( "strings" ) +type ParserType int + +const ( + ParserTypeUnknown ParserType = iota + ParserTypeAlterTable + ParserTypeCreateTable +) + +func (t ParserType) String() string { + switch t { + case ParserTypeAlterTable: + return "alter" + case ParserTypeCreateTable: + return "create-table" + default: + return "" + } +} + +type Parser interface { + Type() ParserType + ParseStatement() error + GetNonTrivialRenames() map[string]string + HasNonTrivialRenames() bool + DroppedColumnsMap() map[string]bool + IsRenameTable() bool + IsAutoIncrementDefined() bool + GetExplicitSchema() string + HasExplicitSchema() bool + GetExplicitTable() string + HasExplicitTable() bool + HasForeignKeys() bool + GetOptions() string +} + type ColumnType int const ( diff --git a/localtests/create-table/create.sql b/localtests/create-table/create.sql new file mode 100644 index 000000000..138de0324 --- /dev/null +++ b/localtests/create-table/create.sql @@ -0,0 +1,29 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id int auto_increment, + a int not null, + b int not null, + primary key(id) +) auto_increment=1; + +drop event if exists gh_ost_test; +delimiter ;; +create event gh_ost_test + on schedule every 1 second + starts current_timestamp + ends current_timestamp + interval 60 second + on completion not preserve + enable + do +begin +insert into gh_ost_test (id, a, b) values (null, 2,3); +insert into gh_ost_test (id, a, b) values (null, 2,4); +insert into gh_ost_test (id, a, b) values (null, 2,5); +insert into gh_ost_test (id, a, b) values (null, 2,6); +insert into gh_ost_test (id, a, b) values (null, 2,7); +insert into gh_ost_test (id, a, b) values (null, 2,8); +insert into gh_ost_test (id, a, b) values (null, 2,9); +insert into gh_ost_test (id, a, b) values (null, 2,0); +insert into gh_ost_test (id, a, b) values (null, 2,1); +insert into gh_ost_test (id, a, b) values (null, 2,2); +end ;; diff --git a/localtests/create-table/extra_args b/localtests/create-table/extra_args new file mode 100644 index 000000000..e0885e421 --- /dev/null +++ b/localtests/create-table/extra_args @@ -0,0 +1,2 @@ +--alter="" +--create-table="(id int auto_increment, a int not null, b int not null, primary key(id), KEY new_index (a))"