Skip to content

Commit

Permalink
datetime: obey the evalengine's environment time (#14358)
Browse files Browse the repository at this point in the history
Signed-off-by: Vicent Marti <[email protected]>
  • Loading branch information
vmg authored Oct 25, 2023
1 parent 64085d8 commit cd2babb
Show file tree
Hide file tree
Showing 12 changed files with 84 additions and 77 deletions.
28 changes: 16 additions & 12 deletions go/mysql/datetime/datetime.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ func (t Time) FormatDecimal() decimal.Decimal {
return dec
}

func (t Time) ToDateTime() (out DateTime) {
return NewDateTimeFromStd(t.ToStdTime(time.Local))
func (t Time) ToDateTime(now time.Time) (out DateTime) {
return NewDateTimeFromStd(t.ToStdTime(now))
}

func (t Time) IsZero() bool {
Expand Down Expand Up @@ -421,9 +421,9 @@ func (t Time) toStdTime(year int, month time.Month, day int, loc *time.Location)
return time.Date(year, month, day, hours, minutes, secs, nsecs, loc)
}

func (t Time) ToStdTime(loc *time.Location) (out time.Time) {
year, month, day := time.Now().Date()
return t.toStdTime(year, month, day, loc)
func (t Time) ToStdTime(now time.Time) (out time.Time) {
year, month, day := now.Date()
return t.toStdTime(year, month, day, now.Location())
}

func (t Time) AddInterval(itv *Interval, stradd bool) (Time, uint8, bool) {
Expand All @@ -444,20 +444,20 @@ func (d Date) ToStdTime(loc *time.Location) (out time.Time) {
return time.Date(d.Year(), time.Month(d.Month()), d.Day(), 0, 0, 0, 0, loc)
}

func (dt DateTime) ToStdTime(loc *time.Location) time.Time {
func (dt DateTime) ToStdTime(now time.Time) time.Time {
zerodate := dt.Date.IsZero()
zerotime := dt.Time.IsZero()

switch {
case zerodate && zerotime:
return time.Time{}
case zerodate:
return dt.Time.ToStdTime(loc)
return dt.Time.ToStdTime(now)
case zerotime:
return dt.Date.ToStdTime(loc)
return dt.Date.ToStdTime(now.Location())
default:
year, month, day := dt.Date.Year(), time.Month(dt.Date.Month()), dt.Date.Day()
return dt.Time.toStdTime(year, month, day, loc)
return dt.Time.toStdTime(year, month, day, now.Location())
}
}

Expand Down Expand Up @@ -527,7 +527,10 @@ func (dt DateTime) Compare(dt2 DateTime) int {
// if we're comparing a time to a datetime, we need to normalize them
// both into datetimes; this normalization is not trivial because negative
// times result in a date change, so let the standard library handle this
return dt.ToStdTime(time.Local).Compare(dt2.ToStdTime(time.Local))

// Using the current time is OK here since the comparison is relative
now := time.Now()
return dt.ToStdTime(now).Compare(dt2.ToStdTime(now))
}
if cmp := dt.Date.Compare(dt2.Date); cmp != 0 {
return cmp
Expand Down Expand Up @@ -559,9 +562,10 @@ func (dt DateTime) Round(p int) (r DateTime) {
r = dt
if n == 1e9 {
r.Time.nanosecond = 0
return NewDateTimeFromStd(r.ToStdTime(time.Local).Add(time.Second))
r.addInterval(&Interval{timeparts: timeparts{sec: 1}, unit: IntervalSecond})
} else {
r.Time.nanosecond = uint32(n)
}
r.Time.nanosecond = uint32(n)
return r
}

Expand Down
2 changes: 1 addition & 1 deletion go/mysql/json/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ func (v *Value) MarshalDate() string {

func (v *Value) MarshalDateTime() string {
if dt, ok := v.DateTime(); ok {
return dt.ToStdTime(time.Local).Format("2006-01-02 15:04:05.000000")
return dt.ToStdTime(time.Now()).Format("2006-01-02 15:04:05.000000")
}
return ""
}
Expand Down
16 changes: 8 additions & 8 deletions go/vt/vtgate/evalengine/compiler_asm.go
Original file line number Diff line number Diff line change
Expand Up @@ -529,12 +529,12 @@ func (asm *assembler) CmpCase(cases int, hasElse bool, tt sqltypes.Type, cc coll
end := env.vm.sp - elseOffset
for sp := env.vm.sp - stackDepth; sp < end; sp += 2 {
if env.vm.stack[sp].(*evalInt64).i != 0 {
env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[sp+1], tt, cc.Collation)
env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[sp+1], tt, cc.Collation, env.now)
goto done
}
}
if elseOffset != 0 {
env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[env.vm.sp-1], tt, cc.Collation)
env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[env.vm.sp-1], tt, cc.Collation, env.now)
} else {
env.vm.stack[env.vm.sp-stackDepth] = nil
}
Expand Down Expand Up @@ -1110,7 +1110,7 @@ func (asm *assembler) Convert_xD(offset int) {
// Need to explicitly check here or we otherwise
// store a nil wrapper in an interface vs. a direct
// nil.
d := evalToDate(env.vm.stack[env.vm.sp-offset])
d := evalToDate(env.vm.stack[env.vm.sp-offset], env.now)
if d == nil {
env.vm.stack[env.vm.sp-offset] = nil
} else {
Expand All @@ -1125,7 +1125,7 @@ func (asm *assembler) Convert_xD_nz(offset int) {
// Need to explicitly check here or we otherwise
// store a nil wrapper in an interface vs. a direct
// nil.
d := evalToDate(env.vm.stack[env.vm.sp-offset])
d := evalToDate(env.vm.stack[env.vm.sp-offset], env.now)
if d == nil || d.isZero() {
env.vm.stack[env.vm.sp-offset] = nil
} else {
Expand All @@ -1140,7 +1140,7 @@ func (asm *assembler) Convert_xDT(offset, prec int) {
// Need to explicitly check here or we otherwise
// store a nil wrapper in an interface vs. a direct
// nil.
dt := evalToDateTime(env.vm.stack[env.vm.sp-offset], prec)
dt := evalToDateTime(env.vm.stack[env.vm.sp-offset], prec, env.now)
if dt == nil {
env.vm.stack[env.vm.sp-offset] = nil
} else {
Expand All @@ -1155,7 +1155,7 @@ func (asm *assembler) Convert_xDT_nz(offset, prec int) {
// Need to explicitly check here or we otherwise
// store a nil wrapper in an interface vs. a direct
// nil.
dt := evalToDateTime(env.vm.stack[env.vm.sp-offset], prec)
dt := evalToDateTime(env.vm.stack[env.vm.sp-offset], prec, env.now)
if dt == nil || dt.isZero() {
env.vm.stack[env.vm.sp-offset] = nil
} else {
Expand Down Expand Up @@ -4252,7 +4252,7 @@ func (asm *assembler) Fn_DATEADD_D(unit datetime.IntervalType, sub bool) {
}

tmp := env.vm.stack[env.vm.sp-2].(*evalTemporal)
env.vm.stack[env.vm.sp-2] = tmp.addInterval(interval, collations.TypedCollation{})
env.vm.stack[env.vm.sp-2] = tmp.addInterval(interval, collations.TypedCollation{}, env.now)
env.vm.sp--
return 1
}, "FN DATEADD TEMPORAL(SP-2), INTERVAL(SP-1)")
Expand All @@ -4274,7 +4274,7 @@ func (asm *assembler) Fn_DATEADD_s(unit datetime.IntervalType, sub bool, col col
goto baddate
}

env.vm.stack[env.vm.sp-2] = tmp.addInterval(interval, col)
env.vm.stack[env.vm.sp-2] = tmp.addInterval(interval, col, env.now)
env.vm.sp--
return 1

Expand Down
3 changes: 3 additions & 0 deletions go/vt/vtgate/evalengine/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,8 @@ func TestCompilerSingle(t *testing.T) {
},
}

tz, _ := time.LoadLocation("Europe/Madrid")

for _, tc := range testCases {
t.Run(tc.expression, func(t *testing.T) {
expr, err := sqlparser.ParseExpr(tc.expression)
Expand All @@ -478,6 +480,7 @@ func TestCompilerSingle(t *testing.T) {
}

env := evalengine.EmptyExpressionEnv()
env.SetTime(time.Date(2023, 10, 24, 12, 0, 0, 0, tz))
env.Row = tc.values

expected, err := env.Evaluate(evalengine.Deoptimize(converted))
Expand Down
11 changes: 6 additions & 5 deletions go/vt/vtgate/evalengine/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package evalengine

import (
"strconv"
"time"
"unicode/utf8"

"vitess.io/vitess/go/hack"
Expand Down Expand Up @@ -167,7 +168,7 @@ func evalIsTruthy(e eval) boolean {
}
}

func evalCoerce(e eval, typ sqltypes.Type, col collations.ID) (eval, error) {
func evalCoerce(e eval, typ sqltypes.Type, col collations.ID, now time.Time) (eval, error) {
if e == nil {
return nil, nil
}
Expand Down Expand Up @@ -199,9 +200,9 @@ func evalCoerce(e eval, typ sqltypes.Type, col collations.ID) (eval, error) {
case sqltypes.Uint8, sqltypes.Uint16, sqltypes.Uint32, sqltypes.Uint64:
return evalToInt64(e).toUint64(), nil
case sqltypes.Date:
return evalToDate(e), nil
return evalToDate(e, now), nil
case sqltypes.Datetime, sqltypes.Timestamp:
return evalToDateTime(e, -1), nil
return evalToDateTime(e, -1, now), nil
case sqltypes.Time:
return evalToTime(e, -1), nil
default:
Expand Down Expand Up @@ -329,7 +330,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
return nil, err
}
// Separate return here to avoid nil wrapped in interface type
d := evalToDate(e)
d := evalToDate(e, time.Now())
if d == nil {
return nil, nil
}
Expand All @@ -340,7 +341,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
return nil, err
}
// Separate return here to avoid nil wrapped in interface type
dt := evalToDateTime(e, -1)
dt := evalToDateTime(e, -1, time.Now())
if dt == nil {
return nil, nil
}
Expand Down
22 changes: 12 additions & 10 deletions go/vt/vtgate/evalengine/eval_temporal.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package evalengine

import (
"time"

"vitess.io/vitess/go/hack"
"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/mysql/datetime"
Expand Down Expand Up @@ -92,12 +94,12 @@ func (e *evalTemporal) toJSON() *evalJSON {
}
}

func (e *evalTemporal) toDateTime(l int) *evalTemporal {
func (e *evalTemporal) toDateTime(l int, now time.Time) *evalTemporal {
switch e.SQLType() {
case sqltypes.Datetime, sqltypes.Date:
return &evalTemporal{t: sqltypes.Datetime, dt: e.dt.Round(l), prec: uint8(l)}
case sqltypes.Time:
return &evalTemporal{t: sqltypes.Datetime, dt: e.dt.Time.Round(l).ToDateTime(), prec: uint8(l)}
return &evalTemporal{t: sqltypes.Datetime, dt: e.dt.Time.Round(l).ToDateTime(now), prec: uint8(l)}
default:
panic("unreachable")
}
Expand All @@ -118,15 +120,15 @@ func (e *evalTemporal) toTime(l int) *evalTemporal {
}
}

func (e *evalTemporal) toDate() *evalTemporal {
func (e *evalTemporal) toDate(now time.Time) *evalTemporal {
switch e.SQLType() {
case sqltypes.Datetime:
dt := datetime.DateTime{Date: e.dt.Date}
return &evalTemporal{t: sqltypes.Date, dt: dt}
case sqltypes.Date:
return e
case sqltypes.Time:
dt := e.dt.Time.ToDateTime()
dt := e.dt.Time.ToDateTime(now)
dt.Time = datetime.Time{}
return &evalTemporal{t: sqltypes.Date, dt: dt}
default:
Expand All @@ -138,7 +140,7 @@ func (e *evalTemporal) isZero() bool {
return e.dt.IsZero()
}

func (e *evalTemporal) addInterval(interval *datetime.Interval, strcoll collations.TypedCollation) eval {
func (e *evalTemporal) addInterval(interval *datetime.Interval, strcoll collations.TypedCollation, now time.Time) eval {
var tmp *evalTemporal
var ok bool

Expand All @@ -150,7 +152,7 @@ func (e *evalTemporal) addInterval(interval *datetime.Interval, strcoll collatio
tmp = &evalTemporal{t: e.t}
tmp.dt.Time, tmp.prec, ok = e.dt.Time.AddInterval(interval, strcoll.Valid())
case tt == sqltypes.Datetime || tt == sqltypes.Timestamp || (tt == sqltypes.Date && interval.Unit().HasTimeParts()) || (tt == sqltypes.Time && interval.Unit().HasDateParts()):
tmp = e.toDateTime(int(e.prec))
tmp = e.toDateTime(int(e.prec), now)
tmp.dt, tmp.prec, ok = e.dt.AddInterval(interval, strcoll.Valid())
}
if !ok {
Expand Down Expand Up @@ -324,10 +326,10 @@ func evalToTime(e eval, l int) *evalTemporal {
return nil
}

func evalToDateTime(e eval, l int) *evalTemporal {
func evalToDateTime(e eval, l int, now time.Time) *evalTemporal {
switch e := e.(type) {
case *evalTemporal:
return e.toDateTime(precision(l, int(e.prec)))
return e.toDateTime(precision(l, int(e.prec)), now)
case *evalBytes:
if t, l, _ := datetime.ParseDateTime(e.string(), l); !t.IsZero() {
return newEvalDateTime(t, l)
Expand Down Expand Up @@ -371,10 +373,10 @@ func evalToDateTime(e eval, l int) *evalTemporal {
return nil
}

func evalToDate(e eval) *evalTemporal {
func evalToDate(e eval, now time.Time) *evalTemporal {
switch e := e.(type) {
case *evalTemporal:
return e.toDate()
return e.toDate(now)
case *evalBytes:
if t, _ := datetime.ParseDate(e.string()); !t.IsZero() {
return newEvalDate(t)
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/evalengine/expr_convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,12 @@ func (c *ConvertExpr) eval(env *ExpressionEnv) (eval, error) {
case p > 6:
return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Too-big precision %d specified for 'CONVERT'. Maximum is 6.", p)
}
if dt := evalToDateTime(e, c.Length); dt != nil {
if dt := evalToDateTime(e, c.Length, env.now); dt != nil {
return dt, nil
}
return nil, nil
case "DATE":
if d := evalToDate(e); d != nil {
if d := evalToDate(e, env.now); d != nil {
return d, nil
}
return nil, nil
Expand Down
19 changes: 10 additions & 9 deletions go/vt/vtgate/evalengine/expr_env.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,15 @@ func (env *ExpressionEnv) TypeOf(expr Expr, fields []*querypb.Field) (sqltypes.T
return ty, f, nil
}

func (env *ExpressionEnv) SetTime(now time.Time) {
// This function is called only once by NewExpressionEnv to ensure that all expressions in the same
// ExpressionEnv evaluate NOW() and similar SQL functions to the same value.
env.now = now
if tz := env.currentTimezone(); tz != nil {
env.now = env.now.In(tz)
}
}

// EmptyExpressionEnv returns a new ExpressionEnv with no bind vars or row
func EmptyExpressionEnv() *ExpressionEnv {
return NewExpressionEnv(context.Background(), nil, nil)
Expand All @@ -108,14 +117,6 @@ func EmptyExpressionEnv() *ExpressionEnv {
func NewExpressionEnv(ctx context.Context, bindVars map[string]*querypb.BindVariable, vc VCursor) *ExpressionEnv {
env := &ExpressionEnv{BindVars: bindVars, vc: vc}
env.user = callerid.ImmediateCallerIDFromContext(ctx)

// The current time for this ExpressionEnv is set only once, during creation.
// This is to ensure that all expressions in the same ExpressionEnv evaluate NOW()
// and similar SQL functions to the same value.
env.now = time.Now()

if tz := env.currentTimezone(); tz != nil {
env.now = env.now.In(tz)
}
env.SetTime(time.Now())
return env
}
2 changes: 1 addition & 1 deletion go/vt/vtgate/evalengine/expr_logical.go
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ func (c *CaseExpr) eval(env *ExpressionEnv) (eval, error) {
return nil, nil
}
t, _ := c.typeof(env, nil)
return evalCoerce(result, t, ca.result().Collation)
return evalCoerce(result, t, ca.result().Collation, env.now)
}

func (c *CaseExpr) typeof(env *ExpressionEnv, fields []*querypb.Field) (sqltypes.Type, typeFlag) {
Expand Down
Loading

0 comments on commit cd2babb

Please sign in to comment.