// Copyright 2017 syzkaller project authors. All rights reserved. // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. // The test uses aetest package that starts local dev_appserver and handles all requests locally: // https://cloud.google.com/appengine/docs/standard/go/tools/localunittesting/reference package main import ( "bytes" "context" "errors" "fmt" "io" "net/http" "net/http/httptest" "net/url" "os" "os/exec" "path/filepath" "reflect" "runtime" "slices" "strings" "sync" "sync/atomic" "testing" "time" "cloud.google.com/go/spanner/admin/database/apiv1" "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" "github.com/google/go-cmp/cmp" "github.com/google/syzkaller/dashboard/api" "github.com/google/syzkaller/dashboard/app/aidb" "github.com/google/syzkaller/dashboard/dashapi" "github.com/google/syzkaller/pkg/coveragedb/spannerclient" "github.com/google/syzkaller/pkg/covermerger" "github.com/google/syzkaller/pkg/email" "github.com/google/syzkaller/pkg/subsystem" spannertest "github.com/google/syzkaller/syz-cluster/pkg/db" "github.com/stretchr/testify/require" "google.golang.org/appengine/v2" "google.golang.org/appengine/v2/aetest" db "google.golang.org/appengine/v2/datastore" "google.golang.org/appengine/v2/log" aemail "google.golang.org/appengine/v2/mail" ) type Ctx struct { t *testing.T inst aetest.Instance ctx context.Context mockedTime time.Time emailSink chan *aemail.Message transformContext func(context.Context) context.Context globalClient *apiClient agentClient *apiClient client *apiClient client2 *apiClient publicClient *apiClient aiClient *apiClient checkAI bool } var skipDevAppserverTests = func() bool { _, err := exec.LookPath("dev_appserver.py") // Don't silently skip tests on CI, we should have gcloud sdk installed there. return err != nil && os.Getenv("SYZ_ENV") == "" || os.Getenv("SYZ_SKIP_DEV_APPSERVER_TESTS") != "" }() func NewCtx(t *testing.T) *Ctx { return newCtx(t, "") } func newCtx(t *testing.T, appID string) *Ctx { if skipDevAppserverTests { t.Skip("skipping test (no dev_appserver.py)") } t.Parallel() inst, err := aetest.NewInstance(&aetest.Options{ AppID: appID, StartupTimeout: 120 * time.Second, // Without this option datastore queries return data with slight delay, // which fails reporting tests. StronglyConsistentDatastore: true, }) if err != nil { t.Fatal(err) } r, err := inst.NewRequest("GET", "", nil) if err != nil { t.Fatal(err) } ctx := &Ctx{ t: t, inst: inst, mockedTime: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), emailSink: make(chan *aemail.Message, 100), transformContext: func(ctx context.Context) context.Context { return ctx }, checkAI: appID != "", } ctx.globalClient = ctx.makeClient(reportingClient, reportingKey, true) ctx.agentClient = ctx.makeClient(agentClient, agentKey, true) ctx.client = ctx.makeClient(client1, password1, true) ctx.client2 = ctx.makeClient(client2, password2, true) ctx.publicClient = ctx.makeClient(clientPublicEmail, keyPublicEmail, true) ctx.aiClient = ctx.makeClient(clientAI, keyAI, true) ctx.ctx = registerRequest(r, ctx).Context() return ctx } var appIDSeq = uint32(0) func NewSpannerCtx(t *testing.T) *Ctx { ddlStatements, err := loadUpDDLStatements() if err != nil { t.Fatal(err) } // The code uses AppID as the spanner database URI project. // So to give each test a private isolated instance of the spanner database, // we give each test that uses spanner an unique AppID. appID := fmt.Sprintf("testapp-%v", atomic.AddUint32(&appIDSeq, 1)) uri := fmt.Sprintf("projects/%s/instances/%v/databases/%v", appID, aidb.Instance, aidb.Database) spannertest.NewTestDB(t, uri, ddlStatements) return newCtx(t, appID) } func executeSpannerDDL(ctx context.Context, statements []string) error { dbAdmin, err := database.NewDatabaseAdminClient(ctx) if err != nil { return fmt.Errorf("failed NewDatabaseAdminClient: %w", err) } defer dbAdmin.Close() dbOp, err := dbAdmin.UpdateDatabaseDdl(ctx, &databasepb.UpdateDatabaseDdlRequest{ Database: fmt.Sprintf("projects/%s/instances/%v/databases/%v", appengine.AppID(ctx), aidb.Instance, aidb.Database), Statements: statements, }) if err != nil { return fmt.Errorf("failed UpdateDatabaseDdl: %w", err) } if err := dbOp.Wait(ctx); err != nil { return fmt.Errorf("failed UpdateDatabaseDdl: %w", err) } return nil } func loadUpDDLStatements() ([]string, error) { return loadDDLStatements("*.up.sql", 1) } func loadDownDDLStatements() ([]string, error) { return loadDDLStatements("*.down.sql", -1) } func loadDDLStatements(wildcard string, sortOrder int) ([]string, error) { files, err := filepath.Glob(filepath.Join("aidb", "migrations", wildcard)) if err != nil { return nil, err } if len(files) == 0 { return nil, fmt.Errorf("loadDDLStatements: wildcard did not match any files: %q", wildcard) } // We prefix DDL file names with sequence numbers. slices.SortFunc(files, func(a, b string) int { return strings.Compare(a, b) * sortOrder }) var all []string for _, file := range files { data, err := os.ReadFile(file) if err != nil { return nil, err } // We need individual statements. Assume semicolon is not used in other places than statements end. statements := strings.Split(string(data), ";") statements = statements[:len(statements)-1] all = append(all, statements...) } return all, nil } func (ctx *Ctx) config() *GlobalConfig { return getConfig(ctx.ctx) } func (ctx *Ctx) expectOK(err error) { if err != nil { ctx.t.Helper() ctx.t.Fatalf("expected OK, got error: %v", err) } } func (ctx *Ctx) expectFail(msg string, err error) { ctx.t.Helper() if err == nil { ctx.t.Fatalf("expected to fail, but it does not") } if !strings.Contains(err.Error(), msg) { ctx.t.Fatalf("expected to fail with %q, but failed with %q", msg, err) } } func expectFailureStatus(t *testing.T, err error, code int) { t.Helper() if err == nil { t.Fatalf("expected to fail as %d, but it does not", code) } var httpErr *HTTPError if !errors.As(err, &httpErr) || httpErr.Code != code { t.Fatalf("expected to fail as %d, but it failed as %v", code, err) } } func (ctx *Ctx) expectBadReqest(err error) { expectFailureStatus(ctx.t, err, http.StatusBadRequest) } func (ctx *Ctx) expectEQ(got, want any) { if diff := cmp.Diff(got, want); diff != "" { ctx.t.Helper() ctx.t.Fatal(diff) } } func (ctx *Ctx) expectNE(got, want any) { if reflect.DeepEqual(got, want) { ctx.t.Helper() ctx.t.Fatalf("equal: %#v", got) } } func (ctx *Ctx) expectTrue(v bool) { if !v { ctx.t.Helper() ctx.t.Fatal("failed") } } func caller(skip int) string { pcs := make([]uintptr, 10) n := runtime.Callers(skip+3, pcs) pcs = pcs[:n] frames := runtime.CallersFrames(pcs) stack := "" for { frame, more := frames.Next() if strings.HasPrefix(frame.Function, "testing.") { break } stack = fmt.Sprintf("%v:%v\n", filepath.Base(frame.File), frame.Line) + stack if !more { break } } if stack != "" { stack = stack[:len(stack)-1] } return stack } func (ctx *Ctx) Close() { defer ctx.inst.Close() // transformContext may substitute config. if ctx.transformContext == nil && !ctx.t.Failed() { // To avoid per-day reporting limits for left-over emails. ctx.advanceTime(25 * time.Hour) // Ensure that we can render main page and all bugs in the final test state. _, err := ctx.GET("/test1") ctx.expectOK(err) _, err = ctx.GET("/test2") ctx.expectOK(err) _, err = ctx.GET("/test1/fixed") ctx.expectOK(err) _, err = ctx.GET("/test2/fixed") ctx.expectOK(err) _, err = ctx.GET("/admin") ctx.expectOK(err) var bugs []*Bug keys, err := db.NewQuery("Bug").GetAll(ctx.ctx, &bugs) if err != nil { ctx.t.Errorf("ERROR: failed to query bugs: %v", err) } for _, key := range keys { _, err = ctx.GET(fmt.Sprintf("/bug?id=%v", key.StringID())) ctx.expectOK(err) } // No pending emails (tests need to consume them). _, err = ctx.GET("/cron/email_poll") ctx.expectOK(err) for len(ctx.emailSink) != 0 { ctx.t.Errorf("ERROR: leftover email: %v", (<-ctx.emailSink).Body) } // No pending external reports (tests need to consume them). resp, _ := ctx.globalClient.ReportingPollBugs("test") for _, rep := range resp.Reports { ctx.t.Errorf("ERROR: leftover external report:\n%#v", rep) } if ctx.checkAI { _, err = ctx.GET("/ains/ai/") ctx.expectOK(err) jobs, err := aidb.LoadNamespaceJobs(ctx.ctx, "ains") ctx.expectOK(err) for _, job := range jobs { _, err = ctx.GET(fmt.Sprintf("/ai_job?id=%v", job.ID)) ctx.expectOK(err) } } } aidb.CloseClient(ctx.ctx) unregisterContext(ctx) validateGlobalConfig() } func (ctx *Ctx) advanceTime(d time.Duration) { ctx.mockedTime = ctx.mockedTime.Add(d) } func (ctx *Ctx) setSubsystems(ns string, list []*subsystem.Subsystem, rev int) { ctx.transformContext = func(ctx context.Context) context.Context { newConfig := replaceNamespaceConfig(ctx, ns, func(cfg *Config) *Config { ret := *cfg ret.Subsystems.Service = subsystem.MustMakeService(list, rev) return &ret }) return contextWithConfig(ctx, newConfig) } } func (ctx *Ctx) setCoverageMocks(ns string, dbClientMock spannerclient.SpannerClient, fileProvMock covermerger.FileVersProvider) { ctx.transformContext = func(ctx context.Context) context.Context { newConfig := replaceNamespaceConfig(ctx, ns, func(cfg *Config) *Config { ret := *cfg ret.Coverage = &CoverageConfig{WebGitURI: "test-git"} return &ret }) ctxWithSpanner := setCoverageDBClient(ctx, dbClientMock) ctxWithSpannerAndFileProvider := setWebGit(ctxWithSpanner, fileProvMock) return contextWithConfig(ctxWithSpannerAndFileProvider, newConfig) } } func (ctx *Ctx) setKernelRepos(ns string, list []KernelRepo) { ctx.transformContext = func(ctx context.Context) context.Context { newConfig := replaceNamespaceConfig(ctx, ns, func(cfg *Config) *Config { ret := *cfg ret.Repos = list return &ret }) return contextWithConfig(ctx, newConfig) } } func (ctx *Ctx) setNoObsoletions() { ctx.transformContext = func(ctx context.Context) context.Context { return contextWithNoObsoletions(ctx) } } func (ctx *Ctx) updateReporting(ns, name string, f func(Reporting) Reporting) { ctx.transformContext = func(ctx context.Context) context.Context { return contextWithConfig(ctx, replaceReporting(ctx, ns, name, f)) } } func (ctx *Ctx) decommissionManager(ns, oldManager, newManager string) { ctx.transformContext = func(ctx context.Context) context.Context { newConfig := replaceManagerConfig(ctx, ns, oldManager, func(cfg ConfigManager) ConfigManager { cfg.Decommissioned = true cfg.DelegatedTo = newManager return cfg }) return contextWithConfig(ctx, newConfig) } } func (ctx *Ctx) decommission(ns string) { ctx.transformContext = func(ctx context.Context) context.Context { newConfig := replaceNamespaceConfig(ctx, ns, func(cfg *Config) *Config { ret := *cfg ret.Decommissioned = true return &ret }) return contextWithConfig(ctx, newConfig) } } func (ctx *Ctx) setWaitForRepro(ns string, d time.Duration) { ctx.transformContext = func(ctx context.Context) context.Context { newConfig := replaceNamespaceConfig(ctx, ns, func(cfg *Config) *Config { ret := *cfg ret.WaitForRepro = d return &ret }) return contextWithConfig(ctx, newConfig) } } // GET sends admin-authorized HTTP GET request to the app. func (ctx *Ctx) GET(url string) ([]byte, error) { return ctx.AuthGET(AccessAdmin, url) } // AuthGET sends HTTP GET request to the app with the specified authorization. func (ctx *Ctx) AuthGET(access AccessLevel, url string) ([]byte, error) { w, err := ctx.httpRequest("GET", url, "", "", access) if err != nil { return nil, err } return w.Body.Bytes(), nil } // POST sends admin-authorized HTTP POST requestd to the app. func (ctx *Ctx) POST(url, body string) ([]byte, error) { w, err := ctx.httpRequest("POST", url, body, "", AccessAdmin) if err != nil { return nil, err } return w.Body.Bytes(), nil } // POST sends an admin-authorized HTTP POST form to the app. func (ctx *Ctx) POSTForm(url string, form url.Values) ([]byte, error) { w, err := ctx.httpRequest("POST", url, form.Encode(), "application/x-www-form-urlencoded", AccessAdmin) if err != nil { return nil, err } return w.Body.Bytes(), nil } // ContentType returns the response Content-Type header value. func (ctx *Ctx) ContentType(url string) (string, error) { w, err := ctx.httpRequest("HEAD", url, "", "", AccessAdmin) if err != nil { return "", err } values := w.Header()["Content-Type"] if len(values) == 0 { return "", fmt.Errorf("no Content-Type") } return values[0], nil } func (ctx *Ctx) httpRequest(method, url, body, contentType string, access AccessLevel) (*httptest.ResponseRecorder, error) { ctx.t.Logf("%v: %v", method, url) r, err := ctx.inst.NewRequest(method, url, strings.NewReader(body)) if err != nil { ctx.t.Fatal(err) } r.Header.Add("X-Appengine-User-IP", "127.0.0.1") if contentType != "" { r.Header.Add("Content-Type", contentType) } r = registerRequest(r, ctx) r = r.WithContext(ctx.transformContext(r.Context())) switch access { case AccessAdmin: aetest.Login(makeUser(AuthorizedAdmin), r) case AccessUser: aetest.Login(makeUser(AuthorizedUser), r) } w := httptest.NewRecorder() http.DefaultServeMux.ServeHTTP(w, r) ctx.t.Logf("REPLY: %v", w.Code) if w.Code != http.StatusOK { return nil, &HTTPError{w.Code, w.Body.String(), w.Result().Header} } return w, nil } type HTTPError struct { Code int Body string Headers http.Header } func (err *HTTPError) Error() string { return fmt.Sprintf("%v: %v", err.Code, err.Body) } func (ctx *Ctx) loadBug(extID string) (*Bug, *Crash, *Build) { bug, _, err := findBugByReportingID(ctx.ctx, extID) if err != nil { ctx.t.Fatalf("failed to load bug: %v", err) } return ctx.loadBugInfo(bug) } func (ctx *Ctx) loadBugByHash(hash string) (*Bug, *Crash, *Build) { bug := new(Bug) bugKey := db.NewKey(ctx.ctx, "Bug", hash, 0, nil) ctx.expectOK(db.Get(ctx.ctx, bugKey, bug)) return ctx.loadBugInfo(bug) } func (ctx *Ctx) loadBugInfo(bug *Bug) (*Bug, *Crash, *Build) { crash, _, err := findCrashForBug(ctx.ctx, bug) if err != nil { ctx.t.Fatalf("failed to load crash: %v", err) } build := ctx.loadBuild(bug.Namespace, crash.BuildID) return bug, crash, build } func (ctx *Ctx) loadJob(extID string) (*Job, *Build, *Crash) { jobKey, err := jobID2Key(ctx.ctx, extID) if err != nil { ctx.t.Fatalf("failed to create job key: %v", err) } job := new(Job) if err := db.Get(ctx.ctx, jobKey, job); err != nil { ctx.t.Fatalf("failed to get job %v: %v", extID, err) } build := ctx.loadBuild(job.Namespace, job.BuildID) crash := new(Crash) crashKey := db.NewKey(ctx.ctx, "Crash", "", job.CrashID, jobKey.Parent()) if err := db.Get(ctx.ctx, crashKey, crash); err != nil { ctx.t.Fatalf("failed to load crash for job: %v", err) } return job, build, crash } func (ctx *Ctx) loadBuild(ns, id string) *Build { build, err := loadBuild(ctx.ctx, ns, id) ctx.expectOK(err) return build } func (ctx *Ctx) loadManager(ns, name string) (*Manager, *Build) { mgr, err := loadManager(ctx.ctx, ns, name) ctx.expectOK(err) build := ctx.loadBuild(ns, mgr.CurrentBuild) return mgr, build } func (ctx *Ctx) loadSingleBug() (*Bug, *db.Key) { var bugs []*Bug keys, err := db.NewQuery("Bug").GetAll(ctx.ctx, &bugs) ctx.expectEQ(err, nil) ctx.expectEQ(len(bugs), 1) return bugs[0], keys[0] } func (ctx *Ctx) loadSingleJob() (*Job, *db.Key) { var jobs []*Job keys, err := db.NewQuery("Job").GetAll(ctx.ctx, &jobs) ctx.expectEQ(err, nil) ctx.expectEQ(len(jobs), 1) return jobs[0], keys[0] } func (ctx *Ctx) checkURLContents(url string, want []byte) { ctx.t.Helper() got, err := ctx.AuthGET(AccessAdmin, url) if err != nil { ctx.t.Fatalf("%v request failed: %v", url, err) } if !bytes.Equal(got, want) { ctx.t.Fatalf("url %v: got:\n%s\nwant:\n%s\n", url, got, want) } } func (ctx *Ctx) pollEmailBug() *aemail.Message { _, err := ctx.GET("/cron/email_poll") ctx.expectOK(err) if len(ctx.emailSink) == 0 { ctx.t.Helper() ctx.t.Fatal("got no emails") } return <-ctx.emailSink } func (ctx *Ctx) pollEmailExtID() string { ctx.t.Helper() _, extBugID := ctx.pollEmailAndExtID() return extBugID } func (ctx *Ctx) pollEmailAndExtID() (string, string) { ctx.t.Helper() msg := ctx.pollEmailBug() _, extBugID, err := email.RemoveAddrContext(msg.Sender) if err != nil { ctx.t.Fatalf("failed to remove addr context: %v", err) } return msg.Sender, extBugID } func (ctx *Ctx) expectNoEmail() { _, err := ctx.GET("/cron/email_poll") ctx.expectOK(err) if len(ctx.emailSink) != 0 { msg := <-ctx.emailSink ctx.t.Helper() ctx.t.Fatalf("got unexpected email: %v\n%s", msg.Subject, msg.Body) } } type apiClient struct { *Ctx *dashapi.Dashboard } func (ctx *Ctx) makeClient(client, key string, failOnErrors bool) *apiClient { logger := func(msg string, args ...any) { ctx.t.Logf("%v: "+msg, append([]any{caller(3)}, args...)...) } errorHandler := func(err error) { if failOnErrors { ctx.t.Fatalf("\n%v: %v", caller(2), err) } } dash, err := dashapi.NewCustom(client, "", key, ctx.inst.NewRequest, ctx.httpDoer(), logger, errorHandler) if err != nil { panic(fmt.Sprintf("Impossible error: %v", err)) } return &apiClient{ Ctx: ctx, Dashboard: dash, } } func (ctx *Ctx) makeAPIClient() *api.Client { return api.NewTestClient(ctx.inst.NewRequest, ctx.httpDoer()) } func (ctx *Ctx) httpDoer() func(*http.Request) (*http.Response, error) { return func(r *http.Request) (*http.Response, error) { r = registerRequest(r, ctx) r = r.WithContext(ctx.transformContext(r.Context())) w := httptest.NewRecorder() http.DefaultServeMux.ServeHTTP(w, r) res := &http.Response{ StatusCode: w.Code, Status: http.StatusText(w.Code), Body: io.NopCloser(w.Result().Body), } return res, nil } } func (client *apiClient) pollBugs(expect int) []*dashapi.BugReport { resp, _ := client.ReportingPollBugs("test") if len(resp.Reports) != expect { client.t.Helper() client.t.Fatalf("want %v reports, got %v", expect, len(resp.Reports)) } for _, rep := range resp.Reports { reproLevel := dashapi.ReproLevelNone if len(rep.ReproC) != 0 { reproLevel = dashapi.ReproLevelC } else if len(rep.ReproSyz) != 0 { reproLevel = dashapi.ReproLevelSyz } reply, _ := client.ReportingUpdate(&dashapi.BugUpdate{ ID: rep.ID, JobID: rep.JobID, Status: dashapi.BugStatusOpen, ReproLevel: reproLevel, CrashID: rep.CrashID, }) client.expectEQ(reply.Error, false) client.expectEQ(reply.OK, true) } return resp.Reports } func (client *apiClient) pollBug() *dashapi.BugReport { return client.pollBugs(1)[0] } func (client *apiClient) pollNotifs(expect int) []*dashapi.BugNotification { resp, _ := client.ReportingPollNotifications("test") if len(resp.Notifications) != expect { client.t.Helper() client.t.Fatalf("want %v notifs, got %v", expect, len(resp.Notifications)) } return resp.Notifications } func (client *apiClient) updateBug(extID string, status dashapi.BugStatus, dup string) { reply, _ := client.ReportingUpdate(&dashapi.BugUpdate{ ID: extID, Status: status, DupOf: dup, }) client.expectTrue(reply.OK) } func (client *apiClient) pollSpecificJobs(manager string, jobs dashapi.ManagerJobs) *dashapi.JobPollResp { req := &dashapi.JobPollReq{ Managers: map[string]dashapi.ManagerJobs{ manager: jobs, }, } resp, err := client.JobPoll(req) client.expectOK(err) return resp } func (client *apiClient) pollJobs(manager string) *dashapi.JobPollResp { return client.pollSpecificJobs(manager, dashapi.ManagerJobs{ TestPatches: true, BisectCause: true, BisectFix: true, }) } func (client *apiClient) pollAndFailBisectJob(manager string) { resp := client.pollJobs(manager) client.expectNE(resp.ID, "") client.expectEQ(resp.Type, dashapi.JobBisectCause) done := &dashapi.JobDoneReq{ ID: resp.ID, Error: []byte("pollAndFailBisectJob"), } client.expectOK(client.JobDone(done)) } type ( EmailOptMessageID int EmailOptSubject string EmailOptFrom string EmailOptOrigFrom string EmailOptCC []string EmailOptSender string ) func (ctx *Ctx) incomingEmail(to, body string, opts ...any) { id := 0 subject := "crash1" from := "default@sender.com" cc := []string{"test@syzkaller.com", "bugs@syzkaller.com", "bugs2@syzkaller.com"} sender := "" origFrom := "" for _, o := range opts { switch opt := o.(type) { case EmailOptMessageID: id = int(opt) case EmailOptSubject: subject = string(opt) case EmailOptFrom: from = string(opt) case EmailOptSender: sender = string(opt) case EmailOptCC: cc = []string(opt) case EmailOptOrigFrom: origFrom = fmt.Sprintf("\nX-Original-From: %v", string(opt)) } } if sender == "" { sender = from } email := fmt.Sprintf(`Sender: %v Date: Tue, 15 Aug 2017 14:59:00 -0700 Message-ID: <%v> Subject: %v From: %v Cc: %v To: %v%v Content-Type: text/plain %v `, sender, id, subject, from, strings.Join(cc, ","), to, origFrom, body) log.Infof(ctx.ctx, "sending %s", email) _, err := ctx.POST("/_ah/mail/"+to, email) ctx.expectOK(err) } func (ctx *Ctx) createAIJob(bugExitID, workflow, baseCommit string) string { bug, _, _ := ctx.loadBug(bugExitID) args := map[string]any{} if baseCommit != "" { args["FixedBaseCommit"] = baseCommit } id, err := aiBugJobCreate(ctx.ctx, workflow, bug, args) require.NoError(ctx.t, err) return id } func initMocks() { // Mock time as some functionality relies on real time. timeNow = func(ctx context.Context) time.Time { return getRequestContext(ctx).mockedTime } aidb.TimeNow = timeNow sendEmail = func(ctx context.Context, msg *aemail.Message) error { getRequestContext(ctx).emailSink <- msg return nil } maxCrashes = func() int { // dev_appserver is very slow, so let's make tests smaller. const maxCrashesDuringTest = 20 return maxCrashesDuringTest } } // Machinery to associate mocked time with requests. type RequestMapping struct { id int ctx *Ctx } var ( requestMu sync.Mutex requestNum int requestContexts []RequestMapping ) func registerRequest(r *http.Request, ctx *Ctx) *http.Request { requestMu.Lock() defer requestMu.Unlock() requestNum++ newContext := context.WithValue(r.Context(), requestIDKey{}, requestNum) newRequest := r.WithContext(newContext) requestContexts = append(requestContexts, RequestMapping{requestNum, ctx}) return newRequest } func getRequestContext(ctx context.Context) *Ctx { requestMu.Lock() defer requestMu.Unlock() reqID := getRequestID(ctx) for _, m := range requestContexts { if m.id == reqID { return m.ctx } } panic(fmt.Sprintf("no context for: %#v", ctx)) } func unregisterContext(ctx *Ctx) { requestMu.Lock() defer requestMu.Unlock() n := 0 for _, m := range requestContexts { if m.ctx == ctx { continue } requestContexts[n] = m n++ } requestContexts = requestContexts[:n] } type requestIDKey struct{} func getRequestID(ctx context.Context) int { val, ok := ctx.Value(requestIDKey{}).(int) if !ok { panic("the context did not come from a test") } return val } // Create a shallow copy of GlobalConfig with a replaced namespace config. func replaceNamespaceConfig(ctx context.Context, ns string, f func(*Config) *Config) *GlobalConfig { ret := *getConfig(ctx) newNsMap := map[string]*Config{} for name, nsCfg := range ret.Namespaces { if name == ns { nsCfg = f(nsCfg) } newNsMap[name] = nsCfg } ret.Namespaces = newNsMap return &ret } func replaceManagerConfig(ctx context.Context, ns, mgr string, f func(ConfigManager) ConfigManager) *GlobalConfig { return replaceNamespaceConfig(ctx, ns, func(cfg *Config) *Config { ret := *cfg newMgrMap := map[string]ConfigManager{} for name, mgrCfg := range ret.Managers { if name == mgr { mgrCfg = f(mgrCfg) } newMgrMap[name] = mgrCfg } ret.Managers = newMgrMap return &ret }) } func replaceReporting(ctx context.Context, ns, name string, f func(Reporting) Reporting) *GlobalConfig { return replaceNamespaceConfig(ctx, ns, func(cfg *Config) *Config { ret := *cfg var newReporting []Reporting for _, cfg := range ret.Reporting { if cfg.Name == name { cfg = f(cfg) } newReporting = append(newReporting, cfg) } ret.Reporting = newReporting return &ret }) }