diff --git a/cmd/envd-ssh/main.go b/cmd/envd-ssh/main.go index e3016c318..6a628f080 100644 --- a/cmd/envd-ssh/main.go +++ b/cmd/envd-ssh/main.go @@ -19,7 +19,6 @@ package main import ( "fmt" "os" - "strconv" "github.com/cockroachdb/errors" rawssh "github.com/gliderlabs/ssh" @@ -28,7 +27,6 @@ import ( "github.com/tensorchord/envd/pkg/config" "github.com/tensorchord/envd/pkg/remote/sshd" - "github.com/tensorchord/envd/pkg/ssh" "github.com/tensorchord/envd/pkg/version" ) @@ -37,6 +35,7 @@ const ( flagDebug = "debug" flagAuthKey = "authorized-keys" flagNoAuth = "no-auth" + flagPort = "port" ) func main() { @@ -64,6 +63,10 @@ func main() { Usage: "disable authentication", Value: false, }, + &cli.IntFlag{ + Name: flagPort, + Usage: "port to listen on", + }, } // Deal with debug flag. @@ -89,18 +92,11 @@ func sshServer(c *cli.Context) error { logrus.Fatal(err.Error()) } - port := ssh.DefaultSSHPort - // TODO(gaocegege): Set it as a flag. - if p, ok := os.LookupEnv(envPort); ok { - var err error - port, err = strconv.Atoi(p) - if err != nil { - return errors.Wrap(err, "failed to parse port") - } - - if port <= 1024 { - return errors.New("failed to parse port: port is reserved") - } + port := c.Int(flagPort) + if port == 0 { + return errors.New("port must be set") + } else if port <= 1024 { + return errors.New("failed to parse port: port is reserved") } noAuth := c.Bool(flagNoAuth) diff --git a/go.mod b/go.mod index 1f366e5e4..4927a4638 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,7 @@ require ( github.com/onsi/gomega v1.19.0 github.com/opencontainers/go-digest v1.0.0 github.com/opencontainers/image-spec v1.0.3-0.20211202183452-c5a74bcca799 + github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 github.com/pkg/sftp v1.13.4 github.com/sirupsen/logrus v1.8.1 github.com/spf13/viper v1.4.0 diff --git a/go.sum b/go.sum index c306cd799..a957f02f2 100644 --- a/go.sum +++ b/go.sum @@ -454,6 +454,8 @@ github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFSt github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/pelletier/go-toml v1.9.4 h1:tjENF6MfZAg8e4ZmZTeWaWiT2vXtsoO6+iuOjFhECwM= github.com/pelletier/go-toml v1.9.4/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= +github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 h1:Ii+DKncOVM8Cu1Hc+ETb5K+23HdAMvESYE3ZJ5b5cMI= +github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5/go.mod h1:iIss55rKnNBTvrwdmkUpLnDpZoAHvWaiq5+iMmen4AE= github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= diff --git a/pkg/app/get_env_dep.go b/pkg/app/get_env_dep.go index 3158685ef..3a71d00a0 100644 --- a/pkg/app/get_env_dep.go +++ b/pkg/app/get_env_dep.go @@ -20,7 +20,6 @@ import ( "github.com/cockroachdb/errors" "github.com/olekukonko/tablewriter" - "github.com/sirupsen/logrus" "github.com/tensorchord/envd/pkg/envd" sshconfig "github.com/tensorchord/envd/pkg/ssh/config" "github.com/tensorchord/envd/pkg/types" @@ -43,11 +42,6 @@ var CommandGetEnvironmentDependency = &cli.Command{ Aliases: []string{"k"}, Value: sshconfig.GetPrivateKey(), }, - &cli.BoolFlag{ - Name: "full", - Usage: "Show full dependency information", - Aliases: []string{"f"}, - }, }, Action: getEnvironmentDependency, } @@ -61,20 +55,12 @@ func getEnvironmentDependency(clicontext *cli.Context) error { if err != nil { return errors.Wrap(err, "failed to create envd engine") } - full := clicontext.Bool("full") - if full { - output, err := envdEngine.ListEnvFullDependency(clicontext.Context, envName, clicontext.Path("private-key")) - if err != nil { - return errors.Wrap(err, "failed to list dependencies") - } - logrus.Infof("%s", output) - } else { - dep, err := envdEngine.ListEnvDependency(clicontext.Context, envName) - if err != nil { - return errors.Wrap(err, "failed to list dependencies") - } - renderDependencies(dep, os.Stdout) + + dep, err := envdEngine.ListEnvDependency(clicontext.Context, envName) + if err != nil { + return errors.Wrap(err, "failed to list dependencies") } + renderDependencies(dep, os.Stdout) return nil } diff --git a/pkg/app/up.go b/pkg/app/up.go index 38b6d89ff..8dfd78d56 100644 --- a/pkg/app/up.go +++ b/pkg/app/up.go @@ -32,6 +32,11 @@ import ( "github.com/tensorchord/envd/pkg/ssh" sshconfig "github.com/tensorchord/envd/pkg/ssh/config" "github.com/tensorchord/envd/pkg/util/fileutil" + "github.com/tensorchord/envd/pkg/util/netutil" +) + +const ( + localhost = "127.0.0.1" ) var CommandUp = &cli.Command{ @@ -155,8 +160,13 @@ func up(clicontext *cli.Context) error { } } + sshPort, err := netutil.GetFreePort() + if err != nil { + return errors.Wrap(err, "failed to get a free port") + } + containerID, containerIP, err := dockerClient.StartEnvd(clicontext.Context, - tag, ctr, buildContext, gpu, *ir.DefaultGraph, clicontext.Duration("timeout"), + tag, ctr, buildContext, gpu, sshPort, *ir.DefaultGraph, clicontext.Duration("timeout"), clicontext.StringSlice("volume")) if err != nil { return err @@ -164,14 +174,15 @@ func up(clicontext *cli.Context) error { logrus.Debugf("container %s is running", containerID) logrus.Debugf("Add entry %s to SSH config. at %s", buildContext, containerIP) - if err = sshconfig.AddEntry(ctr, containerIP, ssh.DefaultSSHPort, clicontext.Path("private-key")); err != nil { + if err = sshconfig.AddEntry( + ctr, localhost, sshPort, clicontext.Path("private-key")); err != nil { logrus.Infof("failed to add entry %s to your SSH config file: %s", ctr, err) return errors.Wrap(err, "failed to add entry to your SSH config file") } if !detach { sshClient, err := ssh.NewClient( - containerIP, "envd", ssh.DefaultSSHPort, true, clicontext.Path("private-key"), "") + localhost, "envd", sshPort, true, clicontext.Path("private-key"), "") if err != nil { return err } diff --git a/pkg/docker/docker.go b/pkg/docker/docker.go index 3cfdb7d1d..36f6524cd 100644 --- a/pkg/docker/docker.go +++ b/pkg/docker/docker.go @@ -38,6 +38,10 @@ import ( "github.com/tensorchord/envd/pkg/util/fileutil" ) +const ( + localhost = "127.0.0.1" +) + var ( interval = 1 * time.Second ) @@ -47,7 +51,8 @@ type Client interface { Load(ctx context.Context, r io.ReadCloser, quiet bool) error // Start creates the container for the given tag and container name. StartEnvd(ctx context.Context, tag, name, buildContext string, - gpuEnabled bool, g ir.Graph, timeout time.Duration, mountOptionsStr []string) (string, string, error) + gpuEnabled bool, sshPort int, g ir.Graph, timeout time.Duration, + mountOptionsStr []string) (string, string, error) StartBuildkitd(ctx context.Context, tag, name, mirror string) (string, error) IsRunning(ctx context.Context, name string) (bool, error) @@ -288,7 +293,7 @@ func (g generalClient) StartBuildkitd(ctx context.Context, // Start creates the container for the given tag and container name. func (c generalClient) StartEnvd(ctx context.Context, tag, name, buildContext string, - gpuEnabled bool, g ir.Graph, timeout time.Duration, mountOptionsStr []string) (string, string, error) { + gpuEnabled bool, sshPort int, g ir.Graph, timeout time.Duration, mountOptionsStr []string) (string, string, error) { logger := logrus.WithFields(logrus.Fields{ "tag": tag, "container": name, @@ -309,7 +314,8 @@ func (c generalClient) StartEnvd(ctx context.Context, tag, name, buildContext st base := fileutil.Base(buildContext) base = filepath.Join("/home/envd", base) config.WorkingDir = base - config.Entrypoint = append(config.Entrypoint, entrypointSH(g, config.WorkingDir)) + config.Entrypoint = append(config.Entrypoint, + entrypointSH(g, config.WorkingDir, sshPort)) mountOption := make([]mount.Mount, len(mountOptionsStr)+1) for i, option := range mountOptionsStr { @@ -343,12 +349,23 @@ func (c generalClient) StartEnvd(ctx context.Context, tag, name, buildContext st PortBindings: nat.PortMap{}, Mounts: mountOption, } + + // Configure ssh port. + natPort := nat.Port(fmt.Sprintf("%d/tcp", sshPort)) + hostConfig.PortBindings[natPort] = []nat.PortBinding{ + { + HostIP: localhost, + HostPort: strconv.Itoa(sshPort), + }, + } + config.ExposedPorts[natPort] = struct{}{} + // TODO(gaocegege): Avoid specific logic to set the port. if g.JupyterConfig != nil { natPort := nat.Port(fmt.Sprintf("%d/tcp", g.JupyterConfig.Port)) hostConfig.PortBindings[natPort] = []nat.PortBinding{ { - HostIP: "localhost", + HostIP: localhost, HostPort: strconv.Itoa(int(g.JupyterConfig.Port)), }, } @@ -361,7 +378,7 @@ func (c generalClient) StartEnvd(ctx context.Context, tag, name, buildContext st hostConfig.DeviceRequests = deviceRequests(-1) } - config.Labels = labels(name, g.JupyterConfig) + config.Labels = labels(name, g.JupyterConfig, sshPort) logger = logger.WithFields(logrus.Fields{ "entrypoint": config.Entrypoint, diff --git a/pkg/docker/entrypoint.go b/pkg/docker/entrypoint.go index 7dff99d17..15b9e8f5d 100644 --- a/pkg/docker/entrypoint.go +++ b/pkg/docker/entrypoint.go @@ -25,17 +25,17 @@ import ( const ( template = `set -e -/var/envd/bin/envd-ssh --authorized-keys %s & +/var/envd/bin/envd-ssh --authorized-keys %s --port %d & %s wait -n` ) -func entrypointSH(g ir.Graph, workingDir string) string { +func entrypointSH(g ir.Graph, workingDir string, sshPort int) string { if g.JupyterConfig != nil { cmds := jupyter.GenerateCommand(g, workingDir) return fmt.Sprintf(template, - config.ContainerauthorizedKeysPath, strings.Join(cmds, " ")) + config.ContainerauthorizedKeysPath, sshPort, strings.Join(cmds, " ")) } return fmt.Sprintf(template, - config.ContainerauthorizedKeysPath, "") + config.ContainerauthorizedKeysPath, sshPort, "") } diff --git a/pkg/docker/label.go b/pkg/docker/label.go index e28f2e8e8..11d2451a5 100644 --- a/pkg/docker/label.go +++ b/pkg/docker/label.go @@ -16,6 +16,7 @@ package docker import ( "fmt" + "strconv" "github.com/docker/docker/api/types/filters" @@ -23,9 +24,10 @@ import ( "github.com/tensorchord/envd/pkg/types" ) -func labels(name string, jupyterConfig *ir.JupyterConfig) map[string]string { +func labels(name string, jupyterConfig *ir.JupyterConfig, sshPort int) map[string]string { res := make(map[string]string) res[types.ContainerLabelName] = name + res[types.ContainerLabelSSHPort] = strconv.Itoa(sshPort) if jupyterConfig != nil { res[types.ContainerLabelJupyterAddr] = fmt.Sprintf("http://localhost:%d", jupyterConfig.Port) } diff --git a/pkg/envd/engine.go b/pkg/envd/engine.go index 756d89eda..1f27c42ce 100644 --- a/pkg/envd/engine.go +++ b/pkg/envd/engine.go @@ -19,7 +19,6 @@ import ( "github.com/cockroachdb/errors" "github.com/sirupsen/logrus" "github.com/tensorchord/envd/pkg/docker" - "github.com/tensorchord/envd/pkg/ssh" "github.com/tensorchord/envd/pkg/types" ) @@ -31,7 +30,6 @@ type Engine interface { ResumeEnvironment(ctx context.Context, env string) (string, error) ListEnvironment(ctx context.Context) ([]types.EnvdEnvironment, error) ListEnvDependency(ctx context.Context, env string) (*types.Dependency, error) - ListEnvFullDependency(ctx context.Context, env, SSHKeyPath string) (string, error) } type generalEngine struct { @@ -141,35 +139,3 @@ func (e generalEngine) ListEnvDependency( } return dep, nil } - -// ListEnvFullDependency attaches into the environment and gets the dependencies of the given environment. -func (e generalEngine) ListEnvFullDependency( - ctx context.Context, env, SSHKeyPath string) (string, error) { - logger := logrus.WithFields(logrus.Fields{ - "env": env, - "ssh-private-key": SSHKeyPath, - }) - logger.Debug("getting full dependencies") - ctr, err := e.dockerCli.GetContainer(ctx, env) - if err != nil { - return "", err - } - ctrIP := ctr.NetworkSettings.IPAddress - if ctrIP == "" { - return "", errors.New("failed to get the ip address of the container") - } - return e.getDependencyListFromSSH(ctx, ctrIP, SSHKeyPath) -} - -func (e generalEngine) getDependencyListFromSSH(ctx context.Context, ip, SSHKeyPath string) (string, error) { - sshClient, err := ssh.NewClient( - ip, "envd", ssh.DefaultSSHPort, true, SSHKeyPath, "") - if err != nil { - return "", errors.Wrap(err, "failed to create ssh client") - } - output, err := sshClient.ExecWithOutput("pip list") - if err != nil { - return "", errors.Wrap(err, "failed to get pip list") - } - return string(output), nil -} diff --git a/pkg/ssh/ssh.go b/pkg/ssh/ssh.go index 7c38513e9..08e02155f 100644 --- a/pkg/ssh/ssh.go +++ b/pkg/ssh/ssh.go @@ -37,8 +37,6 @@ import ( "golang.org/x/term" ) -const DefaultSSHPort = 2222 - type Client interface { Attach() error ExecWithOutput(cmd string) ([]byte, error) diff --git a/pkg/types/envd.go b/pkg/types/envd.go index b96207e5d..c27d873e3 100644 --- a/pkg/types/envd.go +++ b/pkg/types/envd.go @@ -47,22 +47,6 @@ type Dependency struct { PyPIPackages []string `json:"pypi_packages,omitempty"` } -const ( - ContainerLabelName = "ai.tensorchord.envd.name" - ContainerLabelJupyterAddr = "ai.tensorchord.envd.jupyter.address" - - ImageLabelVendor = "ai.tensorchord.envd.vendor" - ImageLabelGPU = "ai.tensorchord.envd.gpu" - ImageLabelAPT = "ai.tensorchord.envd.apt.packages" - ImageLabelPyPI = "ai.tensorchord.envd.pypi.packages" - ImageLabelR = "ai.tensorchord.envd.r.packages" - ImageLabelCUDA = "ai.tensorchord.envd.gpu.cuda" - ImageLabelCUDNN = "ai.tensorchord.envd.gpu.cudnn" - ImageLabelContext = "ai.tensorchord.envd.build.context" - - ImageVendorEnvd = "envd" -) - func NewImage(image types.ImageSummary) (*EnvdImage, error) { img := EnvdImage{ ImageSummary: image, diff --git a/pkg/types/label.go b/pkg/types/label.go new file mode 100644 index 000000000..0755c9aa2 --- /dev/null +++ b/pkg/types/label.go @@ -0,0 +1,18 @@ +package types + +const ( + ContainerLabelName = "ai.tensorchord.envd.name" + ContainerLabelJupyterAddr = "ai.tensorchord.envd.jupyter.address" + ContainerLabelSSHPort = "ai.tensorchord.envd.ssh.port" + + ImageLabelVendor = "ai.tensorchord.envd.vendor" + ImageLabelGPU = "ai.tensorchord.envd.gpu" + ImageLabelAPT = "ai.tensorchord.envd.apt.packages" + ImageLabelPyPI = "ai.tensorchord.envd.pypi.packages" + ImageLabelR = "ai.tensorchord.envd.r.packages" + ImageLabelCUDA = "ai.tensorchord.envd.gpu.cuda" + ImageLabelCUDNN = "ai.tensorchord.envd.gpu.cudnn" + ImageLabelContext = "ai.tensorchord.envd.build.context" + + ImageVendorEnvd = "envd" +) diff --git a/pkg/util/netutil/netutil.go b/pkg/util/netutil/netutil.go new file mode 100644 index 000000000..ea67fbd7d --- /dev/null +++ b/pkg/util/netutil/netutil.go @@ -0,0 +1,27 @@ +// Copyright 2022 The envd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netutil + +import ( + "github.com/phayes/freeport" +) + +func GetFreePort() (int, error) { + port, err := freeport.GetFreePort() + if err != nil { + return 0, err + } + return port, nil +}