Skip to content

Commit

Permalink
security: support multiple CN support for TLS connections
Browse files Browse the repository at this point in the history
Signed-off-by: lhy1024 <[email protected]>
  • Loading branch information
lhy1024 committed Aug 12, 2024
1 parent fe1dbf8 commit 045723d
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 71 deletions.
34 changes: 10 additions & 24 deletions client/tlsutil/tlsconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ type TLSInfo struct {
// should be left nil. In that case, tls.X509KeyPair will be used.
parseFunc func([]byte, []byte) (tls.Certificate, error)

// AllowedCN is a CN which must be provided by a client.
AllowedCN string
// AllowedCNs is a list of CNs which must be provided by a client.
AllowedCNs []string
}

// ClientConfig generates a tls.Config object for use by an HTTP client.
Expand Down Expand Up @@ -121,12 +121,14 @@ func (info TLSInfo) baseConfig() (*tls.Config, error) {
cfg.CipherSuites = info.CipherSuites
}

if info.AllowedCN != "" {
if len(info.AllowedCNs) > 0 {
cfg.VerifyPeerCertificate = func(_ [][]byte, verifiedChains [][]*x509.Certificate) error {
for _, chains := range verifiedChains {
if len(chains) != 0 {
if info.AllowedCN == chains[0].Subject.CommonName {
return nil
for _, allowedCN := range info.AllowedCNs {
if allowedCN == chains[0].Subject.CommonName {
return nil
}
}
}
}
Expand Down Expand Up @@ -162,8 +164,8 @@ type TLSConfig struct {
CertPath string `toml:"cert-path" json:"cert-path"`
// KeyPath is the path of file that contains X509 key in PEM format.
KeyPath string `toml:"key-path" json:"key-path"`
// CertAllowedCN is a CN which must be provided by a client
CertAllowedCN []string `toml:"cert-allowed-cn" json:"cert-allowed-cn"`
// CertAllowedCNs is a CN which must be provided by a client
CertAllowedCNs []string `toml:"cert-allowed-cn" json:"cert-allowed-cn"`

SSLCABytes []byte
SSLCertBytes []byte
Expand Down Expand Up @@ -194,16 +196,12 @@ func (s TLSConfig) ToTLSConfig() (*tls.Config, error) {
if len(s.CertPath) == 0 && len(s.KeyPath) == 0 {
return nil, nil
}
allowedCN, err := s.GetOneAllowedCN()
if err != nil {
return nil, err
}

tlsInfo := TLSInfo{
CertFile: s.CertPath,
KeyFile: s.KeyPath,
TrustedCAFile: s.CAPath,
AllowedCN: allowedCN,
AllowedCNs: s.CertAllowedCNs,
}

tlsConfig, err := tlsInfo.ClientConfig()
Expand All @@ -212,15 +210,3 @@ func (s TLSConfig) ToTLSConfig() (*tls.Config, error) {
}
return tlsConfig, nil
}

// GetOneAllowedCN only gets the first one CN.
func (s TLSConfig) GetOneAllowedCN() (string, error) {
switch len(s.CertAllowedCN) {
case 1:
return s.CertAllowedCN[0], nil
case 0:
return "", nil
default:
return "", errs.ErrSecurityConfig.FastGenByArgs("only supports one CN")
}
}
22 changes: 3 additions & 19 deletions pkg/utils/grpcutil/grpcutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ type TLSConfig struct {
CertPath string `toml:"cert-path" json:"cert-path"`
// KeyPath is the path of file that contains X509 key in PEM format.
KeyPath string `toml:"key-path" json:"key-path"`
// CertAllowedCN is a CN which must be provided by a client
CertAllowedCN []string `toml:"cert-allowed-cn" json:"cert-allowed-cn"`
// CertAllowedCNs is a CN which must be provided by a client
CertAllowedCNs []string `toml:"cert-allowed-cn" json:"cert-allowed-cn"`

SSLCABytes []byte
SSLCertBytes []byte
Expand All @@ -65,16 +65,12 @@ func (s TLSConfig) ToTLSInfo() (*transport.TLSInfo, error) {
if len(s.CertPath) == 0 && len(s.KeyPath) == 0 {
return nil, nil
}
allowedCN, err := s.GetOneAllowedCN()
if err != nil {
return nil, err
}

return &transport.TLSInfo{
CertFile: s.CertPath,
KeyFile: s.KeyPath,
TrustedCAFile: s.CAPath,
AllowedCN: allowedCN,
AllowedCNs: s.CertAllowedCNs,
}, nil
}

Expand Down Expand Up @@ -114,18 +110,6 @@ func (s TLSConfig) ToTLSConfig() (*tls.Config, error) {
return tlsConfig, nil
}

// GetOneAllowedCN only gets the first one CN.
func (s TLSConfig) GetOneAllowedCN() (string, error) {
switch len(s.CertAllowedCN) {
case 1:
return s.CertAllowedCN[0], nil
case 0:
return "", nil
default:
return "", errs.ErrSecurityConfig.FastGenByArgs("only supports one CN")
}
}

// GetClientConn returns a gRPC client connection.
// creates a client connection to the given target. By default, it's
// a non-blocking dial (the function won't wait for connections to be
Expand Down
14 changes: 7 additions & 7 deletions server/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -726,21 +726,21 @@ func (c *Config) GenEmbedEtcdConfig() (*embed.Config, error) {
cfg.QuotaBackendBytes = int64(c.QuotaBackendBytes)
cfg.MaxRequestBytes = c.MaxRequestBytes

allowedCN, serr := c.Security.GetOneAllowedCN()
if serr != nil {
return nil, serr
}
cfg.ClientTLSInfo.ClientCertAuth = len(c.Security.CAPath) != 0
cfg.ClientTLSInfo.TrustedCAFile = c.Security.CAPath
cfg.ClientTLSInfo.CertFile = c.Security.CertPath
cfg.ClientTLSInfo.KeyFile = c.Security.KeyPath
// Client no need to set the CN. (cfg.ClientTLSInfo.AllowedCN = allowedCN)
// Keep compatibility with https://github.com/tikv/pd/pull/2305
// Only check client cert when there are multiple CNs.
if len(c.Security.CertAllowedCNs) > 1 {
cfg.ClientTLSInfo.AllowedCNs = c.Security.CertAllowedCNs
}
fmt.Println(" c.Security.CertAllowedCNs", c.Security.CertAllowedCNs)
cfg.PeerTLSInfo.ClientCertAuth = len(c.Security.CAPath) != 0
cfg.PeerTLSInfo.TrustedCAFile = c.Security.CAPath
cfg.PeerTLSInfo.CertFile = c.Security.CertPath
cfg.PeerTLSInfo.KeyFile = c.Security.KeyPath
//nolint:staticcheck
cfg.PeerTLSInfo.AllowedCN = allowedCN
cfg.PeerTLSInfo.AllowedCNs = c.Security.CertAllowedCNs
cfg.ForceNewCluster = c.ForceNewCluster
cfg.ZapLoggerBuilder = embed.NewZapCoreLoggerBuilder(c.Logger, c.Logger.Core(), c.LogProps.Syncer)
cfg.EnableGRPCGateway = c.EnableGRPCGateway
Expand Down
13 changes: 9 additions & 4 deletions tests/integrations/client/cert_opt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ cert_dir="$2"

function generate_certs() {
if [[ ! -z "$cert_dir" ]]; then
cd "$cert_dir" || exit 255 # Change to the specified directory
cd "$cert_dir" || exit 255 # Change to the specified directory
fi

if ! [[ "$0" =~ "cert_opt.sh" ]]; then
Expand All @@ -21,10 +21,10 @@ function generate_certs() {
openssl req -new -x509 -key ca-key.pem -out ca.pem -days 1 -subj "/CN=ca"
# pd-server
openssl genpkey -algorithm RSA -out pd-server-key.pem
openssl req -new -key pd-server-key.pem -out pd-server.csr -subj "/CN=pd-server"
openssl req -new -key pd-server-key.pem -out pd-server.csr -subj "/CN=pd-server"

# Add IP address as a SAN
echo "subjectAltName = IP:127.0.0.1" > extfile.cnf
echo "subjectAltName = IP:127.0.0.1" >extfile.cnf
openssl x509 -req -in pd-server.csr -CA ca.pem -CAkey ca-key.pem -CAcreateserial -out pd-server.pem -days 1 -extfile extfile.cnf

# Clean up the temporary extension file
Expand All @@ -34,11 +34,16 @@ function generate_certs() {
openssl genpkey -algorithm RSA -out client-key.pem
openssl req -new -key client-key.pem -out client.csr -subj "/CN=client"
openssl x509 -req -in client.csr -CA ca.pem -CAkey ca-key.pem -CAcreateserial -out client.pem -days 1

# client2
openssl genpkey -algorithm RSA -out tidb-client-key.pem
openssl req -new -key tidb-client-key.pem -out tidb-client.csr -subj "/CN=tidb"
openssl x509 -req -in tidb-client.csr -CA ca.pem -CAkey ca-key.pem -CAcreateserial -out tidb-client.pem -days 1
}

function cleanup_certs() {
if [[ ! -z "$cert_dir" ]]; then
cd "$cert_dir" || exit 255 # Change to the specified directory
cd "$cert_dir" || exit 255 # Change to the specified directory
fi

rm -f ca.pem ca-key.pem ca.srl
Expand Down
102 changes: 85 additions & 17 deletions tests/integrations/client/client_tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ var (
TrustedCAFile: "./cert/ca.pem",
}

testTiDBClientTLSInfo = transport.TLSInfo{
KeyFile: "./cert/tidb-client-key.pem",
CertFile: "./cert/tidb-client.pem",
TrustedCAFile: "./cert/ca.pem",
}

testTLSInfoExpired = transport.TLSInfo{
KeyFile: "./cert-expired/pd-server-key.pem",
CertFile: "./cert-expired/pd-server.pem",
Expand All @@ -63,27 +69,14 @@ var (
// when all certs are atomically replaced by directory renaming.
// And expects server to reject client requests, and vice versa.
func TestTLSReloadAtomicReplace(t *testing.T) {
re := require.New(t)

// generate certs
for _, path := range []string{certPath, certExpiredPath} {
if err := os.Mkdir(path, 0755); err != nil {
t.Fatal(err)
}
if err := exec.Command(certScript, "generate", path).Run(); err != nil {
t.Fatal(err)
}
cleanFunc := generateCerts(re, path)
defer cleanFunc()
}
defer func() {
for _, path := range []string{certPath, certExpiredPath} {
if err := exec.Command(certScript, "cleanup", path).Run(); err != nil {
t.Fatal(err)
}
if err := os.RemoveAll(path); err != nil {
t.Fatal(err)
}
}
}()

re := require.New(t)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tmpDir := t.TempDir()
Expand Down Expand Up @@ -123,6 +116,20 @@ func TestTLSReloadAtomicReplace(t *testing.T) {
testTLSReload(ctx, re, cloneFunc, replaceFunc, revertFunc)
}

func generateCerts(re *require.Assertions, path string) func() {
err := os.Mkdir(path, 0755)
re.NoError(err)
err = exec.Command(certScript, "generate", path).Run()
re.NoError(err)

return func() {
err := exec.Command(certScript, "cleanup", path).Run()
re.NoError(err)
err = os.RemoveAll(path)
re.NoError(err)
}
}

func testTLSReload(
ctx context.Context,
re *require.Assertions,
Expand Down Expand Up @@ -275,3 +282,64 @@ func copyFile(src, dst string) error {
}
return w.Sync()
}

func TestMultiCN(t *testing.T) {
re := require.New(t)
cleanFunc := generateCerts(re, certPath)
defer cleanFunc()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tmpDir := t.TempDir()
os.RemoveAll(tmpDir)

certsDir := t.TempDir()
tlsInfo, terr := copyTLSFiles(testTLSInfo, certsDir)
re.NoError(terr)
clus, err := tests.NewTestCluster(ctx, 1, func(conf *config.Config, _ string) {
conf.Security.TLSConfig = grpcutil.TLSConfig{
KeyPath: tlsInfo.KeyFile,
CertPath: tlsInfo.CertFile,
CAPath: tlsInfo.TrustedCAFile,
CertAllowedCNs: []string{"tidb", "pd-server"},
}
conf.AdvertiseClientUrls = strings.ReplaceAll(conf.AdvertiseClientUrls, "http", "https")
conf.ClientUrls = strings.ReplaceAll(conf.ClientUrls, "http", "https")
conf.AdvertisePeerUrls = strings.ReplaceAll(conf.AdvertisePeerUrls, "http", "https")
conf.PeerUrls = strings.ReplaceAll(conf.PeerUrls, "http", "https")
conf.InitialCluster = strings.ReplaceAll(conf.InitialCluster, "http", "https")
})
re.NoError(err)
defer clus.Destroy()
err = clus.RunInitialServers()
re.NoError(err)
clus.WaitLeader()

testServers := clus.GetServers()
endpoints := make([]string, 0, len(testServers))
for _, s := range testServers {
endpoints = append(endpoints, s.GetConfig().AdvertiseClientUrls)
}

// cn TiDB is allowed
re.NoError(testAllowedCN(ctx, endpoints, testTiDBClientTLSInfo))

// cn client is not allowed
re.Error(testAllowedCN(ctx, endpoints, testClientTLSInfo))
}

func testAllowedCN(ctx context.Context, endpoints []string, tls transport.TLSInfo) error {
ctx1, cancel1 := context.WithTimeout(ctx, 3*time.Second)
defer cancel1()
cli, err := pd.NewClientWithContext(ctx1, endpoints, pd.SecurityOption{
CAPath: tls.TrustedCAFile,
CertPath: tls.CertFile,
KeyPath: tls.KeyFile,
}, pd.WithGRPCDialOptions(grpc.WithBlock()))
if err != nil {
return err
}
defer cli.Close()
_, err = cli.GetAllMembers(ctx1)
return err
}

0 comments on commit 045723d

Please sign in to comment.