Skip to content

Commit

Permalink
refactor: move the tmp path into 'Download()' method in 'Downloader'
Browse files Browse the repository at this point in the history
Signed-off-by: zongzhe <[email protected]>
  • Loading branch information
zong-zhe committed Sep 26, 2024
1 parent af04497 commit 3da2366
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 51 deletions.
49 changes: 1 addition & 48 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@ import (
"io"
"os"
"path/filepath"
"runtime"
"strings"
"syscall"

goerr "errors"

"github.com/BurntSushi/toml"
"github.com/dominikbraun/graph"
Expand Down Expand Up @@ -992,20 +988,12 @@ func (c *KpmClient) Download(dep *pkg.Dependency, homePath, localPath string) (*
}
}

// create a tmp dir to download the oci package.
tmpDir, err := os.MkdirTemp("", "")
if err != nil {
return nil, reporter.NewErrorEvent(reporter.Bug, err, fmt.Sprintf("failed to create temp dir '%s'.", tmpDir))
}
// clean the temp dir.
defer os.RemoveAll(tmpDir)

credCli, err := c.GetCredsClient()
if err != nil {
return nil, err
}
err = c.DepDownloader.Download(*downloader.NewDownloadOptions(
downloader.WithLocalPath(tmpDir),
downloader.WithLocalPath(localPath),
downloader.WithSource(dep.Source),
downloader.WithLogWriter(c.logWriter),
downloader.WithSettings(c.settings),
Expand All @@ -1016,41 +1004,6 @@ func (c *KpmClient) Download(dep *pkg.Dependency, homePath, localPath string) (*
return nil, err
}

// check the package in tmp dir is a valid kcl package.
_, err = pkg.FindFirstKclPkgFrom(tmpDir)
if err != nil {
return nil, err
}

// rename the tmp dir to the local path.
if utils.DirExists(localPath) {
err := os.RemoveAll(localPath)
if err != nil {
return nil, err
}
}

if runtime.GOOS != "windows" {
err = os.Rename(tmpDir, localPath)
if err != nil {
// check the error is caused by moving the file across file systems.
if goerr.Is(err, syscall.EXDEV) {
// If it is, use copy as a fallback.
err = copy.Copy(tmpDir, localPath)
if err != nil {
return nil, err
}
} else {
return nil, err
}
}
} else {
err = copy.Copy(tmpDir, localPath)
if err != nil {
return nil, err
}
}

// load the package from the local path.
dpkg, err := c.LoadPkgFromPath(localPath)
if err != nil {
Expand Down
54 changes: 51 additions & 3 deletions pkg/downloader/downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import (
"io"
"os"
"path/filepath"
"runtime"
"strings"
"syscall"

v1 "github.com/opencontainers/image-spec/specs-go/v1"
"github.com/otiai10/copy"
Expand Down Expand Up @@ -128,6 +130,17 @@ func NewOciDownloader(platform string) *DepDownloader {

func (d *DepDownloader) Download(opts DownloadOptions) error {

// create a tmp dir to download the oci package.
tmpDir, err := os.MkdirTemp("", "")
if err != nil {
return fmt.Errorf("failed to create a temp dir: %w", err)
}
if opts.Source.Git != nil {
tmpDir = filepath.Join(tmpDir, constants.GitScheme)
}
// clean the temp dir.
defer os.RemoveAll(tmpDir)

var localPath string
if opts.EnableCache {
// TODO: After the new local storage structure is complete,
Expand Down Expand Up @@ -176,7 +189,7 @@ func (d *DepDownloader) Download(opts DownloadOptions) error {
localPath = opts.LocalPath
}

opts.LocalPath = localPath
opts.LocalPath = tmpDir
// Dispatch the download to the specific downloader by package source.
if opts.Source.Oci != nil || opts.Source.Registry != nil {
if opts.Source.Registry != nil {
Expand All @@ -185,14 +198,49 @@ func (d *DepDownloader) Download(opts DownloadOptions) error {
if d.OciDownloader == nil {
d.OciDownloader = &OciDownloader{}
}
return d.OciDownloader.Download(opts)
err := d.OciDownloader.Download(opts)
if err != nil {
return err
}
}

if opts.Source.Git != nil {
if d.GitDownloader == nil {
d.GitDownloader = &GitDownloader{}
}
return d.GitDownloader.Download(opts)
err := d.GitDownloader.Download(opts)
if err != nil {
return err
}
}

// rename the tmp dir to the local path.
if utils.DirExists(localPath) {
err := os.RemoveAll(localPath)
if err != nil {
return err
}
}

if runtime.GOOS != "windows" {
err = os.Rename(tmpDir, localPath)
if err != nil {
// check the error is caused by moving the file across file systems.
if errors.Is(err, syscall.EXDEV) {
// If it is, use copy as a fallback.
err = copy.Copy(tmpDir, localPath)
if err != nil {
return err
}
} else {
return err
}
}
} else {
err = copy.Copy(tmpDir, localPath)
if err != nil {
return err
}
}
return nil
}
Expand Down

0 comments on commit 3da2366

Please sign in to comment.