diff --git a/internal/locate/region_cache_test.go b/internal/locate/region_cache_test.go index caefb2ae53..9f53690144 100644 --- a/internal/locate/region_cache_test.go +++ b/internal/locate/region_cache_test.go @@ -51,6 +51,7 @@ import ( "github.com/tikv/client-go/v2/internal/mockstore/mocktikv" "github.com/tikv/client-go/v2/internal/retry" "github.com/tikv/client-go/v2/kv" + "github.com/tikv/client-go/v2/oracle" pd "github.com/tikv/pd/client" ) @@ -1001,7 +1002,7 @@ func (s *testRegionCacheSuite) TestRegionEpochOnTiFlash() { s.Equal(ctxTiFlash.Peer.Id, s.peer1) ctxTiFlash.Peer.Role = metapb.PeerRole_Learner r := ctxTiFlash.Meta - reqSend := NewRegionRequestSender(s.cache, nil) + reqSend := NewRegionRequestSender(s.cache, nil, oracle.NoopReadTSValidator{}) regionErr := &errorpb.Error{EpochNotMatch: &errorpb.EpochNotMatch{CurrentRegions: []*metapb.Region{r}}} reqSend.onRegionError(s.bo, ctxTiFlash, nil, regionErr) @@ -1601,7 +1602,7 @@ func (s *testRegionCacheSuite) TestShouldNotRetryFlashback() { ctx, err := s.cache.GetTiKVRPCContext(retry.NewBackofferWithVars(context.Background(), 100, nil), loc.Region, kv.ReplicaReadLeader, 0) s.NotNil(ctx) s.NoError(err) - reqSend := NewRegionRequestSender(s.cache, nil) + reqSend := NewRegionRequestSender(s.cache, nil, oracle.NoopReadTSValidator{}) shouldRetry, err := reqSend.onRegionError(s.bo, ctx, nil, &errorpb.Error{FlashbackInProgress: &errorpb.FlashbackInProgress{}}) s.Error(err) s.False(shouldRetry) diff --git a/internal/locate/region_request.go b/internal/locate/region_request.go index 86c5ad8bd9..d4a8c2c7ad 100644 --- a/internal/locate/region_request.go +++ b/internal/locate/region_request.go @@ -45,6 +45,7 @@ import ( "sync/atomic" "time" + "github.com/tikv/client-go/v2/oracle" "go.uber.org/zap" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -102,6 +103,7 @@ type RegionRequestSender struct { regionCache *RegionCache apiVersion kvrpcpb.APIVersion client client.Client + readTSValidator oracle.ReadTSValidator storeAddr string rpcError error replicaSelector *replicaSelector @@ -190,11 +192,12 @@ func RecordRegionRequestRuntimeStats(stats map[tikvrpc.CmdType]*RPCRuntimeStats, } // NewRegionRequestSender creates a new sender. -func NewRegionRequestSender(regionCache *RegionCache, client client.Client) *RegionRequestSender { +func NewRegionRequestSender(regionCache *RegionCache, client client.Client, readTSValidator oracle.ReadTSValidator) *RegionRequestSender { return &RegionRequestSender{ - regionCache: regionCache, - apiVersion: regionCache.apiVersion, - client: client, + regionCache: regionCache, + apiVersion: regionCache.apiVersion, + client: client, + readTSValidator: readTSValidator, } } @@ -1088,6 +1091,11 @@ func (s *RegionRequestSender) SendReqCtx( } } + if err = s.validateReadTS(bo.GetCtx(), req); err != nil { + logutil.Logger(bo.GetCtx()).Error("validate read ts failed for request", zap.Stringer("reqType", req.Type), zap.Stringer("req", req.Req.(fmt.Stringer)), zap.Stringer("context", &req.Context), zap.Stack("stack"), zap.Error(err)) + return nil, nil, err + } + // If the MaxExecutionDurationMs is not set yet, we set it to be the RPC timeout duration // so TiKV can give up the requests whose response TiDB cannot receive due to timeout. if req.Context.MaxExecutionDurationMs == 0 { @@ -1870,6 +1878,44 @@ func (s *RegionRequestSender) onRegionError(bo *retry.Backoffer, ctx *RPCContext return false, nil } +func (s *RegionRequestSender) validateReadTS(ctx context.Context, req *tikvrpc.Request) error { + if req.StoreTp == tikvrpc.TiDB { + // Skip the checking if the store type is TiDB. + return nil + } + + var readTS uint64 + switch req.Type { + case tikvrpc.CmdGet, tikvrpc.CmdScan, tikvrpc.CmdBatchGet, tikvrpc.CmdCop, tikvrpc.CmdCopStream, tikvrpc.CmdBatchCop, tikvrpc.CmdScanLock: + readTS = req.GetStartTS() + + // TODO: Check transactional write requests that has implicit read. + // case tikvrpc.CmdPessimisticLock: + // readTS = req.PessimisticLock().GetForUpdateTs() + // case tikvrpc.CmdPrewrite: + // inner := req.Prewrite() + // readTS = inner.GetForUpdateTs() + // if readTS == 0 { + // readTS = inner.GetStartVersion() + // } + // case tikvrpc.CmdCheckTxnStatus: + // inner := req.CheckTxnStatus() + // // TiKV uses the greater one of these three fields to update the max_ts. + // readTS = inner.GetLockTs() + // if inner.GetCurrentTs() != math.MaxUint64 && inner.GetCurrentTs() > readTS { + // readTS = inner.GetCurrentTs() + // } + // if inner.GetCallerStartTs() != math.MaxUint64 && inner.GetCallerStartTs() > readTS { + // readTS = inner.GetCallerStartTs() + // } + // case tikvrpc.CmdCheckSecondaryLocks, tikvrpc.CmdCleanup, tikvrpc.CmdBatchRollback: + // readTS = req.GetStartTS() + default: + return nil + } + return s.readTSValidator.ValidateReadTS(ctx, readTS, req.StaleRead, &oracle.Option{TxnScope: req.TxnScope}) +} + type staleReadMetricsCollector struct { } diff --git a/internal/locate/region_request3_test.go b/internal/locate/region_request3_test.go index 38bd4ae4ba..39266b0f68 100644 --- a/internal/locate/region_request3_test.go +++ b/internal/locate/region_request3_test.go @@ -42,6 +42,7 @@ import ( "time" "unsafe" + "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/errorpb" "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/kvproto/pkg/metapb" @@ -80,7 +81,9 @@ func (s *testRegionRequestToThreeStoresSuite) SetupTest() { s.cache = NewRegionCache(pdCli) s.bo = retry.NewNoopBackoff(context.Background()) client := mocktikv.NewRPCClient(s.cluster, s.mvccStore, nil) - s.regionRequestSender = NewRegionRequestSender(s.cache, client) + s.regionRequestSender = NewRegionRequestSender(s.cache, client, oracle.NoopReadTSValidator{}) + + s.NoError(failpoint.Enable("tikvclient/doNotRecoverStoreHealthCheckPanic", "return")) } func (s *testRegionRequestToThreeStoresSuite) TearDownTest() { @@ -145,7 +148,8 @@ func (s *testRegionRequestToThreeStoresSuite) loadAndGetLeaderStore() (*Store, s } func (s *testRegionRequestToThreeStoresSuite) TestForwarding() { - s.regionRequestSender.regionCache.enableForwarding = true + sender := NewRegionRequestSender(s.cache, s.regionRequestSender.client, oracle.NoopReadTSValidator{}) + sender.regionCache.enableForwarding = true // First get the leader's addr from region cache leaderStore, leaderAddr := s.loadAndGetLeaderStore() diff --git a/internal/locate/region_request_state_test.go b/internal/locate/region_request_state_test.go index 49f5a4af30..a2b408790d 100644 --- a/internal/locate/region_request_state_test.go +++ b/internal/locate/region_request_state_test.go @@ -33,6 +33,7 @@ import ( "github.com/tikv/client-go/v2/internal/retry" "github.com/tikv/client-go/v2/kv" "github.com/tikv/client-go/v2/metrics" + "github.com/tikv/client-go/v2/oracle" "github.com/tikv/client-go/v2/tikvrpc" ) @@ -75,7 +76,7 @@ func (s *testRegionCacheStaleReadSuite) SetupTest() { s.cache = NewRegionCache(pdCli) s.bo = retry.NewNoopBackoff(context.Background()) client := mocktikv.NewRPCClient(s.cluster, s.mvccStore, nil) - s.regionRequestSender = NewRegionRequestSender(s.cache, client) + s.regionRequestSender = NewRegionRequestSender(s.cache, client, oracle.NoopReadTSValidator{}) s.setClient() s.injection = testRegionCacheFSMSuiteInjection{ unavailableStoreIDs: make(map[uint64]struct{}), diff --git a/internal/locate/region_request_test.go b/internal/locate/region_request_test.go index 812e945c99..14b85ed49d 100644 --- a/internal/locate/region_request_test.go +++ b/internal/locate/region_request_test.go @@ -58,7 +58,9 @@ import ( "github.com/tikv/client-go/v2/internal/client/mockserver" "github.com/tikv/client-go/v2/internal/mockstore/mocktikv" "github.com/tikv/client-go/v2/internal/retry" + "github.com/tikv/client-go/v2/oracle" "github.com/tikv/client-go/v2/tikvrpc" + pd "github.com/tikv/pd/client" "google.golang.org/grpc" ) @@ -72,6 +74,7 @@ type testRegionRequestToSingleStoreSuite struct { store uint64 peer uint64 region uint64 + pdCli pd.Client cache *RegionCache bo *retry.Backoffer regionRequestSender *RegionRequestSender @@ -82,11 +85,11 @@ func (s *testRegionRequestToSingleStoreSuite) SetupTest() { s.mvccStore = mocktikv.MustNewMVCCStore() s.cluster = mocktikv.NewCluster(s.mvccStore) s.store, s.peer, s.region = mocktikv.BootstrapWithSingleStore(s.cluster) - pdCli := &CodecPDClient{mocktikv.NewPDClient(s.cluster)} - s.cache = NewRegionCache(pdCli) + s.pdCli = &CodecPDClient{mocktikv.NewPDClient(s.cluster)} + s.cache = NewRegionCache(s.pdCli) s.bo = retry.NewNoopBackoff(context.Background()) client := mocktikv.NewRPCClient(s.cluster, s.mvccStore, nil) - s.regionRequestSender = NewRegionRequestSender(s.cache, client) + s.regionRequestSender = NewRegionRequestSender(s.cache, client, oracle.NoopReadTSValidator{}) } func (s *testRegionRequestToSingleStoreSuite) TearDownTest() { @@ -497,7 +500,7 @@ func (s *testRegionRequestToSingleStoreSuite) TestNoReloadRegionForGrpcWhenCtxCa }() cli := client.NewRPCClient() - sender := NewRegionRequestSender(s.cache, cli) + sender := NewRegionRequestSender(s.cache, cli, oracle.NoopReadTSValidator{}) req := tikvrpc.NewRequest(tikvrpc.CmdRawPut, &kvrpcpb.RawPutRequest{ Key: []byte("key"), Value: []byte("value"), @@ -516,7 +519,7 @@ func (s *testRegionRequestToSingleStoreSuite) TestNoReloadRegionForGrpcWhenCtxCa Client: client.NewRPCClient(), redirectAddr: addr, } - sender = NewRegionRequestSender(s.cache, client1) + sender = NewRegionRequestSender(s.cache, client1, oracle.NoopReadTSValidator{}) sender.SendReq(s.bo, req, region.Region, 3*time.Second) // cleanup @@ -702,7 +705,7 @@ func (s *testRegionRequestToSingleStoreSuite) TestBatchClientSendLoopPanic() { cancel() }() req := tikvrpc.NewRequest(tikvrpc.CmdCop, &coprocessor.Request{Data: []byte("a"), StartTs: 1}) - regionRequestSender := NewRegionRequestSender(s.cache, fnClient) + regionRequestSender := NewRegionRequestSender(s.cache, fnClient, oracle.NoopReadTSValidator{}) regionRequestSender.regionCache.testingKnobs.mockRequestLiveness.Store((*livenessFunc)(&tf)) regionRequestSender.SendReq(bo, req, region.Region, client.ReadTimeoutShort) } diff --git a/oracle/oracle.go b/oracle/oracle.go index 7ace335ec0..ad79488ed0 100644 --- a/oracle/oracle.go +++ b/oracle/oracle.go @@ -36,6 +36,7 @@ package oracle import ( "context" + "fmt" "time" ) @@ -64,12 +65,17 @@ type Oracle interface { GetExternalTimestamp(ctx context.Context) (uint64, error) SetExternalTimestamp(ctx context.Context, ts uint64) error - // ValidateSnapshotReadTS verifies whether it can be guaranteed that the given readTS doesn't exceed the maximum ts - // that has been allocated by the oracle, so that it's safe to use this ts to perform snapshot read, stale read, - // etc. + ReadTSValidator +} + +// ReadTSValidator is the interface for providing the ability for verifying whether a timestamp is safe to be used +// for readings, as part of the `Oracle` interface. +type ReadTSValidator interface { + // ValidateReadTS verifies whether it can be guaranteed that the given readTS doesn't exceed the maximum ts + // that has been allocated by the oracle, so that it's safe to use this ts to perform read operations. // Note that this method only checks the ts from the oracle's perspective. It doesn't check whether the snapshot // has been GCed. - ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *Option) error + ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *Option) error } // Future is a future which promises to return a timestamp. @@ -121,3 +127,30 @@ func GoTimeToTS(t time.Time) uint64 { func GoTimeToLowerLimitStartTS(now time.Time, maxTxnTimeUse int64) uint64 { return GoTimeToTS(now.Add(-time.Duration(maxTxnTimeUse) * time.Millisecond)) } + +// NoopReadTSValidator is a dummy implementation of ReadTSValidator that always let the validation pass. +// Only use this when using RPCs that are not related to ts (e.g. rawkv), or in tests where `Oracle` is not available +// and the validation is not necessary. +type NoopReadTSValidator struct{} + +// ValidateReadTS implements the ReadTSValidator interface. +func (NoopReadTSValidator) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *Option) error { + return nil +} + +// ErrFutureTSRead is returned when the read timestamp is set to a future time. +type ErrFutureTSRead struct { + ReadTS uint64 + CurrentTS uint64 +} + +func (e ErrFutureTSRead) Error() string { + return fmt.Sprintf("cannot set read timestamp to a future time, readTS: %d, currentTS: %d", e.ReadTS, e.CurrentTS) +} + +// ErrLatestStaleRead is returned when the read timestamp is set to max uint64 for stale read. +type ErrLatestStaleRead struct{} + +func (ErrLatestStaleRead) Error() string { + return "cannot set read ts to max uint64 for stale read" +} diff --git a/oracle/oracles/export_test.go b/oracle/oracles/export_test.go index f6479d5555..8b5e1d6337 100644 --- a/oracle/oracles/export_test.go +++ b/oracle/oracles/export_test.go @@ -62,9 +62,9 @@ func NewEmptyPDOracle() oracle.Oracle { func SetEmptyPDOracleLastTs(oc oracle.Oracle, ts uint64) { switch o := oc.(type) { case *pdOracle: - now := &lastTSO{ts, ts} + now := &lastTSO{ts, oracle.GetTimeFromTS(ts)} lastTSInterface, _ := o.lastTSMap.LoadOrStore(oracle.GlobalTxnScope, NewLastTSOPointer(now)) lastTSPointer := lastTSInterface.(*lastTSOPointer) - lastTSPointer.store(&lastTSO{tso: ts, arrival: ts}) + lastTSPointer.store(&lastTSO{tso: ts, arrival: oracle.GetTimeFromTS(ts)}) } } diff --git a/oracle/oracles/local.go b/oracle/oracles/local.go index 1e6b747c98..e916286ac3 100644 --- a/oracle/oracles/local.go +++ b/oracle/oracles/local.go @@ -36,6 +36,7 @@ package oracles import ( "context" + "math" "sync" "time" @@ -136,13 +137,23 @@ func (l *localOracle) GetExternalTimestamp(ctx context.Context) (uint64, error) return l.getExternalTimestamp(ctx) } -func (l *localOracle) ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *oracle.Option) error { +func (l *localOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *oracle.Option) error { + if readTS == math.MaxUint64 { + if isStaleRead { + return oracle.ErrLatestStaleRead{} + } + return nil + } + currentTS, err := l.GetTimestamp(ctx, opt) if err != nil { return errors.Errorf("fail to validate read timestamp: %v", err) } if currentTS < readTS { - return errors.Errorf("cannot set read timestamp to a future time") + return oracle.ErrFutureTSRead{ + ReadTS: readTS, + CurrentTS: currentTS, + } } return nil } diff --git a/oracle/oracles/mock.go b/oracle/oracles/mock.go index 5d01e26754..97042e24df 100644 --- a/oracle/oracles/mock.go +++ b/oracle/oracles/mock.go @@ -36,6 +36,7 @@ package oracles import ( "context" + "math" "sync" "time" @@ -122,14 +123,24 @@ func (o *MockOracle) GetLowResolutionTimestampAsync(ctx context.Context, opt *or return o.GetTimestampAsync(ctx, opt) } -// ValidateSnapshotReadTS implements oracle.Oracle interface. -func (o *MockOracle) ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *oracle.Option) error { +// ValidateReadTS implements ReadTSValidator interface. +func (o *MockOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *oracle.Option) error { + if readTS == math.MaxUint64 { + if isStaleRead { + return oracle.ErrLatestStaleRead{} + } + return nil + } + currentTS, err := o.GetTimestamp(ctx, opt) if err != nil { return errors.Errorf("fail to validate read timestamp: %v", err) } if currentTS < readTS { - return errors.Errorf("cannot set read timestamp to a future time") + return oracle.ErrFutureTSRead{ + ReadTS: readTS, + CurrentTS: currentTS, + } } return nil } diff --git a/oracle/oracles/pd.go b/oracle/oracles/pd.go index c069967e03..e8fb75c502 100644 --- a/oracle/oracles/pd.go +++ b/oracle/oracles/pd.go @@ -37,6 +37,7 @@ package oracles import ( "context" "fmt" + "math" "strings" "sync" stdatomic "sync/atomic" @@ -151,7 +152,7 @@ type pdOracle struct { // When the low resolution ts is not new enough and there are many concurrent stane read / snapshot read // operations that needs to validate the read ts, we can use this to avoid too many concurrent GetTS calls by - // reusing a result for different `ValidateSnapshotReadTS` calls. This can be done because that + // reusing a result for different `ValidateReadTS` calls. This can be done because that // we don't require the ts for validation to be strictly the latest one. // Note that the result can't be reused for different txnScopes. The txnScope is used as the key. tsForValidation singleflight.Group @@ -160,7 +161,7 @@ type pdOracle struct { // lastTSO stores the last timestamp oracle gets from PD server and the local time when the TSO is fetched. type lastTSO struct { tso uint64 - arrival uint64 + arrival time.Time } // lastTSOPointer wrap the lastTSO struct into a pointer. @@ -299,17 +300,13 @@ func (o *pdOracle) getTimestamp(ctx context.Context, txnScope string) (uint64, e return oracle.ComposeTS(physical, logical), nil } -func (o *pdOracle) getArrivalTimestamp() uint64 { - return oracle.GoTimeToTS(time.Now()) -} - func (o *pdOracle) setLastTS(ts uint64, txnScope string) { if txnScope == "" { txnScope = oracle.GlobalTxnScope } current := &lastTSO{ tso: ts, - arrival: o.getArrivalTimestamp(), + arrival: time.Now(), } lastTSInterface, ok := o.lastTSMap.Load(txnScope) if !ok { @@ -320,9 +317,12 @@ func (o *pdOracle) setLastTS(ts uint64, txnScope string) { lastTSPointer := lastTSInterface.(*lastTSOPointer) for { last := lastTSPointer.load() - if current.tso <= last.tso || current.arrival <= last.arrival { + if current.tso <= last.tso { return } + if last.arrival.After(current.arrival) { + current.arrival = last.arrival + } if lastTSPointer.compareAndSwap(last, current) { return } @@ -587,8 +587,11 @@ func (o *pdOracle) getStaleTimestamp(txnScope string, prevSecond uint64) (uint64 if !ok { return 0, errors.Errorf("get stale timestamp fail, txnScope: %s", txnScope) } - ts, arrivalTS := last.tso, last.arrival - arrivalTime := oracle.GetTimeFromTS(arrivalTS) + return o.getStaleTimestampWithLastTS(last, prevSecond) +} + +func (o *pdOracle) getStaleTimestampWithLastTS(last *lastTSO, prevSecond uint64) (uint64, error) { + ts, arrivalTime := last.tso, last.arrival physicalTime := oracle.GetTimeFromTS(ts) if uint64(physicalTime.Unix()) <= prevSecond { return 0, errors.Errorf("invalid prevSecond %v", prevSecond) @@ -643,22 +646,34 @@ func (o *pdOracle) getCurrentTSForValidation(ctx context.Context, opt *oracle.Op } } -func (o *pdOracle) ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *oracle.Option) error { - latestTS, err := o.GetLowResolutionTimestamp(ctx, opt) - // If we fail to get latestTS or the readTS exceeds it, get a timestamp from PD to double-check. +func (o *pdOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *oracle.Option) (errRet error) { + if readTS == math.MaxUint64 { + if isStaleRead { + return oracle.ErrLatestStaleRead{} + } + return nil + } + + latestTSInfo, exists := o.getLastTSWithArrivalTS(opt.TxnScope) + // If we fail to get latestTSInfo or the readTS exceeds it, get a timestamp from PD to double-check. // But we don't need to strictly fetch the latest TS. So if there are already concurrent calls to this function // loading the latest TS, we can just reuse the same result to avoid too many concurrent GetTS calls. - if err != nil || readTS > latestTS { + if !exists || readTS > latestTSInfo.tso { currentTS, err := o.getCurrentTSForValidation(ctx, opt) if err != nil { return errors.Errorf("fail to validate read timestamp: %v", err) } - o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS, currentTS, time.Now()) + if isStaleRead { + o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS, currentTS, time.Now()) + } if readTS > currentTS { - return errors.Errorf("cannot set read timestamp to a future time") + return oracle.ErrFutureTSRead{ + ReadTS: readTS, + CurrentTS: currentTS, + } } - } else { - estimatedCurrentTS, err := o.getStaleTimestamp(opt.TxnScope, 0) + } else if isStaleRead { + estimatedCurrentTS, err := o.getStaleTimestampWithLastTS(latestTSInfo, 0) if err != nil { logutil.Logger(ctx).Warn("failed to estimate current ts by getSlateTimestamp for auto-adjusting update low resolution ts interval", zap.Error(err), zap.Uint64("readTS", readTS), zap.String("txnScope", opt.TxnScope)) @@ -669,6 +684,9 @@ func (o *pdOracle) ValidateSnapshotReadTS(ctx context.Context, readTS uint64, op return nil } +// adjustUpdateLowResolutionTSIntervalWithRequestedStaleness triggers adjustments the update interval of low resolution +// ts, if necessary, to suite the usage of stale read. +// This method is not supposed to be called when performing non-stale-read operations. func (o *pdOracle) adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS uint64, currentTS uint64, now time.Time) { requiredStaleness := oracle.GetTimeFromTS(currentTS).Sub(oracle.GetTimeFromTS(readTS)) diff --git a/oracle/oracles/pd_test.go b/oracle/oracles/pd_test.go index 72b6dfab23..27ec8991a9 100644 --- a/oracle/oracles/pd_test.go +++ b/oracle/oracles/pd_test.go @@ -237,40 +237,54 @@ func TestAdaptiveUpdateTSInterval(t *testing.T) { assert.Equal(t, adaptiveUpdateTSIntervalStateNormal, o.adaptiveUpdateIntervalState.state) } -func TestValidateSnapshotReadTS(t *testing.T) { - pdClient := MockPdClient{} - o, err := NewPdOracle(&pdClient, &PDOracleOptions{ - UpdateInterval: time.Second * 2, - }) - assert.NoError(t, err) - defer o.Close() - - ctx := context.Background() - opt := &oracle.Option{TxnScope: oracle.GlobalTxnScope} - ts, err := o.GetTimestamp(ctx, opt) - assert.NoError(t, err) - assert.GreaterOrEqual(t, ts, uint64(1)) +func TestValidateReadTS(t *testing.T) { + testImpl := func(staleRead bool) { + pdClient := MockPdClient{} + o, err := NewPdOracle(&pdClient, &PDOracleOptions{ + UpdateInterval: time.Second * 2, + }) + assert.NoError(t, err) + defer o.Close() + + ctx := context.Background() + opt := &oracle.Option{TxnScope: oracle.GlobalTxnScope} + + // Always returns error for MaxUint64 + err = o.ValidateReadTS(ctx, math.MaxUint64, staleRead, opt) + if staleRead { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } - err = o.ValidateSnapshotReadTS(ctx, 1, opt) - assert.NoError(t, err) - ts, err = o.GetTimestamp(ctx, opt) - assert.NoError(t, err) - // The readTS exceeds the latest ts, so it first fails the check with the low resolution ts. Then it fallbacks to - // the fetching-from-PD path, and it can get the previous ts + 1, which can allow this validation to pass. - err = o.ValidateSnapshotReadTS(ctx, ts+1, opt) - assert.NoError(t, err) - // It can't pass if the readTS is newer than previous ts + 2. - ts, err = o.GetTimestamp(ctx, opt) - assert.NoError(t, err) - err = o.ValidateSnapshotReadTS(ctx, ts+2, opt) - assert.Error(t, err) + ts, err := o.GetTimestamp(ctx, opt) + assert.NoError(t, err) + assert.GreaterOrEqual(t, ts, uint64(1)) + + err = o.ValidateReadTS(ctx, 1, staleRead, opt) + assert.NoError(t, err) + ts, err = o.GetTimestamp(ctx, opt) + assert.NoError(t, err) + // The readTS exceeds the latest ts, so it first fails the check with the low resolution ts. Then it fallbacks to + // the fetching-from-PD path, and it can get the previous ts + 1, which can allow this validation to pass. + err = o.ValidateReadTS(ctx, ts+1, staleRead, opt) + assert.NoError(t, err) + // It can't pass if the readTS is newer than previous ts + 2. + ts, err = o.GetTimestamp(ctx, opt) + assert.NoError(t, err) + err = o.ValidateReadTS(ctx, ts+2, staleRead, opt) + assert.Error(t, err) + + // Simulate other PD clients requests a timestamp. + ts, err = o.GetTimestamp(ctx, opt) + assert.NoError(t, err) + pdClient.logicalTimestamp.Add(2) + err = o.ValidateReadTS(ctx, ts+3, staleRead, opt) + assert.NoError(t, err) + } - // Simulate other PD clients requests a timestamp. - ts, err = o.GetTimestamp(ctx, opt) - assert.NoError(t, err) - pdClient.logicalTimestamp.Add(2) - err = o.ValidateSnapshotReadTS(ctx, ts+3, opt) - assert.NoError(t, err) + testImpl(true) + testImpl(false) } type MockPDClientWithPause struct { @@ -292,7 +306,7 @@ func (c *MockPDClientWithPause) Resume() { c.mu.Unlock() } -func TestValidateSnapshotReadTSReusingGetTSResult(t *testing.T) { +func TestValidateReadTSForStaleReadReusingGetTSResult(t *testing.T) { pdClient := &MockPDClientWithPause{} o, err := NewPdOracle(pdClient, &PDOracleOptions{ UpdateInterval: time.Second * 2, @@ -304,7 +318,7 @@ func TestValidateSnapshotReadTSReusingGetTSResult(t *testing.T) { asyncValidate := func(ctx context.Context, readTS uint64) chan error { ch := make(chan error, 1) go func() { - err := o.ValidateSnapshotReadTS(ctx, readTS, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) + err := o.ValidateReadTS(ctx, readTS, true, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) ch <- err }() return ch @@ -313,7 +327,7 @@ func TestValidateSnapshotReadTSReusingGetTSResult(t *testing.T) { noResult := func(ch chan error) { select { case <-ch: - assert.FailNow(t, "a ValidateSnapshotReadTS operation is not blocked while it's expected to be blocked") + assert.FailNow(t, "a ValidateReadTS operation is not blocked while it's expected to be blocked") default: } } @@ -390,3 +404,79 @@ func TestValidateSnapshotReadTSReusingGetTSResult(t *testing.T) { } } } + +func TestValidateReadTSForNormalReadDoNotAffectUpdateInterval(t *testing.T) { + oracleInterface, err := NewPdOracle(&MockPdClient{}, &PDOracleOptions{ + UpdateInterval: time.Second * 2, + NoUpdateTS: true, + }) + assert.NoError(t, err) + o := oracleInterface.(*pdOracle) + defer o.Close() + + ctx := context.Background() + opt := &oracle.Option{TxnScope: oracle.GlobalTxnScope} + + // Validating read ts for non-stale-read requests must not trigger updating the adaptive update interval of + // low resolution ts. + mustNoNotify := func() { + select { + case <-o.adaptiveUpdateIntervalState.shrinkIntervalCh: + assert.Fail(t, "expects not notifying shrinking update interval immediately, but message was received") + default: + } + } + + ts, err := o.GetTimestamp(ctx, opt) + assert.NoError(t, err) + assert.GreaterOrEqual(t, ts, uint64(1)) + + err = o.ValidateReadTS(ctx, ts, false, opt) + assert.NoError(t, err) + mustNoNotify() + + // It loads `ts + 1` from the mock PD, and the check cannot pass. + err = o.ValidateReadTS(ctx, ts+2, false, opt) + assert.Error(t, err) + mustNoNotify() + + // Do the check again. It loads `ts + 2` from the mock PD, and the check passes. + err = o.ValidateReadTS(ctx, ts+2, false, opt) + assert.NoError(t, err) + mustNoNotify() +} + +func TestSetLastTSAlwaysPushTS(t *testing.T) { + oracleInterface, err := NewPdOracle(&MockPdClient{}, &PDOracleOptions{ + UpdateInterval: time.Second * 2, + NoUpdateTS: true, + }) + assert.NoError(t, err) + o := oracleInterface.(*pdOracle) + defer o.Close() + + var wg sync.WaitGroup + cancel := make(chan struct{}) + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + ctx := context.Background() + for { + select { + case <-cancel: + return + default: + } + ts, err := o.GetTimestamp(ctx, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) + assert.NoError(t, err) + lastTS, found := o.getLastTS(oracle.GlobalTxnScope) + assert.True(t, found) + assert.GreaterOrEqual(t, lastTS, ts) + } + }() + } + time.Sleep(time.Second) + close(cancel) + wg.Wait() +} diff --git a/rawkv/rawkv.go b/rawkv/rawkv.go index ed8a6092e3..dbc4a48bfc 100644 --- a/rawkv/rawkv.go +++ b/rawkv/rawkv.go @@ -48,6 +48,7 @@ import ( "github.com/tikv/client-go/v2/internal/locate" "github.com/tikv/client-go/v2/internal/retry" "github.com/tikv/client-go/v2/metrics" + "github.com/tikv/client-go/v2/oracle" "github.com/tikv/client-go/v2/tikvrpc" pd "github.com/tikv/pd/client" "google.golang.org/grpc" @@ -658,7 +659,7 @@ func (c *Client) CompareAndSwap(ctx context.Context, key, previousValue, newValu func (c *Client) sendReq(ctx context.Context, key []byte, req *tikvrpc.Request, reverse bool) (*tikvrpc.Response, *locate.KeyLocation, error) { bo := retry.NewBackofferWithVars(ctx, rawkvMaxBackoff, nil) - sender := locate.NewRegionRequestSender(c.regionCache, c.rpcClient) + sender := locate.NewRegionRequestSender(c.regionCache, c.rpcClient, oracle.NoopReadTSValidator{}) for { var loc *locate.KeyLocation var err error @@ -752,7 +753,7 @@ func (c *Client) doBatchReq(bo *retry.Backoffer, batch kvrpc.Batch, options *raw }) } - sender := locate.NewRegionRequestSender(c.regionCache, c.rpcClient) + sender := locate.NewRegionRequestSender(c.regionCache, c.rpcClient, oracle.NoopReadTSValidator{}) req.MaxExecutionDurationMs = uint64(client.MaxWriteExecutionTime.Milliseconds()) resp, err := sender.SendReq(bo, req, batch.RegionID, client.ReadTimeoutShort) @@ -802,7 +803,7 @@ func (c *Client) doBatchReq(bo *retry.Backoffer, batch kvrpc.Batch, options *raw // TODO: Is there any better way to avoid duplicating code with func `sendReq` ? func (c *Client) sendDeleteRangeReq(ctx context.Context, startKey []byte, endKey []byte, opts *rawOptions) (*tikvrpc.Response, []byte, error) { bo := retry.NewBackofferWithVars(ctx, rawkvMaxBackoff, nil) - sender := locate.NewRegionRequestSender(c.regionCache, c.rpcClient) + sender := locate.NewRegionRequestSender(c.regionCache, c.rpcClient, oracle.NoopReadTSValidator{}) for { loc, err := c.regionCache.LocateKey(bo, startKey) if err != nil { @@ -900,7 +901,7 @@ func (c *Client) doBatchPut(bo *retry.Backoffer, batch kvrpc.Batch, opts *rawOpt Ttl: ttl, }) - sender := locate.NewRegionRequestSender(c.regionCache, c.rpcClient) + sender := locate.NewRegionRequestSender(c.regionCache, c.rpcClient, oracle.NoopReadTSValidator{}) req.MaxExecutionDurationMs = uint64(client.MaxWriteExecutionTime.Milliseconds()) req.ApiVersion = c.apiVersion resp, err := sender.SendReq(bo, req, batch.RegionID, client.ReadTimeoutShort) diff --git a/tikv/kv.go b/tikv/kv.go index 019a6c2ae8..6b119c1c5d 100644 --- a/tikv/kv.go +++ b/tikv/kv.go @@ -419,7 +419,7 @@ func (s *KVStore) SupportDeleteRange() (supported bool) { // SendReq sends a request to locate. func (s *KVStore) SendReq(bo *Backoffer, req *tikvrpc.Request, regionID locate.RegionVerID, timeout time.Duration) (*tikvrpc.Response, error) { - sender := locate.NewRegionRequestSender(s.regionCache, s.GetTiKVClient()) + sender := locate.NewRegionRequestSender(s.regionCache, s.GetTiKVClient(), s.oracle) return sender.SendReq(bo, req, regionID, timeout) } diff --git a/tikv/region.go b/tikv/region.go index 4fd0a25642..9303a02a89 100644 --- a/tikv/region.go +++ b/tikv/region.go @@ -40,6 +40,7 @@ import ( "github.com/pingcap/kvproto/pkg/metapb" "github.com/tikv/client-go/v2/internal/client" "github.com/tikv/client-go/v2/internal/locate" + "github.com/tikv/client-go/v2/oracle" "github.com/tikv/client-go/v2/tikvrpc" pd "github.com/tikv/pd/client" ) @@ -138,8 +139,8 @@ func GetStoreTypeByMeta(store *metapb.Store) tikvrpc.EndpointType { } // NewRegionRequestSender creates a new sender. -func NewRegionRequestSender(regionCache *RegionCache, client client.Client) *RegionRequestSender { - return locate.NewRegionRequestSender(regionCache, client) +func NewRegionRequestSender(regionCache *RegionCache, client client.Client, readTSValidator oracle.ReadTSValidator) *RegionRequestSender { + return locate.NewRegionRequestSender(regionCache, client, readTSValidator) } // LoadShuttingDown atomically loads ShuttingDown. diff --git a/tikv/split_region.go b/tikv/split_region.go index 413747a911..28701d6a00 100644 --- a/tikv/split_region.go +++ b/tikv/split_region.go @@ -148,7 +148,7 @@ func (s *KVStore) batchSendSingleRegion(bo *Backoffer, batch kvrpc.Batch, scatte RequestSource: util.RequestSourceFromCtx(bo.GetCtx()), }) - sender := locate.NewRegionRequestSender(s.regionCache, s.GetTiKVClient()) + sender := locate.NewRegionRequestSender(s.regionCache, s.GetTiKVClient(), s.oracle) resp, err := sender.SendReq(bo, req, batch.RegionID, client.ReadTimeoutShort) batchResp := kvrpc.BatchResult{Response: resp} diff --git a/txnkv/transaction/commit.go b/txnkv/transaction/commit.go index 5dfff0e03f..4865f3bbcf 100644 --- a/txnkv/transaction/commit.go +++ b/txnkv/transaction/commit.go @@ -86,7 +86,7 @@ func (actionCommit) handleSingleBatch(c *twoPhaseCommitter, bo *retry.Backoffer, tBegin := time.Now() attempts := 0 - sender := locate.NewRegionRequestSender(c.store.GetRegionCache(), c.store.GetTiKVClient()) + sender := locate.NewRegionRequestSender(c.store.GetRegionCache(), c.store.GetTiKVClient(), c.store.GetOracle()) for { attempts++ reqBegin := time.Now() diff --git a/txnkv/transaction/pessimistic.go b/txnkv/transaction/pessimistic.go index e0eb669a11..9d1620a109 100644 --- a/txnkv/transaction/pessimistic.go +++ b/txnkv/transaction/pessimistic.go @@ -166,7 +166,7 @@ func (action actionPessimisticLock) handleSingleBatch(c *twoPhaseCommitter, bo * time.Sleep(300 * time.Millisecond) return errors.WithStack(&tikverr.ErrWriteConflict{WriteConflict: nil}) } - sender := locate.NewRegionRequestSender(c.store.GetRegionCache(), c.store.GetTiKVClient()) + sender := locate.NewRegionRequestSender(c.store.GetRegionCache(), c.store.GetTiKVClient(), c.store.GetOracle()) startTime := time.Now() resp, err := sender.SendReq(bo, req, batch.region, client.ReadTimeoutShort) diagCtx.reqDuration = time.Since(startTime) diff --git a/txnkv/transaction/prewrite.go b/txnkv/transaction/prewrite.go index 8b992cf080..feb4882366 100644 --- a/txnkv/transaction/prewrite.go +++ b/txnkv/transaction/prewrite.go @@ -237,7 +237,7 @@ func (action actionPrewrite) handleSingleBatch(c *twoPhaseCommitter, bo *retry.B attempts := 0 req := c.buildPrewriteRequest(batch, txnSize) - sender := locate.NewRegionRequestSender(c.store.GetRegionCache(), c.store.GetTiKVClient()) + sender := locate.NewRegionRequestSender(c.store.GetRegionCache(), c.store.GetTiKVClient(), c.store.GetOracle()) var resolvingRecordToken *int defer func() { if err != nil { diff --git a/txnkv/txnsnapshot/client_helper.go b/txnkv/txnsnapshot/client_helper.go index 34a6636d58..7c3765f3d5 100644 --- a/txnkv/txnsnapshot/client_helper.go +++ b/txnkv/txnsnapshot/client_helper.go @@ -40,6 +40,7 @@ import ( "github.com/tikv/client-go/v2/internal/client" "github.com/tikv/client-go/v2/internal/locate" "github.com/tikv/client-go/v2/internal/retry" + "github.com/tikv/client-go/v2/oracle" "github.com/tikv/client-go/v2/tikvrpc" "github.com/tikv/client-go/v2/txnkv/txnlock" "github.com/tikv/client-go/v2/util" @@ -63,6 +64,7 @@ type ClientHelper struct { client client.Client resolveLite bool locate.RegionRequestRuntimeStats + oracle oracle.Oracle } // NewClientHelper creates a helper instance. @@ -74,6 +76,7 @@ func NewClientHelper(store kvstore, resolvedLocks *util.TSSet, committedLocks *u committedLocks: committedLocks, client: store.GetTiKVClient(), resolveLite: resolveLite, + oracle: store.GetOracle(), } } @@ -136,7 +139,7 @@ func (ch *ClientHelper) ResolveLocksDone(callerStartTS uint64, token int) { // SendReqCtx wraps the SendReqCtx function and use the resolved lock result in the kvrpcpb.Context. func (ch *ClientHelper) SendReqCtx(bo *retry.Backoffer, req *tikvrpc.Request, regionID locate.RegionVerID, timeout time.Duration, et tikvrpc.EndpointType, directStoreAddr string, opts ...locate.StoreSelectorOption) (*tikvrpc.Response, *locate.RPCContext, string, error) { - sender := locate.NewRegionRequestSender(ch.regionCache, ch.client) + sender := locate.NewRegionRequestSender(ch.regionCache, ch.client, ch.oracle) if len(directStoreAddr) > 0 { sender.SetStoreAddr(directStoreAddr) } diff --git a/txnkv/txnsnapshot/scan.go b/txnkv/txnsnapshot/scan.go index 25579de4ca..0183cf5203 100644 --- a/txnkv/txnsnapshot/scan.go +++ b/txnkv/txnsnapshot/scan.go @@ -197,7 +197,7 @@ func (s *Scanner) getData(bo *retry.Backoffer) error { zap.String("nextEndKey", kv.StrKey(s.nextEndKey)), zap.Bool("reverse", s.reverse), zap.Uint64("txnStartTS", s.startTS())) - sender := locate.NewRegionRequestSender(s.snapshot.store.GetRegionCache(), s.snapshot.store.GetTiKVClient()) + sender := locate.NewRegionRequestSender(s.snapshot.store.GetRegionCache(), s.snapshot.store.GetTiKVClient(), s.snapshot.store.GetOracle()) var reqEndKey, reqStartKey []byte var loc *locate.KeyLocation var resolvingRecordToken *int