diff --git a/internal/pkg/cppdependencyscanner/depsscannerclient/BUILD.bazel b/internal/pkg/cppdependencyscanner/depsscannerclient/BUILD.bazel index e14a49c..8acd3cf 100644 --- a/internal/pkg/cppdependencyscanner/depsscannerclient/BUILD.bazel +++ b/internal/pkg/cppdependencyscanner/depsscannerclient/BUILD.bazel @@ -7,6 +7,7 @@ go_library( visibility = ["//:__subpackages__"], deps = [ "//api/scandeps", + "//internal/pkg/diagnostics", "//internal/pkg/ipc", "@com_github_bazelbuild_remote_apis_sdks//go/pkg/command", "@com_github_bazelbuild_remote_apis_sdks//go/pkg/outerr", diff --git a/internal/pkg/cppdependencyscanner/depsscannerclient/depsscannerclient.go b/internal/pkg/cppdependencyscanner/depsscannerclient/depsscannerclient.go index 4408fa8..6120ac4 100644 --- a/internal/pkg/cppdependencyscanner/depsscannerclient/depsscannerclient.go +++ b/internal/pkg/cppdependencyscanner/depsscannerclient/depsscannerclient.go @@ -37,6 +37,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "github.com/bazelbuild/reclient/internal/pkg/diagnostics" "github.com/bazelbuild/reclient/internal/pkg/ipc" pb "github.com/bazelbuild/reclient/api/scandeps" @@ -115,9 +116,24 @@ var ( type connectFn func(ctx context.Context, address string) (pb.CPPDepsScannerClient, *pb.CapabilitiesResponse, error) // New creates new DepsScannerClient. -func New(ctx context.Context, executor executor, cacheDir string, cacheFileMaxMb int, useDepsCache bool, logDir string, depsScannerAddress, proxyServerAddress string, connTimeout time.Duration, connect connectFn) (*DepsScannerClient, error) { +func New(ctx context.Context, executor executor, cacheDir string, cacheFileMaxMb int, useDepsCache bool, logDir string, depsScannerAddress, proxyServerAddress string, connTimeout time.Duration, connect connectFn) (client *DepsScannerClient, err error) { + var addr string + defer func() { + if err == nil { + return + } + log.Infof("Found errors in scandeps startup, running diagnostics...") + cCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + in := &diagnostics.DiagnosticInputs{ + UDSAddr: depsScannerAddress, + } + diagnostics.Run(cCtx, in) + }() + log.Infof("Connecting to remote dependency scanner: %v", depsScannerAddress) - client := &DepsScannerClient{ + addr = depsScannerAddress + client = &DepsScannerClient{ address: depsScannerAddress, executor: executor, cacheDir: cacheDir, @@ -130,12 +146,12 @@ func New(ctx context.Context, executor executor, cacheDir string, cacheFileMaxMb if strings.HasPrefix(depsScannerAddress, "exec://") { executable := depsScannerAddress[7:] - addr, err := buildAddress(proxyServerAddress, findOpenPort) + addr, err = buildAddress(proxyServerAddress, findOpenPort) if err != nil { return nil, fmt.Errorf("Failed to build address for dependency scanner: %w", err) } client.address = addr - if err := client.startService(ctx, executable); err != nil { + if err = client.startService(ctx, executable); err != nil { return nil, fmt.Errorf("Failed to start dependency scanner: %w", err) } } @@ -150,7 +166,7 @@ func New(ctx context.Context, executor executor, cacheDir string, cacheFileMaxMb defer cancel() go func() { defer close(connectCh) - client, capabilities, err := connect(ctx, client.address) + client, capabilities, err := connect(ctx, addr) select { case connectCh <- connectResponse{ client: client, @@ -486,6 +502,23 @@ func (ds *DepsScannerClient) restartService(ctx context.Context, executable stri return err } +func removeUDSFile(addr string) { + udsPath := addr + if !strings.HasPrefix(addr, "unix") { + return + } + if strings.HasPrefix(addr, "unix://") { + udsPath = strings.TrimPrefix(addr, "unix://") + } else if strings.HasPrefix(addr, "unix:") { + udsPath = strings.TrimPrefix(addr, "unix:") + } + if err := os.Remove(udsPath); err != nil { + log.Warningf("Failed to remove UDS socket file at %v: %v", udsPath, err) + return + } + log.Infof("Successfully removed UDS socket file at %v", udsPath) +} + func (ds *DepsScannerClient) startService(ctx context.Context, executable string) error { ctx, ds.terminate = context.WithCancel(ctx) ds.executable = executable @@ -521,6 +554,8 @@ func (ds *DepsScannerClient) startService(ctx context.Context, executable string envVars["GLOG_log_dir"] = ds.logDir } + removeUDSFile(ds.address) + log.Infof("Starting service: %v", cmdArgs) cmd := &command.Command{Args: cmdArgs} cmd.InputSpec = &command.InputSpec{ diff --git a/internal/pkg/diagnostics/BUILD.bazel b/internal/pkg/diagnostics/BUILD.bazel new file mode 100644 index 0000000..78d3dcf --- /dev/null +++ b/internal/pkg/diagnostics/BUILD.bazel @@ -0,0 +1,18 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "diagnostics", + srcs = [ + "diagnostics.go", + "simpleclientserver.go", + ], + importpath = "github.com/bazelbuild/reclient/internal/pkg/diagnostics", + visibility = ["//:__subpackages__"], + deps = ["@com_github_golang_glog//:glog"], +) + +go_test( + name = "diagnostics_test", + srcs = ["simpleclientserver_test.go"], + embed = [":diagnostics"], +) diff --git a/internal/pkg/diagnostics/diagnostics.go b/internal/pkg/diagnostics/diagnostics.go new file mode 100644 index 0000000..349d4bd --- /dev/null +++ b/internal/pkg/diagnostics/diagnostics.go @@ -0,0 +1,23 @@ +// Package diagnostics is used to provide diagnostic functionality to triage +// problems with reclient during failure scenarios. +package diagnostics + +import ( + "context" + + log "github.com/golang/glog" +) + +// DiagnosticInputs struct holds key state necessary for diagnostics to run. +type DiagnosticInputs struct { + UDSAddr string +} + +// Run runs the diagnostics. +func Run(ctx context.Context, in *DiagnosticInputs) { + if err := CheckUDSAddrWorks(ctx, in.UDSAddr); err != nil { + log.Errorf("DIAGNOSTIC_ERROR: UDS address check for %v had errors: %v", in.UDSAddr, err) + } else { + log.Infof("DIAGNOSTIC_SUCCESS: UDS address %v works with toy RPC server", in.UDSAddr) + } +} diff --git a/internal/pkg/diagnostics/simpleclientserver.go b/internal/pkg/diagnostics/simpleclientserver.go new file mode 100644 index 0000000..9e8f0a4 --- /dev/null +++ b/internal/pkg/diagnostics/simpleclientserver.go @@ -0,0 +1,112 @@ +package diagnostics + +import ( + "context" + "fmt" + "net" + "net/rpc" + "os" + "strings" + + log "github.com/golang/glog" +) + +// HelloService is a simple service to provide a HelloWorld RPC. +type HelloService struct{} + +// HelloRequest identifies the hello world request. +type HelloRequest struct{} + +// HelloResponse identifies the hello world response. +type HelloResponse struct { + Message string +} + +// HelloWorld is the RPC method that returns a hello world response. +func (s *HelloService) HelloWorld(req *HelloRequest, res *HelloResponse) error { + res.Message = "Hello, world!" + return nil +} + +// CheckUDSAddrWorks checks if we are able to start a basic RPC +// server at the given UDS address. Useful to diagnost scandeps_server +// startup timeout issues. +func CheckUDSAddrWorks(ctx context.Context, addr string) error { + if strings.HasPrefix(addr, "unix://") { + addr = strings.ReplaceAll(addr, "unix://", "") + } else if strings.HasPrefix(addr, "unix:") { + addr = strings.ReplaceAll(addr, "unix:", "") + } else { + return fmt.Errorf("addr must begin with a unix:// or unix: prefix: %v", addr) + } + os.RemoveAll(addr) + + serverStarted := make(chan bool, 1) + clientComplete := make(chan bool, 1) + errs := make(chan error, 1) + var listener net.Listener + + go func() { + helloService := new(HelloService) + err := rpc.Register(helloService) + if err != nil { + errs <- fmt.Errorf("error registering service: %v", err) + return + } + + listener, err = net.Listen("unix", addr) + if err != nil { + errs <- fmt.Errorf("listener error: %v", err) + return + } + + log.Infof("Diagnostic server listening on UNIX socket: %v", addr) + serverStarted <- true + rpc.Accept(listener) + }() + select { + case <-ctx.Done(): + return fmt.Errorf("context timeout reached in CheckUDSAddrWorks(%v)", addr) + + case <-serverStarted: + break + + case err := <-errs: + return fmt.Errorf("failed to start RPC server at %v: %v", addr, err) + } + + errs = make(chan error, 1) + go func() { + defer func() { + clientComplete <- true + }() + client, err := rpc.Dial("unix", addr) + if err != nil { + errs <- fmt.Errorf("error connecting to server: %v", err) + return + } + defer client.Close() + + req := &HelloRequest{} + res := &HelloResponse{} + + err = client.Call("HelloService.HelloWorld", req, res) + if err != nil { + errs <- fmt.Errorf("error calling HelloWorld: %v", err) + } + }() + log.V(1).Infof("DIAGNOSTICS: Waiting for client-server communication on address %v", addr) + select { + case <-ctx.Done(): + return fmt.Errorf("context timeout reached in CheckUDSAddrWorks(%v)", addr) + + case <-clientComplete: + break + + case err := <-errs: + return fmt.Errorf("failed to talk to RPC server at %v: %v", addr, err) + } + listener.Close() + + return nil +} diff --git a/internal/pkg/diagnostics/simpleclientserver_test.go b/internal/pkg/diagnostics/simpleclientserver_test.go new file mode 100644 index 0000000..7f84e5b --- /dev/null +++ b/internal/pkg/diagnostics/simpleclientserver_test.go @@ -0,0 +1,34 @@ +package diagnostics + +import ( + "context" + "os" + "path/filepath" + "runtime" + "testing" +) + +func TestCheckUDSAddrWorks(t *testing.T) { + var addr string + if runtime.GOOS == "windows" { + addr = "unix:" + filepath.Join(os.TempDir(), "test.sock") + } else { + addr = "unix://" + filepath.Join(os.TempDir(), "test.sock") + } + + if err := CheckUDSAddrWorks(context.Background(), addr); err != nil { + t.Errorf("CheckUDSAddrWorks(%v) failed: %v", addr, err) + } +} +func TestCheckUDSAddrWorks_InvalidAddress(t *testing.T) { + var addr string + if runtime.GOOS == "windows" { + addr = "unix:X:\\tmp\\test.sock" + } else { + addr = "unix:///rooooot/test.sock" + } + + if err := CheckUDSAddrWorks(context.Background(), addr); err == nil { + t.Errorf("CheckUDSAddrWorks(%v) expected error but succeeded", addr) + } +}