diff options
| author | Aleksandr Nogikh <nogikh@google.com> | 2023-11-02 19:06:05 +0100 |
|---|---|---|
| committer | Aleksandr Nogikh <nogikh@google.com> | 2023-11-03 13:05:13 +0000 |
| commit | 500bfdc41735bc8d617cbfd4f1ab6b5980c8f1e5 (patch) | |
| tree | 531185e72a0efd70cc4c0a7566712ec9c8807f44 | |
| parent | 7e9533b74d64ae2782ac03bfdc01f7f84840be9e (diff) | |
dashboard: throttle incoming requests
To ensure service stability, let's rate limit incoming requests to our
web endpoints.
| -rw-r--r-- | dashboard/app/cache.go | 67 | ||||
| -rw-r--r-- | dashboard/app/config.go | 20 | ||||
| -rw-r--r-- | dashboard/app/handler.go | 39 | ||||
| -rw-r--r-- | dashboard/app/main_test.go | 51 | ||||
| -rw-r--r-- | dashboard/app/util_test.go | 1 |
5 files changed, 178 insertions, 0 deletions
diff --git a/dashboard/app/cache.go b/dashboard/app/cache.go index f54f95c61..14c8169c8 100644 --- a/dashboard/app/cache.go +++ b/dashboard/app/cache.go @@ -7,8 +7,10 @@ import ( "encoding/json" "fmt" "net/http" + "sort" "time" + "github.com/google/syzkaller/pkg/hash" "github.com/google/syzkaller/pkg/image" "golang.org/x/net/context" "google.golang.org/appengine/v2" @@ -254,3 +256,68 @@ func cachedObjectList[T any](c context.Context, key string, period time.Duration } return obj, nil } + +type RequesterInfo struct { + Requests []time.Time +} + +func (ri *RequesterInfo) Record(now time.Time, cfg ThrottleConfig) bool { + var newRequests []time.Time + for _, req := range ri.Requests { + if now.Sub(req) >= cfg.Window { + continue + } + newRequests = append(newRequests, req) + } + newRequests = append(newRequests, now) + sort.Slice(ri.Requests, func(i, j int) bool { return ri.Requests[i].Before(ri.Requests[j]) }) + // Don't store more than needed. + if len(newRequests) > cfg.Limit+1 { + newRequests = newRequests[len(newRequests)-(cfg.Limit+1):] + } + ri.Requests = newRequests + // Check that we satisfy the conditions. + return len(newRequests) <= cfg.Limit +} + +func ThrottleRequest(c context.Context, requesterID string) (bool, error) { + cfg := getConfig(c).Throttle + if cfg.Empty() || requesterID == "" { + // No sense to query memcached. + return true, nil + } + key := fmt.Sprintf("requester-%s", hash.String([]byte(requesterID))) + const attempts = 5 + for i := 0; i < attempts; i++ { + var obj RequesterInfo + item, err := memcache.Gob.Get(c, key, &obj) + if err == memcache.ErrCacheMiss { + ok := obj.Record(timeNow(c), cfg) + err = memcache.Gob.Add(c, &memcache.Item{ + Key: key, + Object: obj, + Expiration: cfg.Window, + }) + if err == memcache.ErrNotStored { + // Conflict with another instance. Retry. + continue + } + return ok, err + } else if err != nil { + return false, err + } + // Update the existing object. + ok := obj.Record(timeNow(c), cfg) + item.Expiration = cfg.Window + item.Object = obj + err = memcache.Gob.CompareAndSwap(c, item) + if err == memcache.ErrCASConflict { + // Update conflict. Retry. + continue + } else if err != nil { + return false, err + } + return ok, nil + } + return false, fmt.Errorf("all attempts to record request failed") +} diff --git a/dashboard/app/config.go b/dashboard/app/config.go index 7a901a181..0fd47d2d0 100644 --- a/dashboard/app/config.go +++ b/dashboard/app/config.go @@ -62,6 +62,8 @@ type GlobalConfig struct { // Emails received via the addresses below will be attributed to the corresponding // kind of Discussion. DiscussionEmails []DiscussionEmailConfig + // Incoming request throttling. + Throttle ThrottleConfig } // Per-namespace config. @@ -306,6 +308,18 @@ type KcidbConfig struct { Credentials []byte } +// ThrottleConfig determines how many requests a single client can make in a period of time. +type ThrottleConfig struct { + // The time period to be considered. + Window time.Duration + // No more than Limit requests are allowed within the time window. + Limit int +} + +func (t ThrottleConfig) Empty() bool { + return t.Window == 0 || t.Limit == 0 +} + var ( namespaceNameRe = regexp.MustCompile("^[a-zA-Z0-9-_.]{4,32}$") clientNameRe = regexp.MustCompile("^[a-zA-Z0-9-_.]{4,100}$") @@ -409,6 +423,12 @@ func checkConfig(cfg *GlobalConfig) { for i := range cfg.EmailBlocklist { cfg.EmailBlocklist[i] = email.CanonicalEmail(cfg.EmailBlocklist[i]) } + if cfg.Throttle.Limit < 0 { + panic("throttle limit cannot be negative") + } + if (cfg.Throttle.Limit != 0) != (cfg.Throttle.Window != 0) { + panic("throttling window and limit must be both set") + } namespaces := make(map[string]bool) clientNames := make(map[string]bool) checkClients(clientNames, cfg.Clients) diff --git a/dashboard/app/handler.go b/dashboard/app/handler.go index 98827f8ca..7162bd85c 100644 --- a/dashboard/app/handler.go +++ b/dashboard/app/handler.go @@ -33,6 +33,9 @@ func handleContext(fn contextHandler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c := appengine.NewContext(r) c = context.WithValue(c, ¤tURLKey, r.URL.RequestURI()) + if !throttleRequest(c, w, r) { + return + } if err := fn(c, w, r); err != nil { hdr := commonHeaderRaw(c, r) data := &struct { @@ -75,6 +78,42 @@ func handleContext(fn contextHandler) http.Handler { }) } +func throttleRequest(c context.Context, w http.ResponseWriter, r *http.Request) bool { + // AppEngine removes all App Engine-specific headers, which include + // X-Appengine-User-IP and X-Forwarded-For. + // https://cloud.google.com/appengine/docs/standard/reference/request-headers?tab=python#removed_headers + ip := r.Header.Get("X-Appengine-User-IP") + if ip == "" { + ip = r.Header.Get("X-Forwarded-For") + ip, _, _ = strings.Cut(ip, ",") // X-Forwarded-For is a comma-delimited list. + ip = strings.TrimSpace(ip) + } + cron := r.Header.Get("X-Appengine-Cron") != "" + if ip == "" || cron { + log.Infof(c, "cannot throttle request from %q, cron %t", ip, cron) + return true + } + accept, err := ThrottleRequest(c, ip) + if err != nil { + log.Errorf(c, "failed to throttle: %v", err) + } + log.Infof(c, "throttling for %q: %t", ip, accept) + if !accept { + http.Error(w, throttlingErrorMessage(c), http.StatusTooManyRequests) + return false + } + return true +} + +func throttlingErrorMessage(c context.Context) string { + ret := "429 Too Many Requests" + email := getConfig(c).ContactEmail + if email == "" { + return ret + } + return fmt.Sprintf("%s\nPlease contact us at %s if you need access to our data.", ret, email) +} + var currentURLKey = "the URL of the HTTP request in context" func getCurrentURL(c context.Context) string { diff --git a/dashboard/app/main_test.go b/dashboard/app/main_test.go index 84cf888dc..bb9895f90 100644 --- a/dashboard/app/main_test.go +++ b/dashboard/app/main_test.go @@ -12,6 +12,7 @@ import ( "github.com/google/syzkaller/dashboard/dashapi" "github.com/stretchr/testify/assert" + "golang.org/x/net/context" ) func TestOnlyManagerFilter(t *testing.T) { @@ -354,3 +355,53 @@ func TestSubsystemsPageRedirect(t *testing.T) { c.expectEQ(httpErr.Code, http.StatusMovedPermanently) c.expectEQ(httpErr.Headers["Location"], []string{"/access-public-email/s/subsystemA"}) } + +func TestNoThrottle(t *testing.T) { + c := NewCtx(t) + defer c.Close() + + assert.True(t, c.config().Throttle.Empty()) + for i := 0; i < 10; i++ { + c.advanceTime(time.Millisecond) + _, err := c.AuthGET(AccessPublic, "/access-public-email") + c.expectOK(err) + } +} + +func TestThrottle(t *testing.T) { + c := NewCtx(t) + defer c.Close() + + c.transformContext = func(c context.Context) context.Context { + newConfig := *getConfig(c) + newConfig.Throttle = ThrottleConfig{ + Window: 10 * time.Second, + Limit: 10, + } + return contextWithConfig(c, &newConfig) + } + + // Adhere to the limit. + for i := 0; i < 15; i++ { + c.advanceTime(time.Second) + _, err := c.AuthGET(AccessPublic, "/access-public-email") + c.expectOK(err) + } + + // Break the limit. + c.advanceTime(time.Millisecond) + _, err := c.AuthGET(AccessPublic, "/access-public-email") + var httpErr *HTTPError + c.expectTrue(errors.As(err, &httpErr)) + c.expectEQ(httpErr.Code, http.StatusTooManyRequests) + + // Still too frequent requests. + c.advanceTime(time.Millisecond) + _, err = c.AuthGET(AccessPublic, "/access-public-email") + c.expectTrue(err != nil) + + // Wait a bit. + c.advanceTime(3 * time.Second) + _, err = c.AuthGET(AccessPublic, "/access-public-email") + c.expectOK(err) +} diff --git a/dashboard/app/util_test.go b/dashboard/app/util_test.go index ba4a953d3..41b2032f0 100644 --- a/dashboard/app/util_test.go +++ b/dashboard/app/util_test.go @@ -316,6 +316,7 @@ func (c *Ctx) httpRequest(method, url, body string, access AccessLevel) (*httpte if err != nil { c.t.Fatal(err) } + r.Header.Add("X-Appengine-User-IP", "127.0.0.1") r = registerRequest(r, c) r = r.WithContext(c.transformContext(r.Context())) if access == AccessAdmin || access == AccessUser { |
