diff --git a/internal/pkg/frontier/save.go b/internal/pkg/frontier/save.go index 3238bd80..071ef837 100644 --- a/internal/pkg/frontier/save.go +++ b/internal/pkg/frontier/save.go @@ -1,10 +1,8 @@ package frontier import ( - "encoding/gob" "os" "path" - "sync" "github.com/sirupsen/logrus" ) @@ -27,17 +25,15 @@ func (f *Frontier) Load() { } defer decodeFile.Close() - // Create a decoder - decoder := gob.NewDecoder(decodeFile) - - // We create the structure to load the file's content - var dump = new(sync.Map) - - // Decode the content of the file in the structure - decoder.Decode(&dump) - - // Copy the loaded data to our actual frontier - f.HostPool = dump + if err := SyncMapDecode(f.HostPool, decodeFile); err != nil { + f.LoggingChan <- &FrontierLogMessage{ + Fields: logrus.Fields{ + "err": err.Error(), + }, + Message: "unable to decode Frontier stats and host pool", + Level: logrus.WarnLevel, + } + } f.LoggingChan <- &FrontierLogMessage{ Fields: logrus.Fields{ @@ -53,13 +49,25 @@ func (f *Frontier) Save() { // Create a file for IO encodeFile, err := os.OpenFile(path.Join(f.JobPath, "frontier.gob"), os.O_CREATE|os.O_WRONLY, 0644) if err != nil { - logrus.Warning(err) + f.LoggingChan <- &FrontierLogMessage{ + Fields: logrus.Fields{ + "err": err.Error(), + }, + Message: "unable to open Frontier file", + Level: logrus.WarnLevel, + } } defer encodeFile.Close() // Write to the file - var encoder = gob.NewEncoder(encodeFile) - if err := encoder.Encode(f.HostPool); err != nil { - logrus.Warning(err) + + if err := SyncMapEncode(f.HostPool, encodeFile); err != nil { + f.LoggingChan <- &FrontierLogMessage{ + Fields: logrus.Fields{ + "err": err.Error(), + }, + Message: "unable to save Frontier stats and host pool", + Level: logrus.WarnLevel, + } } } diff --git a/internal/pkg/frontier/utils.go b/internal/pkg/frontier/utils.go index adf05d42..183fefc2 100644 --- a/internal/pkg/frontier/utils.go +++ b/internal/pkg/frontier/utils.go @@ -2,10 +2,12 @@ package frontier import ( "bufio" + "encoding/gob" "errors" "fmt" "net/url" "os" + "sync" "github.com/gosuri/uilive" "github.com/sirupsen/logrus" @@ -67,3 +69,40 @@ func IsSeedList(path string) (seeds []Item, err error) { return seeds, nil } + +type Pair struct { + Key, Value interface{} +} + +func SyncMapEncode(m *sync.Map, file *os.File) error { + var pairs []Pair + + m.Range(func(key, value interface{}) bool { + pairs = append(pairs, Pair{key, value}) + return true + }) + + gob.Register(PoolItem{}) + + enc := gob.NewEncoder(file) + err := enc.Encode(pairs) + + return err +} + +func SyncMapDecode(m *sync.Map, file *os.File) error { + var pairs []Pair + gob.Register(PoolItem{}) + dec := gob.NewDecoder(file) + err := dec.Decode(&pairs) + + if err != nil { + return err + } + + for _, p := range pairs { + m.Store(p.Key, p.Value.(PoolItem)) + } + + return nil +}