diff --git a/cmd/main.go b/cmd/main.go new file mode 100644 index 00000000..5836f67f --- /dev/null +++ b/cmd/main.go @@ -0,0 +1,222 @@ +package main + +import ( + "fmt" + "io" + "log/slog" + "os" + "time" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/primevprotocol/mev-commit/pkg/node" + "github.com/urfave/cli/v2" + "gopkg.in/yaml.v2" +) + +const ( + defaultP2PPort = 13522 + defaultHTTPPort = 13523 +) + +var ( + optionConfig = &cli.StringFlag{ + Name: "config", + Usage: "path to config file", + Required: true, + EnvVars: []string{"MEV_COMMIT_CONFIG"}, + } +) + +func main() { + app := &cli.App{ + Name: "mev-commit", + Usage: "Entry point for mev-commit", + Commands: []*cli.Command{ + { + Name: "start", + Usage: "Start mev-commit", + Flags: []cli.Flag{ + optionConfig, + }, + Action: func(c *cli.Context) error { + return start(c) + }, + }, + { + Name: "create-key", + Action: func(c *cli.Context) error { + return createKey(c) + }, + }, + }} + + if err := app.Run(os.Args); err != nil { + fmt.Fprintf(app.Writer, "exited with error: %v\n", err) + } +} + +func createKey(c *cli.Context) error { + privKey, err := crypto.GenerateKey() + if err != nil { + return err + } + + if len(c.Args().Slice()) != 1 { + return fmt.Errorf("usage: mev-commit create-key ") + } + + outputFile := c.Args().Slice()[0] + + f, err := os.Create(outputFile) + if err != nil { + return err + } + + defer f.Close() + + if err := crypto.SaveECDSA(outputFile, privKey); err != nil { + return err + } + + fmt.Fprintf(c.App.Writer, "Private key saved to file: %s\n", outputFile) + return nil +} + +type config struct { + PrivKeyFile string `yaml:"priv_key_file" json:"priv_key_file"` + Secret string `yaml:"secret" json:"secret"` + PeerType string `yaml:"peer_type" json:"peer_type"` + P2PPort int `yaml:"p2p_port" json:"p2p_port"` + HTTPPort int `yaml:"http_port" json:"http_port"` + LogFmt string `yaml:"log_fmt" json:"log_fmt"` + LogLevel string `yaml:"log_level" json:"log_level"` + Bootnodes []string `yaml:"bootnodes" json:"bootnodes"` +} + +func checkConfig(cfg *config) error { + if cfg.PrivKeyFile == "" { + return fmt.Errorf("priv_key_file is required") + } + + if cfg.Secret == "" { + return fmt.Errorf("secret is required") + } + + if cfg.PeerType == "" { + return fmt.Errorf("peer_type is required") + } + + if cfg.P2PPort == 0 { + cfg.P2PPort = defaultP2PPort + } + + if cfg.HTTPPort == 0 { + cfg.HTTPPort = defaultHTTPPort + } + + if cfg.LogFmt == "" { + cfg.LogFmt = "text" + } + + if cfg.LogLevel == "" { + cfg.LogLevel = "info" + } + + return nil +} + +func start(c *cli.Context) error { + configFile := c.String(optionConfig.Name) + fmt.Fprintf(c.App.Writer, "starting mev-commit with config file: %s\n", configFile) + + var cfg config + buf, err := os.ReadFile(configFile) + if err != nil { + return fmt.Errorf("failed to read config file at '%s': %w", configFile, err) + } + + if err := yaml.Unmarshal(buf, &cfg); err != nil { + return fmt.Errorf("failed to unmarshal config file at '%s': %w", configFile, err) + } + + if err := checkConfig(&cfg); err != nil { + return fmt.Errorf("invalid config file at '%s': %w", configFile, err) + } + + logger, err := newLogger(cfg.LogLevel, cfg.LogFmt, c.App.Writer) + if err != nil { + return fmt.Errorf("failed to create logger: %w", err) + } + + privKey, err := crypto.LoadECDSA(cfg.PrivKeyFile) + if err != nil { + return fmt.Errorf("failed to load private key from file '%s': %w", cfg.PrivKeyFile, err) + } + + nd, err := node.NewNode(&node.Options{ + PrivKey: privKey, + Secret: cfg.Secret, + PeerType: cfg.PeerType, + P2PPort: cfg.P2PPort, + HTTPPort: cfg.HTTPPort, + Logger: logger, + Bootnodes: cfg.Bootnodes, + }) + if err != nil { + return fmt.Errorf("failed starting node: %w", err) + } + + <-c.Done() + fmt.Fprintf(c.App.Writer, "shutting down...\n") + closed := make(chan struct{}) + + go func() { + defer close(closed) + + err := nd.Close() + if err != nil { + logger.Error("failed to close node", "error", err) + } + }() + + select { + case <-closed: + case <-time.After(5 * time.Second): + logger.Error("failed to close node in time") + } + + return nil +} + +func newLogger(lvl, logFmt string, sink io.Writer) (*slog.Logger, error) { + var ( + level = new(slog.LevelVar) // Info by default + handler slog.Handler + ) + + switch lvl { + case "debug": + level.Set(slog.LevelDebug) + case "info": + level.Set(slog.LevelInfo) + case "warn": + level.Set(slog.LevelWarn) + case "error": + level.Set(slog.LevelError) + default: + return nil, fmt.Errorf("invalid log level: %s", lvl) + } + + switch logFmt { + case "text": + handler = slog.NewTextHandler(sink, &slog.HandlerOptions{Level: level}) + case "none": + fallthrough + case "json": + handler = slog.NewJSONHandler(sink, &slog.HandlerOptions{Level: level}) + default: + return nil, fmt.Errorf("invalid log format: %s", logFmt) + } + + return slog.New(handler), nil +} diff --git a/go.mod b/go.mod index 573afdf8..34044831 100644 --- a/go.mod +++ b/go.mod @@ -7,9 +7,11 @@ require ( github.com/libp2p/go-libp2p v0.31.0 github.com/libp2p/go-msgio v0.3.0 github.com/prometheus/client_golang v1.16.0 + github.com/urfave/cli/v2 v2.25.7 github.com/vmihailenco/msgpack/v5 v5.3.5 golang.org/x/crypto v0.12.0 golang.org/x/sync v0.3.0 + gopkg.in/yaml.v2 v2.4.0 ) require ( @@ -20,6 +22,7 @@ require ( github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/containerd/cgroups v1.1.0 // indirect github.com/coreos/go-systemd/v22 v22.5.0 // indirect + github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect github.com/davidlazar/go-crypto v0.0.0-20200604182044-b73af7476f6c // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect github.com/docker/go-units v0.5.0 // indirect @@ -81,8 +84,10 @@ require ( github.com/quic-go/quic-go v0.38.1 // indirect github.com/quic-go/webtransport-go v0.5.3 // indirect github.com/raulk/go-watchdog v1.3.0 // indirect + github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect + github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect go.uber.org/dig v1.17.0 // indirect go.uber.org/fx v1.20.0 // indirect go.uber.org/multierr v1.11.0 // indirect diff --git a/go.sum b/go.sum index a362d5b6..d5be14cf 100644 --- a/go.sum +++ b/go.sum @@ -47,6 +47,8 @@ github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8 github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= +github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= +github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v0.0.0-20171005155431-ecdeabc65495/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -282,6 +284,8 @@ github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBO github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48/go.mod h1:5u70Mqkb5O5cxEA8nxTsgrgLehJeAw6Oc4Ab1c/P1HM= @@ -321,12 +325,16 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/urfave/cli v1.22.2/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= +github.com/urfave/cli/v2 v2.25.7 h1:VAzn5oq403l5pHjc4OhD54+XGO9cdKVL/7lDjF+iKUs= +github.com/urfave/cli/v2 v2.25.7/go.mod h1:8qnjx1vcq5s2/wpsqoZFndg2CE5tNFyrTvS6SinrnYQ= github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU= github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= +github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU= +github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= @@ -487,6 +495,8 @@ gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWD gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/pkg/apiserver/api.go b/pkg/apiserver/api.go new file mode 100644 index 00000000..23ee95e7 --- /dev/null +++ b/pkg/apiserver/api.go @@ -0,0 +1,110 @@ +package apiserver + +import ( + "expvar" + "log/slog" + "net/http" + "net/http/pprof" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/collectors" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +const ( + defaultNamespace = "primev" +) + +// Service wraps http.Server with additional functionality for metrics and +// other common middlewares. +type Service struct { + metricsRegistry *prometheus.Registry + router *http.ServeMux + logger *slog.Logger +} + +// New creates a new Service. +func New( + version string, + logger *slog.Logger, +) *Service { + srv := &Service{ + router: http.NewServeMux(), + logger: logger, + metricsRegistry: newMetrics(version), + } + + srv.registerDebugEndpoints() + return srv +} + +func (a *Service) registerDebugEndpoints() { + // register metrics handler + a.router.Handle("/metrics", promhttp.HandlerFor(a.metricsRegistry, promhttp.HandlerOpts{})) + + // register pprof handlers + a.router.Handle( + "/debug/pprof", + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + u := r.URL + u.Path += "/" + http.Redirect(w, r, u.String(), http.StatusPermanentRedirect) + }), + ) + a.router.Handle("/debug/pprof/", http.HandlerFunc(pprof.Index)) + a.router.Handle("/debug/pprof/cmdline", http.HandlerFunc(pprof.Cmdline)) + a.router.Handle("/debug/pprof/profile", http.HandlerFunc(pprof.Profile)) + a.router.Handle("/debug/pprof/symbol", http.HandlerFunc(pprof.Symbol)) + a.router.Handle("/debug/pprof/trace", http.HandlerFunc(pprof.Trace)) + a.router.Handle("/debug/pprof/{profile}", http.HandlerFunc(pprof.Index)) + a.router.Handle("/debug/vars", expvar.Handler()) +} + +func newMetrics(version string) (r *prometheus.Registry) { + r = prometheus.NewRegistry() + + // register standard metrics + r.MustRegister( + collectors.NewProcessCollector(collectors.ProcessCollectorOpts{ + Namespace: defaultNamespace, + }), + collectors.NewGoCollector(), + prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: defaultNamespace, + Name: "info", + Help: "builder-boost information.", + ConstLabels: prometheus.Labels{ + "version": version, + }, + }), + ) + + return r +} + +// Router returns the router. +func (a *Service) Router() http.Handler { + return newAccessLogHandler(a.logger)(a.router) +} + +// ChainHandlers chains middlewares and handler. +func (a *Service) ChainHandlers( + path string, + handler http.Handler, + mws ...func(http.Handler) http.Handler, +) { + h := handler + for i := len(mws) - 1; i >= 0; i-- { + h = mws[i](h) + } + a.router.Handle(path, h) +} + +func (a *Service) MetricsRegistry() *prometheus.Registry { + return a.metricsRegistry +} + +// RegisterMetricsCollectors registers prometheus collectors. +func (a *Service) RegisterMetricsCollectors(cs ...prometheus.Collector) { + a.metricsRegistry.MustRegister(cs...) +} diff --git a/pkg/apiserver/api_test.go b/pkg/apiserver/api_test.go new file mode 100644 index 00000000..ae2d9ff7 --- /dev/null +++ b/pkg/apiserver/api_test.go @@ -0,0 +1,141 @@ +package apiserver_test + +import ( + "bytes" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/primevprotocol/mev-commit/pkg/apiserver" +) + +func newTestLogger(w io.Writer) *slog.Logger { + testLogger := slog.NewTextHandler(w, &slog.HandlerOptions{ + Level: slog.LevelDebug, + }) + return slog.New(testLogger) +} + +func TestAPIServer(t *testing.T) { + t.Parallel() + + t.Run("new and close", func(t *testing.T) { + var logBuf bytes.Buffer + s := apiserver.New( + "test", + newTestLogger(&logBuf), + ) + + srv := httptest.NewServer(s.Router()) + t.Cleanup(func() { + srv.Close() + }) + + r, err := http.NewRequest("GET", srv.URL+"/metrics", nil) + if err != nil { + t.Fatal(err) + } + + resp, err := http.DefaultClient.Do(r) + if err != nil { + t.Fatal(err) + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status code %d, got %d", http.StatusOK, resp.StatusCode) + } + + var b bytes.Buffer + n, err := b.ReadFrom(resp.Body) + if err != nil { + t.Fatal(err) + } + + if n == 0 { + t.Fatal("expected non-zero body") + } + + if !strings.Contains(b.String(), "test") { + t.Fatalf("expected body to contain 'test', got %q", b.String()) + } + + if !strings.Contains(b.String(), "go_info") { + t.Fatalf("expected body to contain 'go_info', got %q", b.String()) + } + + if !strings.Contains(b.String(), "go_memstats") { + t.Fatalf("expected body to contain 'go_memstats', got %q", b.String()) + } + + if !strings.Contains(b.String(), "go_gc_duration_seconds") { + t.Fatalf("expected body to contain 'go_gc_duration_seconds', got %q", b.String()) + } + + if !strings.Contains(logBuf.String(), "api access") { + t.Fatalf("expected log to contain 'api access', got %q", logBuf.String()) + } + }) + + t.Run("chain handlers", func(t *testing.T) { + var logBuf bytes.Buffer + s := apiserver.New( + "test", + newTestLogger(&logBuf), + ) + + srv := httptest.NewServer(s.Router()) + t.Cleanup(func() { + srv.Close() + }) + + var orderedHandlerActions []int + s.ChainHandlers( + "/chain", + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + orderedHandlerActions = append(orderedHandlerActions, 3) + w.WriteHeader(http.StatusOK) + }), + func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + orderedHandlerActions = append(orderedHandlerActions, 1) + next.ServeHTTP(w, r) + }) + }, + func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + orderedHandlerActions = append(orderedHandlerActions, 2) + next.ServeHTTP(w, r) + }) + }, + ) + + r, err := http.NewRequest("GET", srv.URL+"/chain", nil) + if err != nil { + t.Fatal(err) + } + + resp, err := http.DefaultClient.Do(r) + if err != nil { + t.Fatal(err) + } + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status code %d, got %d", http.StatusOK, resp.StatusCode) + } + + if len(orderedHandlerActions) != 3 { + t.Fatalf("expected 3 handler actions, got %d", len(orderedHandlerActions)) + } + + for i, v := range []int{1, 2, 3} { + if orderedHandlerActions[i] != v { + t.Fatalf("expected handler action %d, got %d", v, orderedHandlerActions[i]) + } + } + }) +} diff --git a/pkg/apiserver/helpers.go b/pkg/apiserver/helpers.go new file mode 100644 index 00000000..d74863ca --- /dev/null +++ b/pkg/apiserver/helpers.go @@ -0,0 +1,68 @@ +package apiserver + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "net/http" +) + +// StatusResponse is a helper struct used to wrap a string message with a status code. +type StatusResponse struct { + Code int `json:"code"` + Message string `json:"message"` +} + +// WriteResponse helper is used to write a response to the client with the given code +// and message. If the message is a string, it will be wrapped in a StatusResponse +// struct. Otherwise, the message will be encoded as JSON. +func WriteResponse(w http.ResponseWriter, code int, message any) error { + var b bytes.Buffer + switch message := message.(type) { + case string: + err := json.NewEncoder(&b).Encode(StatusResponse{Code: code, Message: message}) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return fmt.Errorf("failed to encode status response: %w", err) + } + default: + err := json.NewEncoder(&b).Encode(message) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return fmt.Errorf("failed to encode response: %w", err) + } + } + + w.WriteHeader(code) + w.Header().Set("Content-Type", "application/json") + fmt.Fprintln(w, b.String()) + return nil +} + +// MethodHandler helper is used to wrap a handler and ensure that the request method +// matches the given method. If the method does not match, a 405 is returned. +func MethodHandler(method string, handler http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != method { + err := WriteResponse(w, http.StatusMethodNotAllowed, "method not allowed") + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + } + return + } + handler(w, r) + } +} + +// BindJSON helper is used to bind the request body to the given type. +func BindJSON[T any](w http.ResponseWriter, r *http.Request) (T, error) { + var body T + + if r.Body == nil { + return body, errors.New("no body") + } + defer r.Body.Close() + + return body, json.NewDecoder(r.Body).Decode(&body) +} diff --git a/pkg/apiserver/helpers_test.go b/pkg/apiserver/helpers_test.go new file mode 100644 index 00000000..ae8a0ce9 --- /dev/null +++ b/pkg/apiserver/helpers_test.go @@ -0,0 +1,164 @@ +package apiserver_test + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/primevprotocol/mev-commit/pkg/apiserver" +) + +type testHandler struct { + called bool +} + +func (h *testHandler) Handle(_ http.ResponseWriter, _ *http.Request) { + h.called = true +} + +func TestMethodHandler(t *testing.T) { + t.Parallel() + + t.Run("method not allowed", func(t *testing.T) { + h := &testHandler{} + mh := apiserver.MethodHandler("GET", http.HandlerFunc(h.Handle)) + + r, err := http.NewRequest("POST", "/", nil) + if err != nil { + t.Fatal(err) + } + + w := httptest.NewRecorder() + mh.ServeHTTP(w, r) + + if w.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected status code %d, got %d", http.StatusMethodNotAllowed, w.Code) + } + + if h.called { + t.Fatal("handler should not have been called") + } + }) + + t.Run("method allowed", func(t *testing.T) { + h := &testHandler{} + mh := apiserver.MethodHandler("GET", http.HandlerFunc(h.Handle)) + + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + + w := httptest.NewRecorder() + mh.ServeHTTP(w, r) + + if !h.called { + t.Fatal("handler should have been called") + } + }) +} + +func TestBindJSON(t *testing.T) { + t.Parallel() + + t.Run("bad request", func(t *testing.T) { + type v struct { + Foo string `json:"foo"` + } + + r, err := http.NewRequest("POST", "/", nil) + if err != nil { + t.Fatal(err) + } + w := httptest.NewRecorder() + + if _, err := apiserver.BindJSON[v](w, r); err == nil { + t.Fatal("expected error") + } + }) + + t.Run("ok", func(t *testing.T) { + type v struct { + Foo string `json:"foo"` + } + + b := bytes.NewBuffer([]byte(`{"foo":"bar"}`)) + + r, err := http.NewRequest("POST", "/", b) + if err != nil { + t.Fatal(err) + } + + w := httptest.NewRecorder() + + vv, err := apiserver.BindJSON[v](w, r) + if err != nil { + t.Fatal(err) + } + + if vv.Foo != "bar" { + t.Fatalf("expected foo to be %q, got %q", "bar", vv.Foo) + } + }) +} + +func TestWriteResponse(t *testing.T) { + t.Parallel() + + t.Run("string", func(t *testing.T) { + w := httptest.NewRecorder() + + if err := apiserver.WriteResponse(w, http.StatusOK, "foo"); err != nil { + t.Fatal(err) + } + + if w.Code != http.StatusOK { + t.Fatalf("expected status code %d, got %d", http.StatusOK, w.Code) + } + + resp := apiserver.StatusResponse{ + Code: http.StatusOK, + Message: "foo", + } + + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(resp); err != nil { + t.Fatal(err) + } + buf.WriteByte('\n') + + if !bytes.Equal(w.Body.Bytes(), buf.Bytes()) { + t.Fatalf("expected body %q, got %q", buf.String(), w.Body.String()) + } + }) + + t.Run("struct", func(t *testing.T) { + type v struct { + Foo string `json:"foo"` + } + + rq := v{Foo: "bar"} + + w := httptest.NewRecorder() + + if err := apiserver.WriteResponse(w, http.StatusOK, rq); err != nil { + t.Fatal(err) + } + + if w.Code != http.StatusOK { + t.Fatalf("expected status code %d, got %d", http.StatusOK, w.Code) + } + + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(rq); err != nil { + t.Fatal(err) + } + buf.WriteByte('\n') + + if !bytes.Equal(w.Body.Bytes(), buf.Bytes()) { + t.Fatalf("expected body %q, got %q", buf.String(), w.Body.String()) + } + }) +} diff --git a/pkg/apiserver/middleware.go b/pkg/apiserver/middleware.go new file mode 100644 index 00000000..3d7b6454 --- /dev/null +++ b/pkg/apiserver/middleware.go @@ -0,0 +1,34 @@ +package apiserver + +import ( + "log/slog" + "net/http" + "time" +) + +type responseStatusRecorder struct { + http.ResponseWriter + status int +} + +func (r *responseStatusRecorder) WriteHeader(status int) { + r.status = status + r.ResponseWriter.WriteHeader(status) +} + +func newAccessLogHandler(log *slog.Logger) func(http.Handler) http.Handler { + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + recorder := &responseStatusRecorder{ResponseWriter: w} + + start := time.Now() + h.ServeHTTP(recorder, req) + log.Info("api access", + "status", recorder.status, + "method", req.Method, + "path", req.URL.Path, + "duration", time.Since(start), + ) + }) + } +} diff --git a/pkg/debugapi/debugapi.go b/pkg/debugapi/debugapi.go new file mode 100644 index 00000000..ddc6c744 --- /dev/null +++ b/pkg/debugapi/debugapi.go @@ -0,0 +1,68 @@ +package debugapi + +import ( + "log/slog" + "net/http" + + "github.com/primevprotocol/mev-commit/pkg/apiserver" + "github.com/primevprotocol/mev-commit/pkg/p2p" + "github.com/primevprotocol/mev-commit/pkg/p2p/libp2p" + "github.com/primevprotocol/mev-commit/pkg/topology" +) + +type APIServer interface { + ChainHandlers(string, http.Handler, ...func(http.Handler) http.Handler) +} + +func RegisterAPI( + srv APIServer, + topo topology.Topology, + p2pSvc *libp2p.Service, + logger *slog.Logger, +) { + d := &debugapi{ + topo: topo, + p2p: p2pSvc, + logger: logger, + } + + srv.ChainHandlers( + "/topology", + apiserver.MethodHandler("GET", d.handleTopology), + ) +} + +type debugapi struct { + topo topology.Topology + p2p *libp2p.Service + logger *slog.Logger +} + +type topologyResponse struct { + Self map[string]interface{} `json:"self"` + ConnectedPeers map[string][]p2p.Peer `json:"connected_peers"` +} + +func (d *debugapi) handleTopology(w http.ResponseWriter, r *http.Request) { + logger := d.logger.With("method", "handleTopology") + builders := d.topo.GetPeers(topology.Query{Type: p2p.PeerTypeBuilder}) + searchers := d.topo.GetPeers(topology.Query{Type: p2p.PeerTypeSearcher}) + + topoResp := topologyResponse{ + Self: d.p2p.Self(), + ConnectedPeers: make(map[string][]p2p.Peer), + } + + if len(builders) > 0 { + topoResp.ConnectedPeers["builders"] = builders + } + if len(searchers) > 0 { + topoResp.ConnectedPeers["searchers"] = searchers + } + + w.Header().Set("Content-Type", "application/json") + err := apiserver.WriteResponse(w, http.StatusOK, topoResp) + if err != nil { + logger.Error("error writing response", "err", err) + } +} diff --git a/pkg/discovery/discovery.go b/pkg/discovery/discovery.go index 69c8e783..087a1ed2 100644 --- a/pkg/discovery/discovery.go +++ b/pkg/discovery/discovery.go @@ -4,6 +4,7 @@ import ( "context" "log/slog" + "github.com/ethereum/go-ethereum/common" "github.com/primevprotocol/mev-commit/pkg/p2p" "github.com/primevprotocol/mev-commit/pkg/p2p/msgpack" "golang.org/x/sync/semaphore" @@ -17,19 +18,19 @@ const ( type P2PService interface { p2p.Streamer - p2p.Addressbook Connect(context.Context, []byte) (p2p.Peer, error) } type Topology interface { AddPeers(...p2p.Peer) + IsConnected(common.Address) bool } type Discovery struct { topo Topology streamer P2PService logger *slog.Logger - checkPeers chan PeerInfo + checkPeers chan p2p.PeerInfo sem *semaphore.Weighted quit chan struct{} } @@ -44,7 +45,7 @@ func New( streamer: streamer, logger: logger.With("protocol", ProtocolName), sem: semaphore.NewWeighted(checkWorkers), - checkPeers: make(chan PeerInfo), + checkPeers: make(chan p2p.PeerInfo), quit: make(chan struct{}), } go d.checkAndAddPeers() @@ -64,14 +65,8 @@ func (d *Discovery) Protocol() p2p.ProtocolSpec { } } -type PeerInfo struct { - ID string - PeerType string - Underlay []byte -} - type peersList struct { - Peers []PeerInfo + Peers []p2p.PeerInfo } func (d *Discovery) handlePeersList(ctx context.Context, peer p2p.Peer, s p2p.Stream) error { @@ -84,6 +79,9 @@ func (d *Discovery) handlePeersList(ctx context.Context, peer p2p.Peer, s p2p.St } for _, p := range peers.Peers { + if d.topo.IsConnected(p.EthAddress) { + continue + } select { case d.checkPeers <- p: case <-ctx.Done(): @@ -96,7 +94,11 @@ func (d *Discovery) handlePeersList(ctx context.Context, peer p2p.Peer, s p2p.St return nil } -func (d *Discovery) BroadcastPeers(ctx context.Context, peer p2p.Peer, peers []PeerInfo) error { +func (d *Discovery) BroadcastPeers( + ctx context.Context, + peer p2p.Peer, + peers []p2p.PeerInfo, +) error { stream, err := d.streamer.NewStream(ctx, peer, ProtocolName, ProtocolVersion, "peersList") if err != nil { d.logger.Error("failed to create stream", "err", err, "to_peer", peer) diff --git a/pkg/discovery/discovery_test.go b/pkg/discovery/discovery_test.go index 7944064d..c8f9605b 100644 --- a/pkg/discovery/discovery_test.go +++ b/pkg/discovery/discovery_test.go @@ -28,6 +28,18 @@ func (t *testTopo) AddPeers(peers ...p2p.Peer) { t.peers = append(t.peers, peers...) } +func (t *testTopo) IsConnected(addr common.Address) bool { + t.mu.Lock() + defer t.mu.Unlock() + + for _, p := range t.peers { + if p.EthAddress == addr { + return true + } + } + return false +} + func newTestLogger(w io.Writer) *slog.Logger { testLogger := slog.NewTextHandler(w, &slog.HandlerOptions{ Level: slog.LevelDebug, @@ -55,12 +67,6 @@ func TestDiscovery(t *testing.T) { } return client, nil }), - p2ptest.WithAddressbookFunc(func(p p2p.Peer) ([]byte, error) { - if p.EthAddress != client.EthAddress { - return nil, errors.New("invalid peer") - } - return []byte("test"), nil - }), ) topo := &testTopo{} @@ -74,11 +80,10 @@ func TestDiscovery(t *testing.T) { svc.SetPeerHandler(server, d.Protocol()) - err := d.BroadcastPeers(context.Background(), server, []discovery.PeerInfo{ + err := d.BroadcastPeers(context.Background(), server, []p2p.PeerInfo{ { - ID: common.HexToAddress("0x1").Hex(), - PeerType: p2p.PeerTypeBuilder.String(), - Underlay: []byte("test"), + EthAddress: common.HexToAddress("0x1"), + Underlay: []byte("test"), }, }) if err != nil { diff --git a/pkg/node/node.go b/pkg/node/node.go new file mode 100644 index 00000000..cef1cd08 --- /dev/null +++ b/pkg/node/node.go @@ -0,0 +1,100 @@ +package node + +import ( + "crypto/ecdsa" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + + "github.com/primevprotocol/mev-commit/pkg/apiserver" + "github.com/primevprotocol/mev-commit/pkg/debugapi" + "github.com/primevprotocol/mev-commit/pkg/discovery" + "github.com/primevprotocol/mev-commit/pkg/p2p" + "github.com/primevprotocol/mev-commit/pkg/p2p/libp2p" + "github.com/primevprotocol/mev-commit/pkg/register" + "github.com/primevprotocol/mev-commit/pkg/topology" +) + +type Options struct { + Version string + PrivKey *ecdsa.PrivateKey + Secret string + PeerType string + Logger *slog.Logger + P2PPort int + HTTPPort int + Bootnodes []string +} + +type Node struct { + closers []io.Closer +} + +func NewNode(opts *Options) (*Node, error) { + reg := register.New() + + minStake, err := reg.GetMinimumStake() + if err != nil { + return nil, err + } + + srv := apiserver.New(opts.Version, opts.Logger) + + closers := make([]io.Closer, 0) + peerType := p2p.FromString(opts.PeerType) + + p2pSvc, err := libp2p.New(&libp2p.Options{ + PrivKey: opts.PrivKey, + Secret: opts.Secret, + PeerType: peerType, + Register: reg, + MinimumStake: minStake, + Logger: opts.Logger, + ListenPort: opts.P2PPort, + MetricsReg: srv.MetricsRegistry(), + BootstrapAddrs: opts.Bootnodes, + }) + if err != nil { + return nil, err + } + closers = append(closers, p2pSvc) + + topo := topology.New(p2pSvc, opts.Logger) + disc := discovery.New(topo, p2pSvc, opts.Logger) + closers = append(closers, disc) + + // Set the announcer for the topology service + topo.SetAnnouncer(disc) + // Set the notifier for the p2p service + p2pSvc.SetNotifier(topo) + + // Register the discovery protocol with the p2p service + p2pSvc.AddProtocol(disc.Protocol()) + + debugapi.RegisterAPI(srv, topo, p2pSvc, opts.Logger) + + server := &http.Server{ + Addr: fmt.Sprintf(":%d", opts.HTTPPort), + Handler: srv.Router(), + } + + go func() { + if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + opts.Logger.Error("failed to start server", "err", err) + } + }() + closers = append(closers, server) + + return &Node{closers: closers}, nil +} + +func (n *Node) Close() error { + var err error + for _, c := range n.closers { + err = errors.Join(err, c.Close()) + } + + return err +} diff --git a/pkg/p2p/libp2p/bootstrapper.go b/pkg/p2p/libp2p/bootstrapper.go new file mode 100644 index 00000000..cdeddc59 --- /dev/null +++ b/pkg/p2p/libp2p/bootstrapper.go @@ -0,0 +1,40 @@ +package libp2p + +import ( + "context" + "time" + + "github.com/libp2p/go-libp2p/core/peer" +) + +func (p *Service) startBootstrapper(addrs []string) { + for { + for _, addr := range addrs { + addrInfo, err := peer.AddrInfoFromString(addr) + if err != nil { + p.logger.Error("failed to parse bootstrap address", "addr", addr, "err", err) + continue + } + + if _, connected := p.peers.isConnected(addrInfo.ID); connected { + p.logger.Debug("already connected to bootstrap peer", "peer", addrInfo.ID) + continue + } + + addrInfoBytes, err := addrInfo.MarshalJSON() + if err != nil { + p.logger.Error("failed to marshal bootstrap peer", "addr", addr, "err", err) + continue + } + + peer, err := p.Connect(context.Background(), addrInfoBytes) + if err != nil { + p.logger.Error("failed to connect to bootstrap peer", "addr", addr, "err", err) + continue + } + + p.logger.Info("connected to bootstrap peer", "peer", peer) + } + time.Sleep(1 * time.Minute) + } +} diff --git a/pkg/p2p/libp2p/export_test.go b/pkg/p2p/libp2p/export_test.go index ae8bdc55..078c6b7a 100644 --- a/pkg/p2p/libp2p/export_test.go +++ b/pkg/p2p/libp2p/export_test.go @@ -1,6 +1,8 @@ package libp2p import ( + "fmt" + "github.com/libp2p/go-libp2p/core/peer" "github.com/primevprotocol/mev-commit/pkg/p2p" ) @@ -20,3 +22,12 @@ func (s *Service) Peer() p2p.Peer { func (s *Service) HostID() peer.ID { return s.host.ID() } + +func (s *Service) AddrString() string { + fmt.Println(s.host.Addrs()) + return s.host.Addrs()[0].String() + "/p2p/" + s.host.ID().String() +} + +func (s *Service) PeerCount() int { + return len(s.host.Network().Peers()) +} diff --git a/pkg/p2p/libp2p/internal/handshake/handshake_test.go b/pkg/p2p/libp2p/internal/handshake/handshake_test.go index 5164d068..65e73dd5 100644 --- a/pkg/p2p/libp2p/internal/handshake/handshake_test.go +++ b/pkg/p2p/libp2p/internal/handshake/handshake_test.go @@ -5,6 +5,7 @@ import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" + "errors" "math/big" "testing" @@ -21,6 +22,10 @@ func (t *testRegister) GetStake(_ common.Address) (*big.Int, error) { return big.NewInt(5), nil } +func (t *testRegister) GetMinimumStake() (*big.Int, error) { + return nil, errors.New("not implemented") +} + type testSigner struct { address common.Address } diff --git a/pkg/p2p/libp2p/libp2p.go b/pkg/p2p/libp2p/libp2p.go index 4568041b..f020053e 100644 --- a/pkg/p2p/libp2p/libp2p.go +++ b/pkg/p2p/libp2p/libp2p.go @@ -39,14 +39,15 @@ type Service struct { } type Options struct { - PrivKey *ecdsa.PrivateKey - Secret string - PeerType p2p.PeerType - Register register.Register - MinimumStake *big.Int - ListenPort int - Logger *slog.Logger - MetricsReg *prometheus.Registry + PrivKey *ecdsa.PrivateKey + Secret string + PeerType p2p.PeerType + Register register.Register + MinimumStake *big.Int + ListenPort int + Logger *slog.Logger + MetricsReg *prometheus.Registry + BootstrapAddrs []string } func New(opts *Options) (*Service, error) { @@ -133,6 +134,10 @@ func New(opts *Options) (*Service, error) { host.Network().Notify(s.peers) s.host.SetStreamHandler(handshake.ProtocolID(), s.handleConnectReq) + + if len(opts.BootstrapAddrs) > 0 { + go s.startBootstrapper(opts.BootstrapAddrs) + } return s, nil } @@ -141,6 +146,10 @@ func (s *Service) Close() error { return s.host.Close() } +func (s *Service) SetNotifier(n p2p.Notifier) { + s.notifier = n +} + func (s *Service) handleConnectReq(streamlibp2p network.Stream) { peerID := streamlibp2p.Conn().RemotePeer() @@ -172,6 +181,15 @@ func (s *Service) disconnected(p p2p.Peer) { } } +func (s *Service) Self() map[string]interface{} { + return map[string]interface{}{ + "Ethereum Address": s.ethAddress.Hex(), + "Peer Type": s.peerType.String(), + "Underlay": s.host.ID().String(), + "Addresses": s.host.Addrs(), + } +} + func (s *Service) AddProtocol(spec p2p.ProtocolSpec) { for _, streamSpec := range spec.StreamSpecs { ss := streamSpec @@ -249,6 +267,7 @@ func (s *Service) Connect(ctx context.Context, info []byte) (p2p.Peer, error) { p, err := s.hsSvc.Handshake(ctx, addrInfo.ID, stream) if err != nil { + _ = s.host.Network().ClosePeer(addrInfo.ID) return p2p.Peer{}, err } diff --git a/pkg/p2p/libp2p/libp2p_test.go b/pkg/p2p/libp2p/libp2p_test.go index 055c6e14..d04abff3 100644 --- a/pkg/p2p/libp2p/libp2p_test.go +++ b/pkg/p2p/libp2p/libp2p_test.go @@ -7,7 +7,9 @@ import ( "log/slog" "math/big" "os" + "sync" "testing" + "time" "github.com/ethereum/go-ethereum/crypto" "github.com/libp2p/go-libp2p/core/peer" @@ -179,3 +181,89 @@ func TestP2PService(t *testing.T) { } }) } + +type testNotifier struct { + mu sync.Mutex + peers []p2p.Peer +} + +func (t *testNotifier) Connected(p p2p.Peer) { + t.mu.Lock() + defer t.mu.Unlock() + t.peers = append(t.peers, p) +} + +func (t *testNotifier) Disconnected(p p2p.Peer) { + t.mu.Lock() + defer t.mu.Unlock() + for i, peer := range t.peers { + if peer.EthAddress == p.EthAddress { + t.peers = append(t.peers[:i], t.peers[i+1:]...) + return + } + } +} + +func TestBootstrap(t *testing.T) { + testDefaultOptions := libp2p.Options{ + Secret: "test", + ListenPort: 0, + PeerType: p2p.PeerTypeBuilder, + Register: registermock.New(10), + MinimumStake: big.NewInt(5), + Logger: newTestLogger(t, os.Stdout), + } + + privKey, err := crypto.GenerateKey() + if err != nil { + t.Fatal(err) + } + + bnOpts := testDefaultOptions + bnOpts.PrivKey = privKey + bnOpts.PeerType = p2p.PeerTypeBootnode + + bootnode, err := libp2p.New(&bnOpts) + if err != nil { + t.Fatal(err) + } + + notifier := &testNotifier{} + bootnode.SetNotifier(notifier) + + privKey, err = crypto.GenerateKey() + if err != nil { + t.Fatal(err) + } + + n1Opts := testDefaultOptions + n1Opts.BootstrapAddrs = []string{bootnode.AddrString()} + n1Opts.PrivKey = privKey + + p1, err := libp2p.New(&n1Opts) + if err != nil { + t.Fatal(err) + } + + start := time.Now() + for { + if time.Since(start) > 10*time.Second { + t.Fatal("timed out waiting for peers to connect") + } + + if p1.PeerCount() == 1 { + if len(notifier.peers) != 1 { + t.Fatalf("expected 1 peer, got %d", len(notifier.peers)) + } + if notifier.peers[0].Type != p2p.PeerTypeBuilder { + t.Fatalf( + "expected peer type %s, got %s", + p2p.PeerTypeBuilder, notifier.peers[0].Type, + ) + } + break + } + + time.Sleep(100 * time.Millisecond) + } +} diff --git a/pkg/p2p/p2p.go b/pkg/p2p/p2p.go index 2353fa65..ef02d24c 100644 --- a/pkg/p2p/p2p.go +++ b/pkg/p2p/p2p.go @@ -56,6 +56,11 @@ type Peer struct { Type PeerType } +type PeerInfo struct { + EthAddress common.Address + Underlay []byte +} + type Stream interface { ReadMsg() ([]byte, error) WriteMsg([]byte) error diff --git a/pkg/register/mock/mock.go b/pkg/register/mock/mock.go index 85c589e9..380a03a7 100644 --- a/pkg/register/mock/mock.go +++ b/pkg/register/mock/mock.go @@ -18,3 +18,7 @@ func New(stake int64) register.Register { func (t *mockRegister) GetStake(_ common.Address) (*big.Int, error) { return big.NewInt(t.stake), nil } + +func (t *mockRegister) GetMinimumStake() (*big.Int, error) { + return big.NewInt(0), nil +} diff --git a/pkg/register/register.go b/pkg/register/register.go index ea955a05..710432fc 100644 --- a/pkg/register/register.go +++ b/pkg/register/register.go @@ -8,6 +8,22 @@ import ( // Register is the provider register used to query the contract type Register interface { + // GetMinimumStake returns the minimum stake required to be a provider + GetMinimumStake() (*big.Int, error) // GetStake returns stake of specified provider GetStake(provider common.Address) (*big.Int, error) } + +type register struct{} + +func New() Register { + return ®ister{} +} + +func (r *register) GetMinimumStake() (*big.Int, error) { + return big.NewInt(0), nil +} + +func (r *register) GetStake(provider common.Address) (*big.Int, error) { + return big.NewInt(0), nil +} diff --git a/pkg/topology/topology.go b/pkg/topology/topology.go index dc7c7e1e..8feeb1c3 100644 --- a/pkg/topology/topology.go +++ b/pkg/topology/topology.go @@ -1,6 +1,13 @@ package topology -import "github.com/primevprotocol/mev-commit/pkg/p2p" +import ( + "context" + "log/slog" + "sync" + + "github.com/ethereum/go-ethereum/common" + "github.com/primevprotocol/mev-commit/pkg/p2p" +) type Query struct { Type p2p.PeerType @@ -8,6 +15,132 @@ type Query struct { type Topology interface { p2p.Notifier + SetAnnouncer(Announcer) GetPeers(Query) []p2p.Peer AddPeers(...p2p.Peer) + IsConnected(common.Address) bool +} + +type Announcer interface { + BroadcastPeers(context.Context, p2p.Peer, []p2p.PeerInfo) error +} + +type topology struct { + mu sync.RWMutex + builders map[common.Address]p2p.Peer + searchers map[common.Address]p2p.Peer + logger *slog.Logger + addressbook p2p.Addressbook + announcer Announcer +} + +func New(a p2p.Addressbook, logger *slog.Logger) *topology { + return &topology{ + builders: make(map[common.Address]p2p.Peer), + searchers: make(map[common.Address]p2p.Peer), + addressbook: a, + logger: logger, + } +} + +func (t *topology) SetAnnouncer(a Announcer) { + t.announcer = a +} + +func (t *topology) Connected(p p2p.Peer) { + t.add(p) + + if t.announcer != nil { + // Whether its a builder or searcher, we want to broadcast the builder peers + peersToBroadcast := t.GetPeers(Query{Type: p2p.PeerTypeBuilder}) + var underlays []p2p.PeerInfo + for _, peer := range peersToBroadcast { + if peer.EthAddress == p.EthAddress { + continue + } + u, err := t.addressbook.GetPeerInfo(peer) + if err != nil { + t.logger.Error("failed to get peer info", "err", err, "peer", peer) + continue + } + underlays = append(underlays, p2p.PeerInfo{ + EthAddress: peer.EthAddress, + Underlay: u, + }) + } + + if len(underlays) == 0 { + t.logger.Warn("no underlays to broadcast", "peer", p) + return + } + err := t.announcer.BroadcastPeers(context.Background(), p, underlays) + if err != nil { + t.logger.Error("failed to broadcast peers", "err", err, "peer", p) + } + } +} + +func (t *topology) add(p p2p.Peer) { + t.mu.Lock() + defer t.mu.Unlock() + + switch p.Type { + case p2p.PeerTypeBuilder: + t.builders[p.EthAddress] = p + case p2p.PeerTypeSearcher: + t.searchers[p.EthAddress] = p + } +} + +func (t *topology) Disconnected(p p2p.Peer) { + t.mu.Lock() + defer t.mu.Unlock() + + switch p.Type { + case p2p.PeerTypeBuilder: + delete(t.builders, p.EthAddress) + case p2p.PeerTypeSearcher: + delete(t.searchers, p.EthAddress) + } +} + +func (t *topology) AddPeers(peers ...p2p.Peer) { + for _, p := range peers { + t.add(p) + } +} + +func (t *topology) GetPeers(q Query) []p2p.Peer { + t.mu.RLock() + defer t.mu.RUnlock() + + var peers []p2p.Peer + + switch q.Type { + case p2p.PeerTypeBuilder: + for _, p := range t.builders { + peers = append(peers, p) + } + case p2p.PeerTypeSearcher: + for _, p := range t.searchers { + peers = append(peers, p) + } + } + + return peers +} + +func (t *topology) IsConnected(addr common.Address) bool { + t.mu.RLock() + defer t.mu.RUnlock() + + if _, ok := t.builders[addr]; ok { + return true + } + + if _, ok := t.searchers[addr]; ok { + return true + } + + return false } diff --git a/pkg/topology/topology_test.go b/pkg/topology/topology_test.go new file mode 100644 index 00000000..def81099 --- /dev/null +++ b/pkg/topology/topology_test.go @@ -0,0 +1,130 @@ +package topology_test + +import ( + "context" + "errors" + "fmt" + "io" + "log/slog" + "os" + "sync" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/primevprotocol/mev-commit/pkg/p2p" + "github.com/primevprotocol/mev-commit/pkg/topology" +) + +type testAddressbook struct{} + +func (t *testAddressbook) GetPeerInfo(p p2p.Peer) ([]byte, error) { + return []byte("test"), nil +} + +type announcer struct { + mu sync.Mutex + broadcasts []p2p.Peer +} + +func (a *announcer) BroadcastPeers(_ context.Context, p p2p.Peer, peers []p2p.PeerInfo) error { + a.mu.Lock() + defer a.mu.Unlock() + + fmt.Println("broadcasting peers", p) + a.broadcasts = append(a.broadcasts, p) + + if len(peers) != 1 { + return errors.New("wrong number of peers") + } + + if string(peers[0].Underlay) != "test" { + return errors.New("wrong peer underlay") + } + + return nil +} + +func newTestLogger(w io.Writer) *slog.Logger { + testLogger := slog.NewTextHandler(w, &slog.HandlerOptions{ + Level: slog.LevelDebug, + }) + return slog.New(testLogger) +} + +func TestTopology(t *testing.T) { + t.Parallel() + + t.Run("ok", func(t *testing.T) { + topo := topology.New(&testAddressbook{}, newTestLogger(os.Stdout)) + announcer := &announcer{} + topo.SetAnnouncer(announcer) + + p1 := p2p.Peer{ + EthAddress: common.HexToAddress("0x1"), + Type: p2p.PeerTypeBuilder, + } + + s1 := p2p.Peer{ + EthAddress: common.HexToAddress("0x2"), + Type: p2p.PeerTypeSearcher, + } + + topo.Connected(p1) + + topo.Connected(s1) + + if len(announcer.broadcasts) != 1 { + t.Fatal("expected one broadcast") + } + + if announcer.broadcasts[0].EthAddress != s1.EthAddress { + t.Fatal("wrong peer") + } + + p2 := p2p.Peer{ + EthAddress: common.HexToAddress("0x3"), + Type: p2p.PeerTypeBuilder, + } + + topo.AddPeers(p2) + + for _, p := range []p2p.Peer{p1, s1, p2} { + if !topo.IsConnected(p.EthAddress) { + t.Fatal("peer not connected") + } + } + + peers := topo.GetPeers(topology.Query{Type: p2p.PeerTypeBuilder}) + if len(peers) != 2 { + t.Fatal("wrong number of peers") + } + + for _, p := range peers { + if p.Type != p2p.PeerTypeBuilder { + t.Fatal("wrong peer type") + } + if p.EthAddress != p1.EthAddress && p.EthAddress != p2.EthAddress { + t.Fatal("wrong peer") + } + } + + peers = topo.GetPeers(topology.Query{Type: p2p.PeerTypeSearcher}) + if len(peers) != 1 { + t.Fatal("wrong number of peers") + } + + if peers[0].Type != p2p.PeerTypeSearcher { + t.Fatal("wrong peer type") + } + + if peers[0].EthAddress != s1.EthAddress { + t.Fatal("wrong peer") + } + + topo.Disconnected(p1) + + if topo.IsConnected(p1.EthAddress) { + t.Fatal("peer still connected") + } + }) +}