Skip to content

Commit 49b8a44

Browse files
authored
Merge pull request #3305 from dolthub/angela/date_functions
Fix datetime functions to return correct results for `0` and `false`
2 parents 16344ed + 718eb86 commit 49b8a44

File tree

17 files changed

+1079
-245
lines changed

17 files changed

+1079
-245
lines changed

enginetest/queries/function_queries.go

Lines changed: 702 additions & 10 deletions
Large diffs are not rendered by default.

enginetest/queries/insert_queries.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2349,6 +2349,26 @@ var InsertScripts = []ScriptTest{
23492349
},
23502350
},
23512351
},
2352+
{
2353+
Name: "inserting zero date",
2354+
Dialect: "mysql",
2355+
SetUpScript: []string{
2356+
"create table t(d date)",
2357+
"insert into t values ('0000-00-00')",
2358+
"create table t2(d datetime)",
2359+
"insert into t2 values ('0000-00-00')",
2360+
},
2361+
Assertions: []ScriptTestAssertion{
2362+
{
2363+
Query: "select * from t",
2364+
Expected: []sql.Row{{types.ZeroTime}},
2365+
},
2366+
{
2367+
Query: "select * from t2",
2368+
Expected: []sql.Row{{types.ZeroTime}},
2369+
},
2370+
},
2371+
},
23522372
}
23532373

