Skip to content

Commit

Permalink
Try to validate read ts for all RPC requests (#1513) (#1558)
Browse files Browse the repository at this point in the history
 

Signed-off-by: MyonKeminta <[email protected]>
Signed-off-by: you06 <[email protected]>
Signed-off-by: ekexium <[email protected]>

Co-authored-by: you06 <[email protected]>
Co-authored-by: MyonKeminta <[email protected]>
  • Loading branch information
3 people authored Jan 16, 2025
1 parent ccec7ef commit ec354dc
Show file tree
Hide file tree
Showing 20 changed files with 315 additions and 92 deletions.
5 changes: 3 additions & 2 deletions internal/locate/region_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ import (
"github.com/tikv/client-go/v2/internal/mockstore/mocktikv"
"github.com/tikv/client-go/v2/internal/retry"
"github.com/tikv/client-go/v2/kv"
"github.com/tikv/client-go/v2/oracle"
pd "github.com/tikv/pd/client"
)

Expand Down Expand Up @@ -1001,7 +1002,7 @@ func (s *testRegionCacheSuite) TestRegionEpochOnTiFlash() {
s.Equal(ctxTiFlash.Peer.Id, s.peer1)
ctxTiFlash.Peer.Role = metapb.PeerRole_Learner
r := ctxTiFlash.Meta
reqSend := NewRegionRequestSender(s.cache, nil)
reqSend := NewRegionRequestSender(s.cache, nil, oracle.NoopReadTSValidator{})
regionErr := &errorpb.Error{EpochNotMatch: &errorpb.EpochNotMatch{CurrentRegions: []*metapb.Region{r}}}
reqSend.onRegionError(s.bo, ctxTiFlash, nil, regionErr)

Expand Down Expand Up @@ -1601,7 +1602,7 @@ func (s *testRegionCacheSuite) TestShouldNotRetryFlashback() {
ctx, err := s.cache.GetTiKVRPCContext(retry.NewBackofferWithVars(context.Background(), 100, nil), loc.Region, kv.ReplicaReadLeader, 0)
s.NotNil(ctx)
s.NoError(err)
reqSend := NewRegionRequestSender(s.cache, nil)
reqSend := NewRegionRequestSender(s.cache, nil, oracle.NoopReadTSValidator{})
shouldRetry, err := reqSend.onRegionError(s.bo, ctx, nil, &errorpb.Error{FlashbackInProgress: &errorpb.FlashbackInProgress{}})
s.Error(err)
s.False(shouldRetry)
Expand Down
54 changes: 50 additions & 4 deletions internal/locate/region_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import (
"sync/atomic"
"time"

"github.com/tikv/client-go/v2/oracle"
"go.uber.org/zap"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
Expand Down Expand Up @@ -102,6 +103,7 @@ type RegionRequestSender struct {
regionCache *RegionCache
apiVersion kvrpcpb.APIVersion
client client.Client
readTSValidator oracle.ReadTSValidator
storeAddr string
rpcError error
replicaSelector *replicaSelector
Expand Down Expand Up @@ -190,11 +192,12 @@ func RecordRegionRequestRuntimeStats(stats map[tikvrpc.CmdType]*RPCRuntimeStats,
}

// NewRegionRequestSender creates a new sender.
func NewRegionRequestSender(regionCache *RegionCache, client client.Client) *RegionRequestSender {
func NewRegionRequestSender(regionCache *RegionCache, client client.Client, readTSValidator oracle.ReadTSValidator) *RegionRequestSender {
return &RegionRequestSender{
regionCache: regionCache,
apiVersion: regionCache.apiVersion,
client: client,
regionCache: regionCache,
apiVersion: regionCache.apiVersion,
client: client,
readTSValidator: readTSValidator,
}
}

Expand Down Expand Up @@ -1088,6 +1091,11 @@ func (s *RegionRequestSender) SendReqCtx(
}
}

if err = s.validateReadTS(bo.GetCtx(), req); err != nil {
logutil.Logger(bo.GetCtx()).Error("validate read ts failed for request", zap.Stringer("reqType", req.Type), zap.Stringer("req", req.Req.(fmt.Stringer)), zap.Stringer("context", &req.Context), zap.Stack("stack"), zap.Error(err))
return nil, nil, err
}

// If the MaxExecutionDurationMs is not set yet, we set it to be the RPC timeout duration
// so TiKV can give up the requests whose response TiDB cannot receive due to timeout.
if req.Context.MaxExecutionDurationMs == 0 {
Expand Down Expand Up @@ -1870,6 +1878,44 @@ func (s *RegionRequestSender) onRegionError(bo *retry.Backoffer, ctx *RPCContext
return false, nil
}

func (s *RegionRequestSender) validateReadTS(ctx context.Context, req *tikvrpc.Request) error {
if req.StoreTp == tikvrpc.TiDB {
// Skip the checking if the store type is TiDB.
return nil
}

var readTS uint64
switch req.Type {
case tikvrpc.CmdGet, tikvrpc.CmdScan, tikvrpc.CmdBatchGet, tikvrpc.CmdCop, tikvrpc.CmdCopStream, tikvrpc.CmdBatchCop, tikvrpc.CmdScanLock:
readTS = req.GetStartTS()

// TODO: Check transactional write requests that has implicit read.
// case tikvrpc.CmdPessimisticLock:
// readTS = req.PessimisticLock().GetForUpdateTs()
// case tikvrpc.CmdPrewrite:
// inner := req.Prewrite()
// readTS = inner.GetForUpdateTs()
// if readTS == 0 {
// readTS = inner.GetStartVersion()
// }
// case tikvrpc.CmdCheckTxnStatus:
// inner := req.CheckTxnStatus()
// // TiKV uses the greater one of these three fields to update the max_ts.
// readTS = inner.GetLockTs()
// if inner.GetCurrentTs() != math.MaxUint64 && inner.GetCurrentTs() > readTS {
// readTS = inner.GetCurrentTs()
// }
// if inner.GetCallerStartTs() != math.MaxUint64 && inner.GetCallerStartTs() > readTS {
// readTS = inner.GetCallerStartTs()
// }
// case tikvrpc.CmdCheckSecondaryLocks, tikvrpc.CmdCleanup, tikvrpc.CmdBatchRollback:
// readTS = req.GetStartTS()
default:
return nil
}
return s.readTSValidator.ValidateReadTS(ctx, readTS, req.StaleRead, &oracle.Option{TxnScope: req.TxnScope})
}

type staleReadMetricsCollector struct {
}

Expand Down
8 changes: 6 additions & 2 deletions internal/locate/region_request3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import (
"time"
"unsafe"

"github.com/pingcap/failpoint"
"github.com/pingcap/kvproto/pkg/errorpb"
"github.com/pingcap/kvproto/pkg/kvrpcpb"
"github.com/pingcap/kvproto/pkg/metapb"
Expand Down Expand Up @@ -80,7 +81,9 @@ func (s *testRegionRequestToThreeStoresSuite) SetupTest() {
s.cache = NewRegionCache(pdCli)
s.bo = retry.NewNoopBackoff(context.Background())
client := mocktikv.NewRPCClient(s.cluster, s.mvccStore, nil)
s.regionRequestSender = NewRegionRequestSender(s.cache, client)
s.regionRequestSender = NewRegionRequestSender(s.cache, client, oracle.NoopReadTSValidator{})

s.NoError(failpoint.Enable("tikvclient/doNotRecoverStoreHealthCheckPanic", "return"))
}

func (s *testRegionRequestToThreeStoresSuite) TearDownTest() {
Expand Down Expand Up @@ -145,7 +148,8 @@ func (s *testRegionRequestToThreeStoresSuite) loadAndGetLeaderStore() (*Store, s
}

func (s *testRegionRequestToThreeStoresSuite) TestForwarding() {
s.regionRequestSender.regionCache.enableForwarding = true
sender := NewRegionRequestSender(s.cache, s.regionRequestSender.client, oracle.NoopReadTSValidator{})
sender.regionCache.enableForwarding = true

// First get the leader's addr from region cache
leaderStore, leaderAddr := s.loadAndGetLeaderStore()
Expand Down
3 changes: 2 additions & 1 deletion internal/locate/region_request_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/tikv/client-go/v2/internal/retry"
"github.com/tikv/client-go/v2/kv"
"github.com/tikv/client-go/v2/metrics"
"github.com/tikv/client-go/v2/oracle"
"github.com/tikv/client-go/v2/tikvrpc"
)

Expand Down Expand Up @@ -75,7 +76,7 @@ func (s *testRegionCacheStaleReadSuite) SetupTest() {
s.cache = NewRegionCache(pdCli)
s.bo = retry.NewNoopBackoff(context.Background())
client := mocktikv.NewRPCClient(s.cluster, s.mvccStore, nil)
s.regionRequestSender = NewRegionRequestSender(s.cache, client)
s.regionRequestSender = NewRegionRequestSender(s.cache, client, oracle.NoopReadTSValidator{})
s.setClient()
s.injection = testRegionCacheFSMSuiteInjection{
unavailableStoreIDs: make(map[uint64]struct{}),
Expand Down
15 changes: 9 additions & 6 deletions internal/locate/region_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ import (
"github.com/tikv/client-go/v2/internal/client/mockserver"
"github.com/tikv/client-go/v2/internal/mockstore/mocktikv"
"github.com/tikv/client-go/v2/internal/retry"
"github.com/tikv/client-go/v2/oracle"
"github.com/tikv/client-go/v2/tikvrpc"
pd "github.com/tikv/pd/client"
"google.golang.org/grpc"
)

Expand All @@ -72,6 +74,7 @@ type testRegionRequestToSingleStoreSuite struct {
store uint64
peer uint64
region uint64
pdCli pd.Client
cache *RegionCache
bo *retry.Backoffer
regionRequestSender *RegionRequestSender
Expand All @@ -82,11 +85,11 @@ func (s *testRegionRequestToSingleStoreSuite) SetupTest() {
s.mvccStore = mocktikv.MustNewMVCCStore()
s.cluster = mocktikv.NewCluster(s.mvccStore)
s.store, s.peer, s.region = mocktikv.BootstrapWithSingleStore(s.cluster)
pdCli := &CodecPDClient{mocktikv.NewPDClient(s.cluster)}
s.cache = NewRegionCache(pdCli)
s.pdCli = &CodecPDClient{mocktikv.NewPDClient(s.cluster)}
s.cache = NewRegionCache(s.pdCli)
s.bo = retry.NewNoopBackoff(context.Background())
client := mocktikv.NewRPCClient(s.cluster, s.mvccStore, nil)
s.regionRequestSender = NewRegionRequestSender(s.cache, client)
s.regionRequestSender = NewRegionRequestSender(s.cache, client, oracle.NoopReadTSValidator{})
}

func (s *testRegionRequestToSingleStoreSuite) TearDownTest() {
Expand Down Expand Up @@ -497,7 +500,7 @@ func (s *testRegionRequestToSingleStoreSuite) TestNoReloadRegionForGrpcWhenCtxCa
}()

cli := client.NewRPCClient()
sender := NewRegionRequestSender(s.cache, cli)
sender := NewRegionRequestSender(s.cache, cli, oracle.NoopReadTSValidator{})
req := tikvrpc.NewRequest(tikvrpc.CmdRawPut, &kvrpcpb.RawPutRequest{
Key: []byte("key"),
Value: []byte("value"),
Expand All @@ -516,7 +519,7 @@ func (s *testRegionRequestToSingleStoreSuite) TestNoReloadRegionForGrpcWhenCtxCa
Client: client.NewRPCClient(),
redirectAddr: addr,
}
sender = NewRegionRequestSender(s.cache, client1)
sender = NewRegionRequestSender(s.cache, client1, oracle.NoopReadTSValidator{})
sender.SendReq(s.bo, req, region.Region, 3*time.Second)

// cleanup
Expand Down Expand Up @@ -702,7 +705,7 @@ func (s *testRegionRequestToSingleStoreSuite) TestBatchClientSendLoopPanic() {
cancel()
}()
req := tikvrpc.NewRequest(tikvrpc.CmdCop, &coprocessor.Request{Data: []byte("a"), StartTs: 1})
regionRequestSender := NewRegionRequestSender(s.cache, fnClient)
regionRequestSender := NewRegionRequestSender(s.cache, fnClient, oracle.NoopReadTSValidator{})
regionRequestSender.regionCache.testingKnobs.mockRequestLiveness.Store((*livenessFunc)(&tf))
regionRequestSender.SendReq(bo, req, region.Region, client.ReadTimeoutShort)
}
Expand Down
41 changes: 37 additions & 4 deletions oracle/oracle.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ package oracle

import (
"context"
"fmt"
"time"
)

Expand Down Expand Up @@ -64,12 +65,17 @@ type Oracle interface {
GetExternalTimestamp(ctx context.Context) (uint64, error)
SetExternalTimestamp(ctx context.Context, ts uint64) error

// ValidateSnapshotReadTS verifies whether it can be guaranteed that the given readTS doesn't exceed the maximum ts
// that has been allocated by the oracle, so that it's safe to use this ts to perform snapshot read, stale read,
// etc.
ReadTSValidator
}

// ReadTSValidator is the interface for providing the ability for verifying whether a timestamp is safe to be used
// for readings, as part of the `Oracle` interface.
type ReadTSValidator interface {
// ValidateReadTS verifies whether it can be guaranteed that the given readTS doesn't exceed the maximum ts
// that has been allocated by the oracle, so that it's safe to use this ts to perform read operations.
// Note that this method only checks the ts from the oracle's perspective. It doesn't check whether the snapshot
// has been GCed.
ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *Option) error
ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *Option) error
}

// Future is a future which promises to return a timestamp.
Expand Down Expand Up @@ -121,3 +127,30 @@ func GoTimeToTS(t time.Time) uint64 {
func GoTimeToLowerLimitStartTS(now time.Time, maxTxnTimeUse int64) uint64 {
return GoTimeToTS(now.Add(-time.Duration(maxTxnTimeUse) * time.Millisecond))
}

// NoopReadTSValidator is a dummy implementation of ReadTSValidator that always let the validation pass.
// Only use this when using RPCs that are not related to ts (e.g. rawkv), or in tests where `Oracle` is not available
// and the validation is not necessary.
type NoopReadTSValidator struct{}

// ValidateReadTS implements the ReadTSValidator interface.
func (NoopReadTSValidator) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *Option) error {
return nil
}

// ErrFutureTSRead is returned when the read timestamp is set to a future time.
type ErrFutureTSRead struct {
ReadTS uint64
CurrentTS uint64
}

func (e ErrFutureTSRead) Error() string {
return fmt.Sprintf("cannot set read timestamp to a future time, readTS: %d, currentTS: %d", e.ReadTS, e.CurrentTS)
}

// ErrLatestStaleRead is returned when the read timestamp is set to max uint64 for stale read.
type ErrLatestStaleRead struct{}

func (ErrLatestStaleRead) Error() string {
return "cannot set read ts to max uint64 for stale read"
}
4 changes: 2 additions & 2 deletions oracle/oracles/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ func NewEmptyPDOracle() oracle.Oracle {
func SetEmptyPDOracleLastTs(oc oracle.Oracle, ts uint64) {
switch o := oc.(type) {
case *pdOracle:
now := &lastTSO{ts, ts}
now := &lastTSO{ts, oracle.GetTimeFromTS(ts)}
lastTSInterface, _ := o.lastTSMap.LoadOrStore(oracle.GlobalTxnScope, NewLastTSOPointer(now))
lastTSPointer := lastTSInterface.(*lastTSOPointer)
lastTSPointer.store(&lastTSO{tso: ts, arrival: ts})
lastTSPointer.store(&lastTSO{tso: ts, arrival: oracle.GetTimeFromTS(ts)})
}
}
15 changes: 13 additions & 2 deletions oracle/oracles/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ package oracles

import (
"context"
"math"
"sync"
"time"

Expand Down Expand Up @@ -136,13 +137,23 @@ func (l *localOracle) GetExternalTimestamp(ctx context.Context) (uint64, error)
return l.getExternalTimestamp(ctx)
}

func (l *localOracle) ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *oracle.Option) error {
func (l *localOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *oracle.Option) error {
if readTS == math.MaxUint64 {
if isStaleRead {
return oracle.ErrLatestStaleRead{}
}
return nil
}

currentTS, err := l.GetTimestamp(ctx, opt)
if err != nil {
return errors.Errorf("fail to validate read timestamp: %v", err)
}
if currentTS < readTS {
return errors.Errorf("cannot set read timestamp to a future time")
return oracle.ErrFutureTSRead{
ReadTS: readTS,
CurrentTS: currentTS,
}
}
return nil
}
17 changes: 14 additions & 3 deletions oracle/oracles/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ package oracles

import (
"context"
"math"
"sync"
"time"

Expand Down Expand Up @@ -122,14 +123,24 @@ func (o *MockOracle) GetLowResolutionTimestampAsync(ctx context.Context, opt *or
return o.GetTimestampAsync(ctx, opt)
}

// ValidateSnapshotReadTS implements oracle.Oracle interface.
func (o *MockOracle) ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *oracle.Option) error {
// ValidateReadTS implements ReadTSValidator interface.
func (o *MockOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *oracle.Option) error {
if readTS == math.MaxUint64 {
if isStaleRead {
return oracle.ErrLatestStaleRead{}
}
return nil
}

currentTS, err := o.GetTimestamp(ctx, opt)
if err != nil {
return errors.Errorf("fail to validate read timestamp: %v", err)
}
if currentTS < readTS {
return errors.Errorf("cannot set read timestamp to a future time")
return oracle.ErrFutureTSRead{
ReadTS: readTS,
CurrentTS: currentTS,
}
}
return nil
}
Expand Down
Loading

0 comments on commit ec354dc

Please sign in to comment.