From a25ef3f7184aa9965f02d63aefe8daa240fba5c1 Mon Sep 17 00:00:00 2001 From: Peter Wang Date: Wed, 6 Sep 2023 10:30:10 +0800 Subject: [PATCH] Fix konnectivity issue Fix that context was cancelled before the proxy connection setup completely which cause by GH-229 --- pkg/csi/client/grpc.go | 28 ++++++++++++++++++---------- pkg/csi/client/konnectivity.go | 13 ++++++------- pkg/csi/controllerserver.go | 3 ++- pkg/csi/nodeutils.go | 4 ++-- 4 files changed, 28 insertions(+), 20 deletions(-) diff --git a/pkg/csi/client/grpc.go b/pkg/csi/client/grpc.go index 3b2c8897..c007620d 100644 --- a/pkg/csi/client/grpc.go +++ b/pkg/csi/client/grpc.go @@ -87,10 +87,14 @@ func (c *workerConnection) Close() error { return c.conn.Close() } -func connect(address string, timeout time.Duration, proxyOpts GrpcProxyClientOptions) (*grpc.ClientConn, error) { - log.V(6).Infof("New Connecting to %s", address) - ctx, cancel := context.WithTimeout(context.Background(), timeout) +// connect create new connection to the `address` +// if proxyOpts.ProxyUDSName or proxyOpts.ProxyHost is not empty, a proxied connection returned +// currently, the proxy server can only be konnnectivity +func connect(address string, timeout time.Duration, proxyOpts GrpcProxyClientOptions) (conn *grpc.ClientConn, err error) { + log.V(6).Infof("New Connecting to remote address %s ...", address) + connectCtx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() + // only for unit test var bufDialerFunc func(context.Context, string) (net.Conn, error) dialOptions := []grpc.DialOption{ @@ -103,18 +107,22 @@ func connect(address string, timeout time.Duration, proxyOpts GrpcProxyClientOpt } dialOptions = append(dialOptions, grpc.WithContextDialer(bufDialerFunc)) } + // end for unit test // setup konnectivity proxy connection + var proxyAddress string if proxyOpts.ProxyHost != "" || proxyOpts.ProxyUDSName != "" { - var err error var proxyDialer proxyFunc if proxyOpts.ProxyUDSName == "" { - proxyDialer, err = getKonnectivityMTLSDialer(ctx, address, timeout, proxyOpts) + proxyAddress = proxyOpts.ProxyHost + proxyDialer, err = getKonnectivityMTLSDialer(connectCtx, address, proxyOpts) } else { - proxyDialer, err = getKonnectivityUDSDialer(ctx, address, timeout, proxyOpts) + proxyAddress = proxyOpts.ProxyUDSName + proxyDialer, err = getKonnectivityUDSDialer(connectCtx, address, proxyOpts) } if err != nil { - return nil, fmt.Errorf("failed to setup konnectivity dialer: %w", err) + return nil, fmt.Errorf("failed to setup konnectivity dialer(%q): %w", address, err) } + log.Infof("connected to proxy server %q.", proxyAddress) dialOptions = append(dialOptions, grpc.WithContextDialer(proxyDialer)) } // if strings.HasPrefix(address, "/") { @@ -129,14 +137,14 @@ func connect(address string, timeout time.Duration, proxyOpts GrpcProxyClientOpt return net.DialTimeout("unix", addr, timeout) })) } - conn, err := grpc.Dial(address, dialOptions...) + conn, err = grpc.Dial(address, dialOptions...) if err != nil { return nil, err } for { - if !conn.WaitForStateChange(ctx, conn.GetState()) { - log.Warningf("Connection to %s timed out", address) + if !conn.WaitForStateChange(connectCtx, conn.GetState()) { + log.Warningf("Connection to %s timed out, subsequent calls might fail due to this.", address) return conn, nil // return nil, subsequent GetPluginInfo will show the real connection error } if conn.GetState() == connectivity.Ready { diff --git a/pkg/csi/client/konnectivity.go b/pkg/csi/client/konnectivity.go index f903a43b..1066f765 100644 --- a/pkg/csi/client/konnectivity.go +++ b/pkg/csi/client/konnectivity.go @@ -26,7 +26,6 @@ import ( "net/http" "os" "path/filepath" - "time" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -47,7 +46,7 @@ type GrpcProxyClientOptions struct { type proxyFunc func(ctx context.Context, addr string) (net.Conn, error) -func getKonnectivityUDSDialer(ctx context.Context, address string, timeout time.Duration, o GrpcProxyClientOptions) (func(ctx context.Context, addr string) (net.Conn, error), error) { +func getKonnectivityUDSDialer(ctx context.Context, address string,o GrpcProxyClientOptions) (func(ctx context.Context, addr string) (net.Conn, error), error) { log.Infof("using konnectivity UDS dialer") var proxyConn net.Conn @@ -57,15 +56,15 @@ func getKonnectivityUDSDialer(ctx context.Context, address string, timeout time. switch o.Mode { case "grpc": dialOption := grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { - c, err := net.DialTimeout("unix", o.ProxyUDSName, timeout) + c, err := net.Dial("unix", o.ProxyUDSName) if err != nil { log.ErrorS(err, "failed to create connection to uds", "name", o.ProxyUDSName) } return c, err }) tunnel, err := client.CreateSingleUseGrpcTunnelWithContext( - context.TODO(), - ctx, + ctx,// create context should follow grpc timeout configuration + context.TODO(), // tunnel context use context.TODO() o.ProxyUDSName, dialOption, grpc.WithTransportCredentials(insecure.NewCredentials()), @@ -112,7 +111,7 @@ func getKonnectivityUDSDialer(ctx context.Context, address string, timeout time. }, nil } -func getKonnectivityMTLSDialer(ctx context.Context, address string, _ time.Duration, o GrpcProxyClientOptions) (func(ctx context.Context, addr string) (net.Conn, error), error) { +func getKonnectivityMTLSDialer(ctx context.Context, address string, o GrpcProxyClientOptions) (func(ctx context.Context, addr string) (net.Conn, error), error) { log.Infof("using konnectivity mTLS dialer") tlsConfig, err := getClientTLSConfig(o.CACert, o.ClientCert, o.ClientKey, o.ProxyHost, nil) @@ -126,7 +125,7 @@ func getKonnectivityMTLSDialer(ctx context.Context, address string, _ time.Durat transportCreds := credentials.NewTLS(tlsConfig) dialOption := grpc.WithTransportCredentials(transportCreds) serverAddress := fmt.Sprintf("%s:%d", o.ProxyHost, o.ProxyPort) - tunnel, err := client.CreateSingleUseGrpcTunnelWithContext(context.TODO(), ctx, serverAddress, dialOption) + tunnel, err := client.CreateSingleUseGrpcTunnelWithContext(ctx, context.TODO(), serverAddress, dialOption) if err != nil { return nil, fmt.Errorf("failed to create tunnel %s, got %v", serverAddress, err) } diff --git a/pkg/csi/controllerserver.go b/pkg/csi/controllerserver.go index 01391dfe..3ea00b84 100644 --- a/pkg/csi/controllerserver.go +++ b/pkg/csi/controllerserver.go @@ -210,7 +210,7 @@ func (cs *controllerServer) CreateVolume(ctx context.Context, req *csi.CreateVol return nil, status.Errorf(codes.Internal, "CreateVolume: fail to get lv %s from node %s: %s", req.Name, nodeName, err.Error()) } else { if lvName == "" { - log.Info("CreateVolume: volume %s not found, creating volume on node %s", volumeID, nodeName) + log.Infof("CreateVolume: volume %s is not found, creating volume on node %s", volumeID, nodeName) outstr, err := conn.CreateVolume(ctx, options) if err != nil { return nil, status.Errorf(codes.Internal, "CreateVolume: fail to create lv %s(options: %v): %s", utils.GetNameKey(vgName, volumeID), options, err.Error()) @@ -767,6 +767,7 @@ func (cs *controllerServer) newCreateSnapshotResponse(snapshotId, sourceVolumeId }, nil } +// getNodeConn creates a new connection to the lvmd for the `nodeSelected` node func (cs *controllerServer) getNodeConn(nodeSelected string) (client.Connection, error) { node, err := cs.nodeLister.Get(nodeSelected) if err != nil { diff --git a/pkg/csi/nodeutils.go b/pkg/csi/nodeutils.go index 4bde71a8..2ac0c101 100644 --- a/pkg/csi/nodeutils.go +++ b/pkg/csi/nodeutils.go @@ -476,7 +476,7 @@ func (ns *nodeServer) createLvm(vgName, volumeID, lvmType, unit string, pvSize i log.Errorf("createVolume:: lvcreate command %s error: %v", cmd, err) return err } - log.Infof("Successful Create Striping LVM volume: %s, with command: %s", volumeID, cmd) + log.Infof("Successfully Create Striping LVM volume: %s, with command: %s", volumeID, cmd) } else if lvmType == LinearType { cmd := fmt.Sprintf("%s lvcreate -n %s -L %d%s -Wy -y %s", localtype.NsenterCmd, volumeID, pvSize, unit, vgName) _, err := ns.osTool.RunCommand(cmd) @@ -484,7 +484,7 @@ func (ns *nodeServer) createLvm(vgName, volumeID, lvmType, unit string, pvSize i log.Errorf("createVolume:: lvcreate linear command %s error: %v", cmd, err) return err } - log.Infof("Successful Create Linear LVM volume: %s, with command: %s", volumeID, cmd) + log.Infof("Successfully Create Linear LVM volume: %s, with command: %s", volumeID, cmd) } return nil }