diff --git a/client/tlsutil/tlsconfig.go b/client/tlsutil/tlsconfig.go index 5bf03dc4afc5..338c6050aa18 100644 --- a/client/tlsutil/tlsconfig.go +++ b/client/tlsutil/tlsconfig.go @@ -64,8 +64,8 @@ type TLSInfo struct { // should be left nil. In that case, tls.X509KeyPair will be used. parseFunc func([]byte, []byte) (tls.Certificate, error) - // AllowedCN is a CN which must be provided by a client. - AllowedCN string + // AllowedCNs is a list of CNs which must be provided by a client. + AllowedCNs []string } // ClientConfig generates a tls.Config object for use by an HTTP client. @@ -121,12 +121,14 @@ func (info TLSInfo) baseConfig() (*tls.Config, error) { cfg.CipherSuites = info.CipherSuites } - if info.AllowedCN != "" { + if len(info.AllowedCNs) > 0 { cfg.VerifyPeerCertificate = func(_ [][]byte, verifiedChains [][]*x509.Certificate) error { for _, chains := range verifiedChains { if len(chains) != 0 { - if info.AllowedCN == chains[0].Subject.CommonName { - return nil + for _, allowedCN := range info.AllowedCNs { + if allowedCN == chains[0].Subject.CommonName { + return nil + } } } } @@ -162,8 +164,8 @@ type TLSConfig struct { CertPath string `toml:"cert-path" json:"cert-path"` // KeyPath is the path of file that contains X509 key in PEM format. KeyPath string `toml:"key-path" json:"key-path"` - // CertAllowedCN is a CN which must be provided by a client - CertAllowedCN []string `toml:"cert-allowed-cn" json:"cert-allowed-cn"` + // CertAllowedCNs is a CN which must be provided by a client + CertAllowedCNs []string `toml:"cert-allowed-cn" json:"cert-allowed-cn"` SSLCABytes []byte SSLCertBytes []byte @@ -194,16 +196,12 @@ func (s TLSConfig) ToTLSConfig() (*tls.Config, error) { if len(s.CertPath) == 0 && len(s.KeyPath) == 0 { return nil, nil } - allowedCN, err := s.GetOneAllowedCN() - if err != nil { - return nil, err - } tlsInfo := TLSInfo{ CertFile: s.CertPath, KeyFile: s.KeyPath, TrustedCAFile: s.CAPath, - AllowedCN: allowedCN, + AllowedCNs: s.CertAllowedCNs, } tlsConfig, err := tlsInfo.ClientConfig() @@ -212,15 +210,3 @@ func (s TLSConfig) ToTLSConfig() (*tls.Config, error) { } return tlsConfig, nil } - -// GetOneAllowedCN only gets the first one CN. -func (s TLSConfig) GetOneAllowedCN() (string, error) { - switch len(s.CertAllowedCN) { - case 1: - return s.CertAllowedCN[0], nil - case 0: - return "", nil - default: - return "", errs.ErrSecurityConfig.FastGenByArgs("only supports one CN") - } -} diff --git a/pkg/utils/grpcutil/grpcutil.go b/pkg/utils/grpcutil/grpcutil.go index 5f852ff73593..df53dbee1934 100644 --- a/pkg/utils/grpcutil/grpcutil.go +++ b/pkg/utils/grpcutil/grpcutil.go @@ -52,8 +52,8 @@ type TLSConfig struct { CertPath string `toml:"cert-path" json:"cert-path"` // KeyPath is the path of file that contains X509 key in PEM format. KeyPath string `toml:"key-path" json:"key-path"` - // CertAllowedCN is a CN which must be provided by a client - CertAllowedCN []string `toml:"cert-allowed-cn" json:"cert-allowed-cn"` + // CertAllowedCNs is a CN which must be provided by a client + CertAllowedCNs []string `toml:"cert-allowed-cn" json:"cert-allowed-cn"` SSLCABytes []byte SSLCertBytes []byte @@ -65,16 +65,12 @@ func (s TLSConfig) ToTLSInfo() (*transport.TLSInfo, error) { if len(s.CertPath) == 0 && len(s.KeyPath) == 0 { return nil, nil } - allowedCN, err := s.GetOneAllowedCN() - if err != nil { - return nil, err - } return &transport.TLSInfo{ CertFile: s.CertPath, KeyFile: s.KeyPath, TrustedCAFile: s.CAPath, - AllowedCN: allowedCN, + AllowedCNs: s.CertAllowedCNs, }, nil } @@ -114,18 +110,6 @@ func (s TLSConfig) ToTLSConfig() (*tls.Config, error) { return tlsConfig, nil } -// GetOneAllowedCN only gets the first one CN. -func (s TLSConfig) GetOneAllowedCN() (string, error) { - switch len(s.CertAllowedCN) { - case 1: - return s.CertAllowedCN[0], nil - case 0: - return "", nil - default: - return "", errs.ErrSecurityConfig.FastGenByArgs("only supports one CN") - } -} - // GetClientConn returns a gRPC client connection. // creates a client connection to the given target. By default, it's // a non-blocking dial (the function won't wait for connections to be diff --git a/server/config/config.go b/server/config/config.go index eb38f326932a..bdfa4401b44c 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -726,21 +726,21 @@ func (c *Config) GenEmbedEtcdConfig() (*embed.Config, error) { cfg.QuotaBackendBytes = int64(c.QuotaBackendBytes) cfg.MaxRequestBytes = c.MaxRequestBytes - allowedCN, serr := c.Security.GetOneAllowedCN() - if serr != nil { - return nil, serr - } cfg.ClientTLSInfo.ClientCertAuth = len(c.Security.CAPath) != 0 cfg.ClientTLSInfo.TrustedCAFile = c.Security.CAPath cfg.ClientTLSInfo.CertFile = c.Security.CertPath cfg.ClientTLSInfo.KeyFile = c.Security.KeyPath - // Client no need to set the CN. (cfg.ClientTLSInfo.AllowedCN = allowedCN) + // Keep compatibility with https://github.com/tikv/pd/pull/2305 + // Only check client cert when there are multiple CNs. + if len(c.Security.CertAllowedCNs) > 1 { + cfg.ClientTLSInfo.AllowedCNs = c.Security.CertAllowedCNs + } + fmt.Println(" c.Security.CertAllowedCNs", c.Security.CertAllowedCNs) cfg.PeerTLSInfo.ClientCertAuth = len(c.Security.CAPath) != 0 cfg.PeerTLSInfo.TrustedCAFile = c.Security.CAPath cfg.PeerTLSInfo.CertFile = c.Security.CertPath cfg.PeerTLSInfo.KeyFile = c.Security.KeyPath - //nolint:staticcheck - cfg.PeerTLSInfo.AllowedCN = allowedCN + cfg.PeerTLSInfo.AllowedCNs = c.Security.CertAllowedCNs cfg.ForceNewCluster = c.ForceNewCluster cfg.ZapLoggerBuilder = embed.NewZapCoreLoggerBuilder(c.Logger, c.Logger.Core(), c.LogProps.Syncer) cfg.EnableGRPCGateway = c.EnableGRPCGateway diff --git a/tests/integrations/client/cert_opt.sh b/tests/integrations/client/cert_opt.sh index 3984e67f3ab0..02f72249db72 100755 --- a/tests/integrations/client/cert_opt.sh +++ b/tests/integrations/client/cert_opt.sh @@ -3,7 +3,7 @@ cert_dir="$2" function generate_certs() { if [[ ! -z "$cert_dir" ]]; then - cd "$cert_dir" || exit 255 # Change to the specified directory + cd "$cert_dir" || exit 255 # Change to the specified directory fi if ! [[ "$0" =~ "cert_opt.sh" ]]; then @@ -21,10 +21,10 @@ function generate_certs() { openssl req -new -x509 -key ca-key.pem -out ca.pem -days 1 -subj "/CN=ca" # pd-server openssl genpkey -algorithm RSA -out pd-server-key.pem - openssl req -new -key pd-server-key.pem -out pd-server.csr -subj "/CN=pd-server" + openssl req -new -key pd-server-key.pem -out pd-server.csr -subj "/CN=pd-server" # Add IP address as a SAN - echo "subjectAltName = IP:127.0.0.1" > extfile.cnf + echo "subjectAltName = IP:127.0.0.1" >extfile.cnf openssl x509 -req -in pd-server.csr -CA ca.pem -CAkey ca-key.pem -CAcreateserial -out pd-server.pem -days 1 -extfile extfile.cnf # Clean up the temporary extension file @@ -34,11 +34,16 @@ function generate_certs() { openssl genpkey -algorithm RSA -out client-key.pem openssl req -new -key client-key.pem -out client.csr -subj "/CN=client" openssl x509 -req -in client.csr -CA ca.pem -CAkey ca-key.pem -CAcreateserial -out client.pem -days 1 + + # client2 + openssl genpkey -algorithm RSA -out tidb-client-key.pem + openssl req -new -key tidb-client-key.pem -out tidb-client.csr -subj "/CN=tidb" + openssl x509 -req -in tidb-client.csr -CA ca.pem -CAkey ca-key.pem -CAcreateserial -out tidb-client.pem -days 1 } function cleanup_certs() { if [[ ! -z "$cert_dir" ]]; then - cd "$cert_dir" || exit 255 # Change to the specified directory + cd "$cert_dir" || exit 255 # Change to the specified directory fi rm -f ca.pem ca-key.pem ca.srl diff --git a/tests/integrations/client/client_tls_test.go b/tests/integrations/client/client_tls_test.go index aac864fdaa52..9e68a6561361 100644 --- a/tests/integrations/client/client_tls_test.go +++ b/tests/integrations/client/client_tls_test.go @@ -52,6 +52,12 @@ var ( TrustedCAFile: "./cert/ca.pem", } + testTiDBClientTLSInfo = transport.TLSInfo{ + KeyFile: "./cert/tidb-client-key.pem", + CertFile: "./cert/tidb-client.pem", + TrustedCAFile: "./cert/ca.pem", + } + testTLSInfoExpired = transport.TLSInfo{ KeyFile: "./cert-expired/pd-server-key.pem", CertFile: "./cert-expired/pd-server.pem", @@ -63,27 +69,14 @@ var ( // when all certs are atomically replaced by directory renaming. // And expects server to reject client requests, and vice versa. func TestTLSReloadAtomicReplace(t *testing.T) { + re := require.New(t) + // generate certs for _, path := range []string{certPath, certExpiredPath} { - if err := os.Mkdir(path, 0755); err != nil { - t.Fatal(err) - } - if err := exec.Command(certScript, "generate", path).Run(); err != nil { - t.Fatal(err) - } + cleanFunc := generateCerts(re, path) + defer cleanFunc() } - defer func() { - for _, path := range []string{certPath, certExpiredPath} { - if err := exec.Command(certScript, "cleanup", path).Run(); err != nil { - t.Fatal(err) - } - if err := os.RemoveAll(path); err != nil { - t.Fatal(err) - } - } - }() - re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() tmpDir := t.TempDir() @@ -123,6 +116,20 @@ func TestTLSReloadAtomicReplace(t *testing.T) { testTLSReload(ctx, re, cloneFunc, replaceFunc, revertFunc) } +func generateCerts(re *require.Assertions, path string) func() { + err := os.Mkdir(path, 0755) + re.NoError(err) + err = exec.Command(certScript, "generate", path).Run() + re.NoError(err) + + return func() { + err := exec.Command(certScript, "cleanup", path).Run() + re.NoError(err) + err = os.RemoveAll(path) + re.NoError(err) + } +} + func testTLSReload( ctx context.Context, re *require.Assertions, @@ -275,3 +282,64 @@ func copyFile(src, dst string) error { } return w.Sync() } + +func TestMultiCN(t *testing.T) { + re := require.New(t) + cleanFunc := generateCerts(re, certPath) + defer cleanFunc() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tmpDir := t.TempDir() + os.RemoveAll(tmpDir) + + certsDir := t.TempDir() + tlsInfo, terr := copyTLSFiles(testTLSInfo, certsDir) + re.NoError(terr) + clus, err := tests.NewTestCluster(ctx, 1, func(conf *config.Config, _ string) { + conf.Security.TLSConfig = grpcutil.TLSConfig{ + KeyPath: tlsInfo.KeyFile, + CertPath: tlsInfo.CertFile, + CAPath: tlsInfo.TrustedCAFile, + CertAllowedCNs: []string{"tidb", "pd-server"}, + } + conf.AdvertiseClientUrls = strings.ReplaceAll(conf.AdvertiseClientUrls, "http", "https") + conf.ClientUrls = strings.ReplaceAll(conf.ClientUrls, "http", "https") + conf.AdvertisePeerUrls = strings.ReplaceAll(conf.AdvertisePeerUrls, "http", "https") + conf.PeerUrls = strings.ReplaceAll(conf.PeerUrls, "http", "https") + conf.InitialCluster = strings.ReplaceAll(conf.InitialCluster, "http", "https") + }) + re.NoError(err) + defer clus.Destroy() + err = clus.RunInitialServers() + re.NoError(err) + clus.WaitLeader() + + testServers := clus.GetServers() + endpoints := make([]string, 0, len(testServers)) + for _, s := range testServers { + endpoints = append(endpoints, s.GetConfig().AdvertiseClientUrls) + } + + // cn TiDB is allowed + re.NoError(testAllowedCN(ctx, endpoints, testTiDBClientTLSInfo)) + + // cn client is not allowed + re.Error(testAllowedCN(ctx, endpoints, testClientTLSInfo)) +} + +func testAllowedCN(ctx context.Context, endpoints []string, tls transport.TLSInfo) error { + ctx1, cancel1 := context.WithTimeout(ctx, 3*time.Second) + defer cancel1() + cli, err := pd.NewClientWithContext(ctx1, endpoints, pd.SecurityOption{ + CAPath: tls.TrustedCAFile, + CertPath: tls.CertFile, + KeyPath: tls.KeyFile, + }, pd.WithGRPCDialOptions(grpc.WithBlock())) + if err != nil { + return err + } + defer cli.Close() + _, err = cli.GetAllMembers(ctx1) + return err +}