From 03e0c857f29b3f4d5e083f9774ead49d3d215221 Mon Sep 17 00:00:00 2001 From: Mark van der Velden Date: Fri, 10 Apr 2020 10:54:55 +0200 Subject: [PATCH 1/2] Admitting you have a problem, is always the first step --- cmd/web/handlers.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cmd/web/handlers.go b/cmd/web/handlers.go index 3a09522..e26175b 100644 --- a/cmd/web/handlers.go +++ b/cmd/web/handlers.go @@ -56,6 +56,8 @@ func NewAutoCompleteHandler(logger logrus.FieldLogger, myFinder *finder.Finder) return } + // @todo this currently leaks information, even for the most obscure domains we might have learned. We should come + // up with a way to threshold that only the top performing domains are returned here. list, err := myFinder.GetMatchingPrefix(ctx, req.Domain, 10) if err != nil { log.WithError(err).Errorf("Error during lookup %s", err) From e87baf6e8b06565f8c65f2f3e37912f72f00996e Mon Sep 17 00:00:00 2001 From: Mark van der Velden Date: Mon, 13 Apr 2020 21:35:05 +0200 Subject: [PATCH 2/2] Adding a lower threshold setting for autocomplete --- cmd/web/config.toml | 6 + cmd/web/config/config.go | 6 +- cmd/web/erihttp/handlers/logger.go | 4 +- cmd/web/erihttp/server.go | 7 -- cmd/web/erihttp/types.go | 46 +++---- cmd/web/erihttp/util.go | 17 ++- cmd/web/erihttp/util_test.go | 87 +++++++++++-- cmd/web/handlers.go | 67 +++++++--- cmd/web/handlers_test.go | 194 +++++++++++++++++++++++++++++ cmd/web/hitlist/hitlist.go | 11 ++ cmd/web/hitlist/hitlist_test.go | 52 ++++++++ cmd/web/main.go | 2 +- cmd/web/util.go | 23 ++++ 13 files changed, 463 insertions(+), 59 deletions(-) create mode 100644 cmd/web/handlers_test.go diff --git a/cmd/web/config.toml b/cmd/web/config.toml index a725920..3234162 100644 --- a/cmd/web/config.toml +++ b/cmd/web/config.toml @@ -87,3 +87,9 @@ # amount of time a rate-limited request is allowed to wait for. Anything above this is aborted, to help protect # against connection draining. Requests are delayed parkedTTL = "100ms" + + [server.services.autocomplete] + + # Domains need at least (inclusive) this amount of recipients to be considered for the autocomplete API + # This is mostly to prevent leaking possibly private information + recipientThreshold = 1000 \ No newline at end of file diff --git a/cmd/web/config/config.go b/cmd/web/config/config.go index 04a6208..85285d6 100644 --- a/cmd/web/config/config.go +++ b/cmd/web/config/config.go @@ -53,7 +53,6 @@ type Config struct { } `toml:"log"` Hash struct { Key string `toml:"key"` - //Enable bool `toml:"enable"` } `toml:"hash"` Finder struct { UseBuckets bool `toml:"useBuckets" usage:"Buckets speedup matching, but assumes no mistakes are made at the start"` @@ -62,6 +61,11 @@ type Config struct { Resolver string `toml:"resolver" usage:"The resolver to use for DNS lookups"` SuggestValidator ValidatorType `toml:"suggest"` } `toml:"validator" flag:",inline" env:",inline"` + Services struct { + Autocomplete struct { + RecipientThreshold uint64 `usage:"Define the minimum amount of recipients a domain needs before allowed in the autocomplete"` + } `toml:"autocomplete"` + } `toml:"services"` Profiler struct { Enable bool `toml:"enable" default:"false"` Prefix string `toml:"prefix"` diff --git a/cmd/web/erihttp/handlers/logger.go b/cmd/web/erihttp/handlers/logger.go index cea747a..4799869 100644 --- a/cmd/web/erihttp/handlers/logger.go +++ b/cmd/web/erihttp/handlers/logger.go @@ -45,7 +45,9 @@ func WithRequestLogger(logger logrus.FieldLogger) HandlerWrapper { r = r.WithContext(context.WithValue(r.Context(), RequestID, rid)) - logger.Debug("Request start") + logger.WithFields(logrus.Fields{ + "content_length": r.ContentLength, + }).Debug("Request start") defer func(w *CustomResponseWriter) { diff --git a/cmd/web/erihttp/server.go b/cmd/web/erihttp/server.go index b76e592..2a08135 100644 --- a/cmd/web/erihttp/server.go +++ b/cmd/web/erihttp/server.go @@ -1,7 +1,6 @@ package erihttp import ( - "errors" "io" "log" "net/http" @@ -10,12 +9,6 @@ import ( "github.com/Dynom/ERI/cmd/web/config" ) -var ( - ErrMissingBody = errors.New("missing body") - ErrInvalidRequest = errors.New("request is invalid") - ErrBodyTooLarge = errors.New("request body too large") -) - func BuildHTTPServer(mux http.Handler, config config.Config, logWriter io.Writer, handlers ...func(h http.Handler) http.Handler) *http.Server { for _, h := range handlers { mux = h(mux) diff --git a/cmd/web/erihttp/types.go b/cmd/web/erihttp/types.go index 8912608..aa06848 100644 --- a/cmd/web/erihttp/types.go +++ b/cmd/web/erihttp/types.go @@ -1,22 +1,41 @@ package erihttp +import "errors" + +var ( + ErrMissingBody = errors.New("missing body") + ErrInvalidRequest = errors.New("request is invalid") + ErrBodyTooLarge = errors.New("request body too large") + ErrUnsupportedContentType = errors.New("unsupported content-type") +) + +type ERIResponse interface { + + // Hacking around Generics, like it's 1999. + PrepareResponse() +} + type AutoCompleteResponse struct { Suggestions []string `json:"suggestions"` + Error string `json:",omitempty"` } -type CheckResponse struct { - Valid bool `json:"valid"` - Reason string `json:"reason,omitempty"` - Alternative string `json:"alternative,omitempty"` +func (r *AutoCompleteResponse) PrepareResponse() { + if r.Suggestions == nil { + r.Suggestions = []string{} + } } type SuggestResponse struct { Alternatives []string `json:"alternatives"` MalformedSyntax bool `json:"malformed_syntax"` + Error string `json:",omitempty"` } -type ErrorResponse struct { - Error string `json:"error"` +func (r *SuggestResponse) PrepareResponse() { + if r.Alternatives == nil { + r.Alternatives = []string{} + } } type AutoCompleteRequest struct { @@ -26,18 +45,3 @@ type AutoCompleteRequest struct { type SuggestRequest struct { Email string `json:"email"` } - -type CheckRequest struct { - Email string `json:"email"` - Alternatives bool `json:"with_alternatives"` -} - -type LearnRequest struct { - Emails []ToLearn `json:"emails"` - Domains []ToLearn `json:"domains"` -} - -type ToLearn struct { - Value string `json:"value"` - Valid bool `json:"valid"` -} diff --git a/cmd/web/erihttp/util.go b/cmd/web/erihttp/util.go index 559003d..cbba4b4 100644 --- a/cmd/web/erihttp/util.go +++ b/cmd/web/erihttp/util.go @@ -6,15 +6,26 @@ import ( "net/http" ) +const ( + MaxBodySize int64 = 1 << 20 +) + func GetBodyFromHTTPRequest(r *http.Request) ([]byte, error) { var empty []byte - const maxSizePlusOne int64 = 1<<20 + 1 if r.Body == nil { return empty, ErrMissingBody } - b, err := ioutil.ReadAll(io.LimitReader(r.Body, maxSizePlusOne)) + if r.ContentLength > MaxBodySize { + return empty, ErrBodyTooLarge + } + + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + return empty, ErrUnsupportedContentType + } + + b, err := ioutil.ReadAll(io.LimitReader(r.Body, MaxBodySize+1)) if err != nil { if err == io.EOF { return empty, ErrMissingBody @@ -22,7 +33,7 @@ func GetBodyFromHTTPRequest(r *http.Request) ([]byte, error) { return empty, ErrInvalidRequest } - if int64(len(b)) == maxSizePlusOne { + if int64(len(b)) > MaxBodySize { return empty, ErrBodyTooLarge } diff --git a/cmd/web/erihttp/util_test.go b/cmd/web/erihttp/util_test.go index bd1b750..6f45396 100644 --- a/cmd/web/erihttp/util_test.go +++ b/cmd/web/erihttp/util_test.go @@ -1,28 +1,97 @@ package erihttp import ( + "bytes" + "math" "net/http" + "net/http/httptest" "reflect" + "strings" "testing" ) func TestGetBodyFromHTTPRequest(t *testing.T) { - type args struct { - r *http.Request - } tests := []struct { name string - args args + req func(body []byte) *http.Request want []byte - wantErr bool + wantErr error }{ - // TODO: Add test cases. + { + wantErr: nil, + name: "All good", + req: func(body []byte) *http.Request { + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + return req + }, + want: []byte("{}"), + }, + { + wantErr: ErrMissingBody, + name: "Nil body", + req: func(_ []byte) *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Content-Type", "application/json") + req.Body = nil + + return req + }, + want: nil, + }, + { + wantErr: ErrBodyTooLarge, + name: "Too lengthy/Content-Length", + req: func(_ []byte) *http.Request { + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("")) + req.Header.Set("Content-Type", "application/json") + req.ContentLength = math.MaxInt64 + return req + }, + want: nil, + }, + { + wantErr: ErrBodyTooLarge, + name: "Too lengthy/Body", + req: func(_ []byte) *http.Request { + body := strings.Repeat("a", int(MaxBodySize+1)) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.ContentLength = int64(len(body) - 1) + + return req + }, + want: nil, + }, + { + wantErr: ErrUnsupportedContentType, + name: "Content-Type/Missing", + req: func(_ []byte) *http.Request { + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("")) + req.Header.Del("Content-Type") + return req + }, + want: nil, + }, + { + wantErr: ErrUnsupportedContentType, + name: "Content-Type/Wrong", + req: func(_ []byte) *http.Request { + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("")) + req.Header.Set("Content-Type", "plain/text") + return req + }, + want: nil, + }, } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := GetBodyFromHTTPRequest(tt.args.r) - if (err != nil) != tt.wantErr { - t.Errorf("GetBodyFromHTTPRequest() error = %v, wantErr %v", err, tt.wantErr) + req := tt.req(tt.want) + got, err := GetBodyFromHTTPRequest(req) + + if err != tt.wantErr { + t.Errorf("GetBodyFromHTTPRequest() error = %v, wantErr %q", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { diff --git a/cmd/web/handlers.go b/cmd/web/handlers.go index e26175b..8ada1e7 100644 --- a/cmd/web/handlers.go +++ b/cmd/web/handlers.go @@ -7,6 +7,7 @@ import ( "net/http" "time" + "github.com/Dynom/ERI/cmd/web/hitlist" "github.com/Dynom/ERI/validator" "github.com/Dynom/ERI/cmd/web/erihttp/handlers" @@ -19,56 +20,90 @@ import ( "github.com/sirupsen/logrus" ) -func NewAutoCompleteHandler(logger logrus.FieldLogger, myFinder *finder.Finder) http.HandlerFunc { +func NewAutoCompleteHandler(logger logrus.FieldLogger, myFinder *finder.Finder, hitList *hitlist.HitList, recipientThreshold uint64) http.HandlerFunc { + + const ( + maxSuggestions = 5 + + FailedRequestError = "Request failed, unable to parse request body. Expected JSON." + DomainLookupFailedError = "Request failed, unable to lookup by domain." + FailedResponseError = "Generating response failed." + ) log := logger.WithField("handler", "auto complete") return func(w http.ResponseWriter, r *http.Request) { var err error var req erihttp.AutoCompleteRequest - log := log.WithField(handlers.RequestID.String(), r.Context().Value(handlers.RequestID)) + log = log.WithField(handlers.RequestID.String(), r.Context().Value(handlers.RequestID)) defer deferClose(r.Body, log) body, err := erihttp.GetBodyFromHTTPRequest(r) if err != nil { - log.WithError(err).Errorf("Error handling request %s", err) + log.WithFields(logrus.Fields{ + "error": err, + "content_length": r.ContentLength, + }).Errorf("Error handling request %s", err) + w.WriteHeader(http.StatusBadRequest) - _, _ = w.Write([]byte("Request failed")) + + // err is expected to be safe to expose to the client + writeErrorJSONResponse(logger, w, &erihttp.AutoCompleteResponse{Error: err.Error()}) return } err = json.Unmarshal(body, &req) if err != nil { log.WithError(err).Errorf("Error handling request body %s", err) + w.WriteHeader(http.StatusBadRequest) - _, _ = w.Write([]byte("Request failed, unable to parse request body. Did you send JSON?")) + writeErrorJSONResponse(log, w, &erihttp.AutoCompleteResponse{Error: FailedRequestError}) return } ctx, cancel := context.WithTimeout(r.Context(), time.Millisecond*500) defer cancel() - if len(req.Domain) == 0 { - log.Error("Empty argument") + if req.Domain == "" { + log.Debug("Empty argument") w.WriteHeader(http.StatusBadRequest) - _, _ = w.Write([]byte("Request failed, unable to lookup by domain")) + writeErrorJSONResponse(log, w, &erihttp.AutoCompleteResponse{Error: DomainLookupFailedError}) return } - // @todo this currently leaks information, even for the most obscure domains we might have learned. We should come - // up with a way to threshold that only the top performing domains are returned here. - list, err := myFinder.GetMatchingPrefix(ctx, req.Domain, 10) + list, err := myFinder.GetMatchingPrefix(ctx, req.Domain, maxSuggestions*2) if err != nil { - log.WithError(err).Errorf("Error during lookup %s", err) - w.WriteHeader(http.StatusInternalServerError) - _, _ = w.Write([]byte("Request failed, unable to lookup by domain")) + log.WithError(err).Warn("Error during lookup") + w.WriteHeader(http.StatusBadRequest) + writeErrorJSONResponse(log, w, &erihttp.AutoCompleteResponse{Error: DomainLookupFailedError}) return } + // Filter the list, so that we don't leak rarely used domain names. This might lead to privacy problems with personal + // domain names for example + var filteredList = make([]string, 0, maxSuggestions) + for _, domain := range list { + if ctx.Err() != nil { + w.WriteHeader(http.StatusBadRequest) + + // @todo Is this a safe error to "leak" ? + writeErrorJSONResponse(log, w, &erihttp.AutoCompleteResponse{Error: ctx.Err().Error()}) + return + } + + if cnt := hitList.GetRecipientCount(hitlist.Domain(domain)); cnt >= recipientThreshold { + filteredList = append(filteredList, domain) + if len(filteredList) >= maxSuggestions { + break + } + } + } + response, err := json.Marshal(erihttp.AutoCompleteResponse{ - Suggestions: list, + Suggestions: filteredList, }) + if err != nil { log.WithFields(logrus.Fields{ "response": response, @@ -76,7 +111,7 @@ func NewAutoCompleteHandler(logger logrus.FieldLogger, myFinder *finder.Finder) }).Error("Failed to marshal the response") w.WriteHeader(http.StatusInternalServerError) - _, _ = w.Write([]byte("Unable to produce a response")) + writeErrorJSONResponse(log, w, &erihttp.AutoCompleteResponse{Error: FailedResponseError}) return } diff --git a/cmd/web/handlers_test.go b/cmd/web/handlers_test.go new file mode 100644 index 0000000..13d15c2 --- /dev/null +++ b/cmd/web/handlers_test.go @@ -0,0 +1,194 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "reflect" + "strings" + "testing" + "time" + + "github.com/Dynom/ERI/cmd/web/erihttp" + "github.com/Dynom/ERI/cmd/web/hitlist" + "github.com/Dynom/ERI/cmd/web/services" + "github.com/Dynom/TySug/finder" + "github.com/sirupsen/logrus" + testLog "github.com/sirupsen/logrus/hooks/test" +) + +func TestNewAutoCompleteHandler(t *testing.T) { + logger, hook := testLog.NewNullLogger() + _ = hook + + refs := []string{ + "a", "b", "c", "d", + // Testing for > 5 matches + "exam", "example", "examination", "excalibur", "exceptional", "extra", + } + + myFinder, err := finder.New(refs, finder.WithAlgorithm(finder.NewJaroWinklerDefaults())) + if err != nil { + t.Errorf("Test setup failed, %s", err) + t.FailNow() + } + + hitList := hitlist.New(nil, time.Minute*1) + + validRequest := erihttp.AutoCompleteRequest{ + Domain: "ex", + } + + validRequestBody, err := json.Marshal(&validRequest) + if err != nil { + t.Errorf("Test setup failed, %s", err) + t.FailNow() + } + + emptyArgumentValidStructureRequest := erihttp.AutoCompleteRequest{} + emptyArgumentValidStructureRequestBody, err := json.Marshal(&emptyArgumentValidStructureRequest) + if err != nil { + t.Errorf("Test setup failed, %s", err) + t.FailNow() + } + + expiredContext, c := context.WithCancel(context.Background()) + c() + + type wants struct { + statusCode int + } + tests := []struct { + name string + requestBody io.Reader + ctx context.Context + want wants + }{ + { + name: "correct POST body", + requestBody: bytes.NewReader(validRequestBody), + ctx: context.Background(), + want: wants{ + statusCode: 200, + }, + }, + { + name: "malformed POST body", + requestBody: strings.NewReader("burp"), + ctx: context.Background(), + want: wants{ + statusCode: 400, + }, + }, + { + name: "nil POST body", + requestBody: nil, + ctx: context.Background(), + want: wants{ + statusCode: 400, + }, + }, + { + name: "Too large POST body", + requestBody: strings.NewReader(strings.Repeat(".", int(erihttp.MaxBodySize)+1)), + ctx: context.Background(), + want: wants{ + statusCode: 400, + }, + }, + { + name: "Bad JSON", + requestBody: bytes.NewReader(validRequestBody[0 : len(validRequestBody)-1]), // stripping off the '}' + ctx: context.Background(), + want: wants{ + statusCode: 400, + }, + }, + { + name: "Empty input", + requestBody: bytes.NewReader(emptyArgumentValidStructureRequestBody), + ctx: context.Background(), + want: wants{ + statusCode: 400, + }, + }, + { + name: "Bad context", + requestBody: bytes.NewReader(validRequestBody), + ctx: expiredContext, + want: wants{ + statusCode: 400, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hook.Reset() + handlerFunc := NewAutoCompleteHandler(logger, myFinder, hitList, 0) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/", tt.requestBody) + req = req.WithContext(tt.ctx) + req.Header.Set("Content-Type", "application/json") + + handlerFunc.ServeHTTP(rec, req) + + b, _ := ioutil.ReadAll(rec.Result().Body) + t.Logf("Body: %s", b) + + if tt.want.statusCode != rec.Code { + t.Errorf("NewAutoCompleteHandler() = %+v, want %+v", rec, tt.want) + for _, l := range hook.AllEntries() { + t.Logf("Logs: %s", l.Message) + t.Logf("Meta: %v", l.Data) + } + } + }) + } +} + +func TestNewHealthHandler(t *testing.T) { + type args struct { + logger logrus.FieldLogger + } + tests := []struct { + name string + args args + want http.HandlerFunc + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := NewHealthHandler(tt.args.logger); !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewHealthHandler() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNewSuggestHandler(t *testing.T) { + type args struct { + logger logrus.FieldLogger + svc services.SuggestSvc + } + tests := []struct { + name string + args args + want http.HandlerFunc + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := NewSuggestHandler(tt.args.logger, tt.args.svc); !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewSuggestHandler() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/cmd/web/hitlist/hitlist.go b/cmd/web/hitlist/hitlist.go index 0ce0dbb..57c94b6 100644 --- a/cmd/web/hitlist/hitlist.go +++ b/cmd/web/hitlist/hitlist.go @@ -102,6 +102,17 @@ func (hl *HitList) GetValidAndUsageSortedDomains() []string { return domains } +// GetRecipientCount returns the amount of recipients known for a domain +func (hl *HitList) GetRecipientCount(d Domain) (amount uint64) { + hl.lock.RLock() + if hit, exists := hl.hits[d]; exists { + amount = uint64(len(hit.Recipients)) + } + hl.lock.RUnlock() + + return +} + // AddInternalParts adds values considered "safe". Typically you would only use this on provisioning HitList from a storage layer func (hl *HitList) AddInternalParts(domain Domain, recipient Recipient, vr validator.Result, duration time.Duration) error { diff --git a/cmd/web/hitlist/hitlist_test.go b/cmd/web/hitlist/hitlist_test.go index 776a76a..fb1fa98 100644 --- a/cmd/web/hitlist/hitlist_test.go +++ b/cmd/web/hitlist/hitlist_test.go @@ -726,3 +726,55 @@ func TestHitList_GetInternalTypes(t *testing.T) { }) } } + +func TestHitList_GetRecipientCount(t *testing.T) { + + tests := []struct { + name string + toAdd []types.EmailParts + domain Domain + wantAmount uint64 + }{ + { + name: "basics", + toAdd: []types.EmailParts{ + types.NewEmailFromParts("john", "example.org"), + }, + domain: "example.org", + wantAmount: 1, + }, + { + name: "multiple", + toAdd: []types.EmailParts{ + types.NewEmailFromParts("john", "example.org"), + types.NewEmailFromParts("jane", "example.org"), + }, + domain: "example.org", + wantAmount: 2, + }, + { + name: "no domain match", + toAdd: []types.EmailParts{ + types.NewEmailFromParts("john", "example.org"), + }, + domain: "a", + wantAmount: 0, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hl := New(mockHasher{}, time.Second*1) + for _, a := range tt.toAdd { + err := hl.Add(a, validator.Result{}) + if err != nil { + t.Errorf("Preparing test failed %s", err) + t.FailNow() + } + } + + if gotAmount := hl.GetRecipientCount(tt.domain); gotAmount != tt.wantAmount { + t.Errorf("GetRecipientCount() = %v, want %v", gotAmount, tt.wantAmount) + } + }) + } +} diff --git a/cmd/web/main.go b/cmd/web/main.go index bf7a3bb..74a2108 100644 --- a/cmd/web/main.go +++ b/cmd/web/main.go @@ -119,7 +119,7 @@ func main() { registerHealthHandler(mux, logger) mux.HandleFunc("/suggest", NewSuggestHandler(logger, suggestSvc)) - mux.HandleFunc("/autocomplete", NewAutoCompleteHandler(logger, myFinder)) + mux.HandleFunc("/autocomplete", NewAutoCompleteHandler(logger, myFinder, hitList, conf.Server.Services.Autocomplete.RecipientThreshold)) schema, err := NewGraphQLSchema(&suggestSvc) if err != nil { diff --git a/cmd/web/util.go b/cmd/web/util.go index abb5624..87af3a7 100644 --- a/cmd/web/util.go +++ b/cmd/web/util.go @@ -3,6 +3,7 @@ package main import ( "context" "database/sql" + "encoding/json" "fmt" "io" "net" @@ -13,6 +14,7 @@ import ( "time" gcppubsub "cloud.google.com/go/pubsub" + "github.com/Dynom/ERI/cmd/web/erihttp" "github.com/Dynom/ERI/cmd/web/hitlist" "github.com/Dynom/ERI/cmd/web/persister" "github.com/Dynom/ERI/cmd/web/pubsub" @@ -270,3 +272,24 @@ func createPubSubSvc(conf config.Config, logger logrus.FieldLogger, rt *runtimer return pubSubSvc, nil } + +// writeErrorJSONResponse Sets the error on a response and writes it with the corresponding Content-Type +func writeErrorJSONResponse(logger logrus.FieldLogger, w http.ResponseWriter, responseType erihttp.ERIResponse) { + + responseType.PrepareResponse() + response, err := json.Marshal(responseType) + if err != nil { + logger.WithError(err).Error("Failed to marshal the response") + response = []byte(`{"error":""}`) + } + + w.Header().Set("Content-Type", "application/json") + c, err := w.Write(response) + if err != nil { + logger.WithFields(logrus.Fields{ + "error": err, + "bytes_written": c, + }).Error("Failed to write response") + return + } +}