23542374
var InsertDuplicateKeyKeyless = []ScriptTest{

enginetest/queries/queries.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8352,13 +8352,6 @@ from typestable`,
83528352
},
83538353
},
83548354
},
8355-
{
8356-
// TODO: This goes past MySQL's range
8357-
Query: "select dayname('0000-00-00')",
8358-
Expected: []sql.Row{
8359-
{"Saturday"},
8360-
},
8361-
},
83628355
{
83638356
Query: "select * from mytable order by dayname(i)",
83648357
Expected: []sql.Row{

enginetest/queries/script_queries.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6175,10 +6175,10 @@ CREATE TABLE tab3 (
61756175
"0",
61766176
float64(0),
61776177
float64(0),
6178-
time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC),
6178+
time.Date(0, 0, 0, 0, 0, 0, 0, time.UTC),
61796179
types.Timespan(0),
6180-
time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC),
6181-
time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC),
6180+
time.Date(0, 0, 0, 0, 0, 0, 0, time.UTC),
6181+
time.Date(0, 0, 0, 0, 0, 0, 0, time.UTC),
61826182
0,
61836183
"",
61846184
"",

enginetest/queries/update_queries.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1011,7 +1011,7 @@ var UpdateErrorScripts = []ScriptTest{
10111011
},
10121012
}
10131013

1014-
var ZeroTime = time.Date(0000, time.January, 1, 0, 0, 0, 0, time.UTC)
1014+
var ZeroTime = time.Date(0000, 0, 0, 0, 0, 0, 0, time.UTC)
10151015
var Jan1Noon = time.Date(2000, time.January, 1, 12, 0, 0, 0, time.UTC)
10161016
var Dec15_1_30 = time.Date(2023, time.December, 15, 1, 30, 0, 0, time.UTC)
10171017
var Oct2Midnight = time.Date(2020, time.October, 2, 0, 0, 0, 0, time.UTC)

sql/expression/arithmetic.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,13 @@ func convertValueToType(ctx *sql.Context, typ sql.Type, val interface{}, isTimeT
440440
// the value is interpreted as 0, but we need to match the type of the other valid value
441441
// to avoid additional conversion, the nil value is handled in each operation
442442
}
443+
if types.IsTime(typ) {
444+
time, ok := cval.(time.Time)
445+
if !ok || time.Equal(types.ZeroTime) {
446+
ctx.Warn(1292, "Incorrect datetime value: '%s'", val)
447+
return nil
448+
}
449+
}
443450
return cval
444451
}
445452

@@ -462,6 +469,9 @@ func convertTimeTypeToString(val interface{}) interface{} {
462469
}
463470

464471
func plus(lval, rval interface{}) (interface{}, error) {
472+
if lval == nil || rval == nil {
473+
return nil, nil
474+
}
465475
switch l := lval.(type) {
466476
case uint8:
467477
switch r := rval.(type) {
@@ -536,6 +546,9 @@ func plus(lval, rval interface{}) (interface{}, error) {
536546
}
537547

538548
func minus(lval, rval interface{}) (interface{}, error) {
549+
if lval == nil || rval == nil {
550+
return nil, nil
551+
}
539552
switch l := lval.(type) {
540553
case uint8:
541554
switch r := rval.(type) {

sql/expression/function/days.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ func (t *ToDays) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
9898
return nil, nil
9999
}
100100
d := date.(time.Time)
101+
if d.Equal(types.ZeroTime) {
102+
return nil, nil
103+
}
101104

102105
// Using zeroTime.Sub(date) doesn't work because it overflows time.Duration
103106
// so we need to calculate the number of days manually

sql/expression/function/days_test.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ func TestToDays(t *testing.T) {
4545
arg: expression.NewLiteral("-10", types.Int32),
4646
exp: nil,
4747
},
48-
48+
{
49+
arg: expression.NewLiteral("0", types.Int32),
50+
exp: nil,
51+
},
4952
{
5053
arg: expression.NewLiteral("0000-00-00", types.Text),
5154
exp: nil,

sql/expression/function/extract.go

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ func (td *Extract) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
114114

115115
switch unit {
116116
case "DAY":
117-
return dateTime.Day(), nil
117+
return day(dateTime), nil
118118
case "HOUR":
119119
return dateTime.Hour(), nil
120120
case "MINUTE":
@@ -124,57 +124,66 @@ func (td *Extract) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
124124
case "MICROSECOND":
125125
return dateTime.Nanosecond() / 1000, nil
126126
case "QUARTER":
127-
return (int(dateTime.Month())-1)/3 + 1, nil
127+
return quarter(dateTime), nil
128128
case "MONTH":
129-
return int(dateTime.Month()), nil
129+
return month(dateTime), nil
130130
case "WEEK":
131-
date, err := getDate(ctx, expression.UnaryExpression{Child: td.RightChild}, row)
132-
if err != nil {
133-
return nil, err
134-
}
135-
yyyy, ok := year(date).(int32)
131+
yyyy, ok := year(dateTime).(int)
136132
if !ok {
137133
return nil, sql.ErrInvalidArgumentDetails.New("WEEK", "invalid year")
138134
}
139-
mm, ok := month(date).(int32)
135+
mm, ok := month(dateTime).(int)
140136
if !ok {
141137
return nil, sql.ErrInvalidArgumentDetails.New("WEEK", "invalid month")
142138
}
143-
dd, ok := day(date).(int32)
139+
dd, ok := day(dateTime).(int)
144140
if !ok {
145141
return nil, sql.ErrInvalidArgumentDetails.New("WEEK", "invalid day")
146142
}
147-
yearForWeek, week := calcWeek(yyyy, mm, dd, weekBehaviourYear)
148-
if yearForWeek < yyyy {
143+
yr := int32(yyyy)
144+
yearForWeek, week := calcWeek(yr, int32(mm), int32(dd), weekBehaviourYear)
145+
if yearForWeek < yr {
149146
week = 0
150-
} else if yearForWeek > yyyy {
147+
} else if yearForWeek > yr {
151148
week = 53
152149
}
153150
return int(week), nil
154151
case "YEAR":
155-
return dateTime.Year(), nil
152+
return year(dateTime), nil
156153
case "DAY_HOUR":
157-
dd := dateTime.Day() * 1_00
154+
dd, ok := day(dateTime).(int)
155+
if !ok {
156+
return nil, sql.ErrInvalidArgumentDetails.New("DAY_HOUR", "invalid day")
157+
}
158158
hh := dateTime.Hour()
159-
return dd + hh, nil
159+
return (dd * 1_00) + hh, nil
160160
case "DAY_MINUTE":
161-
dd := dateTime.Day() * 1_00_00
161+
dd, ok := day(dateTime).(int)
162+
if !ok {
163+
return nil, sql.ErrInvalidArgumentDetails.New("DAY_MINUTE", "invalid day")
164+
}
162165
hh := dateTime.Hour() * 1_00
163166
mm := dateTime.Minute()
164-
return dd + hh + mm, nil
167+
return (dd * 1_00_00) + hh + mm, nil
165168
case "DAY_SECOND":
166-
dd := dateTime.Day() * 1_00_00_00
169+
dd, ok := day(dateTime).(int)
170+
if !ok {
171+
return nil, sql.ErrInvalidArgumentDetails.New("DAY_SECOND", "invalid day")
172+
}
167173
hh := dateTime.Hour() * 1_00_00
168174
mm := dateTime.Minute() * 1_00
169175
ss := dateTime.Second()
170-
return dd + hh + mm + ss, nil
176+
return (dd * 1_00_00_00) + hh + mm + ss, nil
171177
case "DAY_MICROSECOND":
172-
dd := dateTime.Day() * 1_00_00_00_000000
178+
dd, ok := day(dateTime).(int)
179+
if !ok {
180+
return nil, sql.ErrInvalidArgumentDetails.New("DAY_MICROSECOND", "invalid day")
181+
}
173182
hh := dateTime.Hour() * 1_00_00_000000
174183
mm := dateTime.Minute() * 1_00_000000
175184
ss := dateTime.Second() * 1_000000
176185
mmmmmm := dateTime.Nanosecond() / 1000
177-
return dd + hh + mm + ss + mmmmmm, nil
186+
return (dd * 1_00_00_00_000000) + hh + mm + ss + mmmmmm, nil
178187
case "HOUR_MINUTE":
179188
hh := dateTime.Hour() * 1_00
180189
mm := dateTime.Minute()
@@ -204,10 +213,15 @@ func (td *Extract) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
204213
mmmmmm := dateTime.Nanosecond() / 1000
205214
return ss + mmmmmm, nil
206215
case "YEAR_MONTH":
207-
yyyy := dateTime.Year() * 1_00
208-
dateTime.Month()
209-
mm := int(dateTime.Month())
210-
return yyyy + mm, nil
216+
yyyy, ok := year(dateTime).(int)
217+
if !ok {
218+
return nil, sql.ErrInvalidArgumentDetails.New("YEAR_MONTH", "invalid year")
219+
}
220+
mm, ok := month(dateTime).(int)
221+
if !ok {
222+
return nil, sql.ErrInvalidArgumentDetails.New("YEAR_MONTH", "invalid month")
223+
}
224+
return (yyyy * 1_00) + mm, nil
211225
default:
212226
return nil, fmt.Errorf("invalid time unit")
213227
}

sql/expression/function/registry.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,6 @@ var BuiltIns = []sql.Function{
320320
sql.Function0{Name: "uuid", Fn: NewUUIDFunc},
321321
sql.Function0{Name: "uuid_short", Fn: NewUUIDShortFunc},
322322
sql.FunctionN{Name: "uuid_to_bin", Fn: NewUUIDToBin},
323-
sql.FunctionN{Name: "week", Fn: NewWeek},
324323
sql.Function1{Name: "values", Fn: NewValues},
325324
sql.Function1{Name: "validate_password_strength", Fn: NewValidatePasswordStrength},
326325
sql.Function1{Name: "variance", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewVarPop(e) }},
@@ -338,6 +337,7 @@ var BuiltIns = []sql.Function{
338337
sql.Function1{Name: "vector_to_string", Fn: vector.NewVectorToString},
339338
sql.Function1{Name: "from_vector", Fn: vector.NewVectorToString},
340339
sql.Function1{Name: "vec_totext", Fn: vector.NewVectorToString},
340+
sql.FunctionN{Name: "week", Fn: NewWeek},
341341
sql.Function1{Name: "weekday", Fn: NewWeekday},
342342
sql.Function1{Name: "weekofyear", Fn: NewWeekOfYear},
343343
sql.Function1{Name: "year", Fn: NewYear},

0 commit comments

Comments
 (0)