From cabbeb006fdd93cbe91a8e1970590eaba87ab347 Mon Sep 17 00:00:00 2001 From: Connor Edwards <38229097+cedws@users.noreply.github.com> Date: Sun, 10 Nov 2024 16:59:56 +0000 Subject: [PATCH] feat(launcher): Implement CRC check when patching --- go.mod | 9 ++- go.sum | 7 ++ main.go | 236 +++++++++++++++++++++++++++++++++++++++++++------------- 3 files changed, 198 insertions(+), 54 deletions(-) diff --git a/go.mod b/go.mod index 874b1e8..c292e8f 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,13 @@ go 1.23.2 require ( github.com/cedws/w101-client-go v0.0.0-20241108205848-13d029f4311a github.com/cedws/w101-proto-go v0.0.0-20241108201431-d1d25865844d + github.com/snksoft/crc v1.1.0 + github.com/stretchr/testify v1.4.0 ) -require golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5 // indirect +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect +) diff --git a/go.sum b/go.sum index 998126d..7821101 100644 --- a/go.sum +++ b/go.sum @@ -2,10 +2,14 @@ github.com/cedws/w101-client-go v0.0.0-20241108205848-13d029f4311a h1:kNrbs1p8Kw github.com/cedws/w101-client-go v0.0.0-20241108205848-13d029f4311a/go.mod h1:LEEZPLzH2mdzndxQPSwg+7ho+Q1U+raLB6802UUvLzQ= github.com/cedws/w101-proto-go v0.0.0-20241108201431-d1d25865844d h1:mSxoQfRHZGYiWCHNzEl0fQEnv2BRrzVyjOvQx1V8470= github.com/cedws/w101-proto-go v0.0.0-20241108201431-d1d25865844d/go.mod h1:qRpqI9UDOQUr4luI6HRBDh/QNFVQwYhfSXmf8+UNwPQ= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/snksoft/crc v1.1.0 h1:HkLdI4taFlgGGG1KvsWMpz78PkOC9TkPVpTV/cuWn48= +github.com/snksoft/crc v1.1.0/go.mod h1:5/gUOsgAm7OmIhb6WJzw7w5g2zfJi4FrHYgGPdshE+A= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -15,5 +19,8 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/main.go b/main.go index 4bf86ec..2674196 100644 --- a/main.go +++ b/main.go @@ -13,6 +13,7 @@ import ( "os" "os/exec" "path/filepath" + "sync" "time" "github.com/cedws/w101-client-go/dml" @@ -20,6 +21,7 @@ import ( "github.com/cedws/w101-client-go/proto" "github.com/cedws/w101-proto-go/pkg/login" "github.com/cedws/w101-proto-go/pkg/patch" + "github.com/snksoft/crc" ) const ( @@ -32,6 +34,36 @@ var ( errTimeoutFileList = fmt.Errorf("timed out waiting for latest file list") ) +var makeHasherOnce = sync.OnceValue(func() fileHasher { + hash := *crc.NewHash(&crc.Parameters{ + Width: 32, + Polynomial: 0x4C11DB7, + ReflectIn: true, + ReflectOut: true, + Init: 0, + FinalXor: 0, + }) + + return fileHasher{hash} +}) + +type fileHasher struct { + crc.Hash +} + +func (f *fileHasher) Write(b []byte) (int, error) { + f.Hash.Update(b) + return len(b), nil +} + +func (f *fileHasher) CRC32() uint32 { + return f.Hash.CRC32() +} + +func (f *fileHasher) Reset() { + f.Hash.Reset() +} + type patchHandler struct { patch.PatchService fileListCh chan patch.LatestFileListV2 @@ -77,7 +109,7 @@ func main() { os.Exit(1) } - params := LaunchParams{ + params := launchParams{ Dir: dir, Username: username, Password: password, @@ -86,12 +118,28 @@ func main() { PatchServerAddr: patchServerAddr, } - if err := launch(ctx, params); err != nil { + patchClient := newPatchClient(params) + + if err := patchClient.launch(ctx, params); err != nil { log.Fatal(err) } } -type LaunchParams struct { +type patchClient struct { + launchParams + hasher *fileHasher +} + +func newPatchClient(params launchParams) *patchClient { + hasher := makeHasherOnce() + + return &patchClient{ + launchParams: params, + hasher: &hasher, + } +} + +type launchParams struct { Dir string Username string Password string @@ -100,8 +148,34 @@ type LaunchParams struct { PatchServerAddr string } -func launch(ctx context.Context, params LaunchParams) error { - fileList, err := latestFileList(ctx, params) +func (p *patchClient) launch(ctx context.Context, params launchParams) error { + if err := p.downloadBaseFiles(ctx); err != nil { + return err + } + + if !params.PatchOnly { + userID, ck2, err := p.requestCK2Token(ctx, params) + if err != nil { + return err + } + + if err := p.launchGraphicalClient(ctx, userID, ck2); err != nil { + return err + } + } + + return nil +} + +type patchFile struct { + URL string + Source string + Target string + CRC uint32 +} + +func (p *patchClient) downloadBaseFiles(ctx context.Context) error { + fileList, err := p.latestFileList(ctx) if err != nil { return err } @@ -117,35 +191,61 @@ func launch(ctx context.Context, params LaunchParams) error { return err } - for _, table := range *dmlTables { - if table.Name != "Base" { - continue - } + if err := p.processTables(ctx, fileList.URLPrefix, *dmlTables); err != nil { + return err + } - slog.Info("Downloading files for table", "table", table.Name) + return nil +} + +func (p *patchClient) processTables(ctx context.Context, urlPrefix string, tables []dml.Table) error { + for _, table := range tables { + if table.Name == "Base" { + slog.Info("Processing files for table", "table", table.Name) - for _, record := range table.Records { - if err := download(ctx, fileList.URLPrefix, record, params); err != nil { - return err + for _, record := range table.Records { + if err := p.processRecord(ctx, urlPrefix, record); err != nil { + return err + } } } } - if !params.PatchOnly { - userID, ck2, err := requestCK2Token(ctx, params) - if err != nil { - return err - } + return nil +} - if err := launchGraphicalClient(ctx, userID, ck2, params); err != nil { - return err - } +func (p *patchClient) processRecord(ctx context.Context, urlPrefix string, record dml.Record) error { + source := record["SrcFileName"].(string) + target := record["TarFileName"].(string) + crc := record["CRC"].(uint32) + + if target == "" { + target = source + } + + source = filepath.Clean(source) + target = filepath.Clean(target) + + fileURL, err := url.JoinPath(urlPrefix, source) + if err != nil { + return err + } + + patchFile := patchFile{ + URL: fileURL, + Source: source, + Target: target, + CRC: crc, + } + + if err := p.download(ctx, patchFile); err != nil { + return err } return nil } -func launchGraphicalClient(ctx context.Context, userID uint64, ck2 string, params LaunchParams) error { +func (p *patchClient) launchGraphicalClient(ctx context.Context, userID uint64, ck2 string) error { host, port, err := net.SplitHostPort(defaultLoginServer) if err != nil { return err @@ -155,17 +255,17 @@ func launchGraphicalClient(ctx context.Context, userID uint64, ck2 string, param "-L", host, port, "-U", ".." + fmt.Sprint(userID), string(ck2), - params.Username, + p.launchParams.Username, } slog.Info("Launching WizardGraphicalClient.exe", "args", args) cmd := exec.CommandContext(ctx, "./WizardGraphicalClient.exe", args...) - cmd.Dir = filepath.Join(params.Dir, "Bin") + cmd.Dir = filepath.Join(p.launchParams.Dir, "Bin") return cmd.Start() } -func requestCK2Token(ctx context.Context, params LaunchParams) (uint64, string, error) { +func (p *launchParams) requestCK2Token(ctx context.Context, params launchParams) (uint64, string, error) { authenRspCh := make(chan login.UserAuthenRsp) r := proto.NewMessageRouter() @@ -213,35 +313,34 @@ func requestCK2Token(ctx context.Context, params LaunchParams) (uint64, string, } } -func download(ctx context.Context, prefix string, record dml.Record, params LaunchParams) error { - srcFileName := record["SrcFileName"].(string) - tarFileName := record["TarFileName"].(string) - - if tarFileName == "" { - tarFileName = srcFileName +func (p *patchClient) download(ctx context.Context, patchFile patchFile) error { + ok, err := p.verifyFile(patchFile) + if err != nil { + return fmt.Errorf("error verifying file: %w", err) + } + if ok { + slog.Info("File OK", "crc", patchFile.CRC, "path", patchFile.Target) + return nil } - dirname := filepath.Dir(tarFileName) + dirname := filepath.Dir(patchFile.Target) - fulldir := filepath.Join(params.Dir, dirname) + fulldir := filepath.Join(p.launchParams.Dir, dirname) if err := os.MkdirAll(fulldir, 0755); err != nil { return err } - fileURL, err := url.JoinPath(prefix, srcFileName) - if err != nil { - return err - } - - slog.Info("Downloading file", "url", fileURL) + slog.Info("Downloading file", "url", patchFile.URL) - resp, err := request(ctx, fileURL) + resp, err := request(ctx, patchFile.URL) if err != nil { return err } defer resp.Close() - file, err := os.Create(filepath.Join(params.Dir, tarFileName)) + filePath := filepath.Join(p.launchParams.Dir, patchFile.Target) + + file, err := os.Create(filePath) if err != nil { return err } @@ -254,37 +353,50 @@ func download(ctx context.Context, prefix string, record dml.Record, params Laun return nil } -func request(ctx context.Context, url string) (io.ReadCloser, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) - if err != nil { - return nil, err +func (p *patchClient) verifyFile(patchFile patchFile) (bool, error) { + filePath := filepath.Join(p.launchParams.Dir, patchFile.Target) + slog.Info("Verifying file", "path", filePath) + + stat, err := os.Stat(filePath) + switch { + case os.IsNotExist(err): + return false, nil + case err != nil: + return false, err + case stat.IsDir(): + return false, nil + default: + // File exists, no error } - resp, err := http.DefaultClient.Do(req) + file, err := os.Open(filePath) if err != nil { - return nil, err + return false, err } + defer file.Close() - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + p.hasher.Reset() + if _, err := io.Copy(p.hasher, file); err != nil { + return false, err } + actualCRC := p.hasher.CRC32() - return resp.Body, err + return actualCRC == patchFile.CRC, nil } -func latestFileList(ctx context.Context, params LaunchParams) (*patch.LatestFileListV2, error) { +func (p *patchClient) latestFileList(ctx context.Context) (*patch.LatestFileListV2, error) { fileListCh := make(chan patch.LatestFileListV2) r := proto.NewMessageRouter() patch.RegisterPatchService(r, &patchHandler{fileListCh: fileListCh}) - protoClient, err := proto.Dial(ctx, params.PatchServerAddr, r) + protoClient, err := proto.Dial(ctx, p.launchParams.PatchServerAddr, r) if err != nil { return nil, err } defer protoClient.Close() - slog.Info("Connected to patch server", "server", params.PatchServerAddr) + slog.Info("Connected to patch server", "server", p.launchParams.PatchServerAddr) c := patch.NewPatchClient(protoClient) if err := c.LatestFileListV2(&patch.LatestFileListV2{}); err != nil { @@ -301,3 +413,21 @@ func latestFileList(ctx context.Context, params LaunchParams) (*patch.LatestFile return nil, ctx.Err() } } + +func request(ctx context.Context, url string) (io.ReadCloser, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + return resp.Body, err +}