From e5497d21d2b99e24155d902821a6b4650225c100 Mon Sep 17 00:00:00 2001 From: husharp Date: Fri, 24 May 2024 15:23:18 +0800 Subject: [PATCH 1/2] add caller ID Signed-off-by: husharp --- tools/pd-ctl/pdctl/command/global.go | 8 ++++---- tools/pd-ctl/tests/global_test.go | 28 ++++++++++++++++++++++++---- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/tools/pd-ctl/pdctl/command/global.go b/tools/pd-ctl/pdctl/command/global.go index fa77df6a101..f7c04c3ca5c 100644 --- a/tools/pd-ctl/pdctl/command/global.go +++ b/tools/pd-ctl/pdctl/command/global.go @@ -33,7 +33,7 @@ import ( ) const ( - pdControlCallerID = "pd-ctl" + PDControlCallerID = "pd-ctl" clusterPrefix = "pd/api/v1/cluster" ) @@ -107,7 +107,7 @@ func initNewPDClient(cmd *cobra.Command, opts ...pd.ClientOption) error { if PDCli != nil { PDCli.Close() } - PDCli = pd.NewClient(pdControlCallerID, getEndpoints(cmd), opts...) + PDCli = pd.NewClient(PDControlCallerID, getEndpoints(cmd), opts...).WithCallerID(PDControlCallerID) return nil } @@ -122,7 +122,7 @@ func initNewPDClientWithTLS(cmd *cobra.Command, caPath, certPath, keyPath string // TODO: replace dialClient with the PD HTTP client completely. var dialClient = &http.Client{ - Transport: apiutil.NewCallerIDRoundTripper(http.DefaultTransport, pdControlCallerID), + Transport: apiutil.NewCallerIDRoundTripper(http.DefaultTransport, PDControlCallerID), } // RequireHTTPSClient creates a HTTPS client if the related flags are set @@ -153,7 +153,7 @@ func initHTTPSClient(caPath, certPath, keyPath string) error { } dialClient = &http.Client{ Transport: apiutil.NewCallerIDRoundTripper( - &http.Transport{TLSClientConfig: tlsConfig}, pdControlCallerID), + &http.Transport{TLSClientConfig: tlsConfig}, PDControlCallerID), } return nil } diff --git a/tools/pd-ctl/tests/global_test.go b/tools/pd-ctl/tests/global_test.go index f4f55e2af89..645f410c876 100644 --- a/tools/pd-ctl/tests/global_test.go +++ b/tools/pd-ctl/tests/global_test.go @@ -16,10 +16,12 @@ package tests import ( "context" + "encoding/json" "fmt" "net/http" "testing" + "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/log" "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/utils/apiutil" @@ -27,22 +29,33 @@ import ( "github.com/tikv/pd/pkg/utils/testutil" "github.com/tikv/pd/server" cmd "github.com/tikv/pd/tools/pd-ctl/pdctl" + "github.com/tikv/pd/tools/pd-ctl/pdctl/command" "go.uber.org/zap" ) -const pdControlCallerID = "pd-ctl" - func TestSendAndGetComponent(t *testing.T) { re := require.New(t) handler := func(context.Context, *server.Server) (http.Handler, apiutil.APIServiceGroup, error) { mux := http.NewServeMux() + mux.HandleFunc("/pd/api/v1/cluster", func(w http.ResponseWriter, r *http.Request) { + callerID := apiutil.GetCallerIDOnHTTP(r) + for k := range r.Header { + log.Info("header", zap.String("key", k)) + } + log.Info("caller id", zap.String("caller-id", callerID)) + re.Equal(command.PDControlCallerID, callerID) + cluster := &metapb.Cluster{Id: 1} + clusterBytes, err := json.Marshal(cluster) + re.NoError(err) + w.Write(clusterBytes) + }) mux.HandleFunc("/pd/api/v1/health", func(w http.ResponseWriter, r *http.Request) { callerID := apiutil.GetCallerIDOnHTTP(r) for k := range r.Header { log.Info("header", zap.String("key", k)) } log.Info("caller id", zap.String("caller-id", callerID)) - re.Equal(pdControlCallerID, callerID) + re.Equal(command.PDControlCallerID, callerID) fmt.Fprint(w, callerID) }) info := apiutil.APIServiceGroup{ @@ -67,5 +80,12 @@ func TestSendAndGetComponent(t *testing.T) { args := []string{"-u", pdAddr, "health"} output, err := ExecuteCommand(cmd, args...) re.NoError(err) - re.Equal(fmt.Sprintf("%s\n", pdControlCallerID), string(output)) + re.Equal(fmt.Sprintf("%s\n", command.PDControlCallerID), string(output)) + + args = []string{"-u", pdAddr, "cluster"} + output, err = ExecuteCommand(cmd, args...) + re.NoError(err) + re.Equal(fmt.Sprintf("%s\n", `{ + "id": 1 +}`), string(output)) } From 4daa1a7da5d25c1c4d38e3c169b929ef95b7b22d Mon Sep 17 00:00:00 2001 From: husharp Date: Fri, 24 May 2024 15:59:00 +0800 Subject: [PATCH 2/2] add comment for test Signed-off-by: husharp --- tools/pd-ctl/tests/global_test.go | 33 +++++++++++++++++-------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/tools/pd-ctl/tests/global_test.go b/tools/pd-ctl/tests/global_test.go index 645f410c876..6987267ea54 100644 --- a/tools/pd-ctl/tests/global_test.go +++ b/tools/pd-ctl/tests/global_test.go @@ -22,7 +22,6 @@ import ( "testing" "github.com/pingcap/kvproto/pkg/metapb" - "github.com/pingcap/log" "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/utils/apiutil" "github.com/tikv/pd/pkg/utils/assertutil" @@ -30,31 +29,30 @@ import ( "github.com/tikv/pd/server" cmd "github.com/tikv/pd/tools/pd-ctl/pdctl" "github.com/tikv/pd/tools/pd-ctl/pdctl/command" - "go.uber.org/zap" ) func TestSendAndGetComponent(t *testing.T) { re := require.New(t) handler := func(context.Context, *server.Server) (http.Handler, apiutil.APIServiceGroup, error) { mux := http.NewServeMux() + // check pd http sdk api mux.HandleFunc("/pd/api/v1/cluster", func(w http.ResponseWriter, r *http.Request) { callerID := apiutil.GetCallerIDOnHTTP(r) - for k := range r.Header { - log.Info("header", zap.String("key", k)) - } - log.Info("caller id", zap.String("caller-id", callerID)) re.Equal(command.PDControlCallerID, callerID) cluster := &metapb.Cluster{Id: 1} clusterBytes, err := json.Marshal(cluster) re.NoError(err) w.Write(clusterBytes) }) + // check http client api + // TODO: remove this comment after replacing dialClient with the PD HTTP client completely. mux.HandleFunc("/pd/api/v1/health", func(w http.ResponseWriter, r *http.Request) { callerID := apiutil.GetCallerIDOnHTTP(r) - for k := range r.Header { - log.Info("header", zap.String("key", k)) - } - log.Info("caller id", zap.String("caller-id", callerID)) + re.Equal(command.PDControlCallerID, callerID) + fmt.Fprint(w, callerID) + }) + mux.HandleFunc("/pd/api/v1/stores", func(w http.ResponseWriter, r *http.Request) { + callerID := apiutil.GetCallerIDOnHTTP(r) re.Equal(command.PDControlCallerID, callerID) fmt.Fprint(w, callerID) }) @@ -77,15 +75,20 @@ func TestSendAndGetComponent(t *testing.T) { }() cmd := cmd.GetRootCmd() - args := []string{"-u", pdAddr, "health"} + args := []string{"-u", pdAddr, "cluster"} output, err := ExecuteCommand(cmd, args...) re.NoError(err) + re.Equal(fmt.Sprintf("%s\n", `{ + "id": 1 +}`), string(output)) + + args = []string{"-u", pdAddr, "health"} + output, err = ExecuteCommand(cmd, args...) + re.NoError(err) re.Equal(fmt.Sprintf("%s\n", command.PDControlCallerID), string(output)) - args = []string{"-u", pdAddr, "cluster"} + args = []string{"-u", pdAddr, "store"} output, err = ExecuteCommand(cmd, args...) re.NoError(err) - re.Equal(fmt.Sprintf("%s\n", `{ - "id": 1 -}`), string(output)) + re.Equal(fmt.Sprintf("%s\n", command.PDControlCallerID), string(output)) }