aboutsummaryrefslogtreecommitdiffstats
path: root/syz-cluster/pkg/reporter/api.go
blob: 310dd57d1da420937ce2065044fa37fcfa9e2b1a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
// Copyright 2025 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.

package reporter

import (
	"errors"
	"fmt"
	"net/http"
	"net/http/httptest"
	"testing"

	"github.com/google/syzkaller/syz-cluster/pkg/api"
	"github.com/google/syzkaller/syz-cluster/pkg/app"
	"github.com/google/syzkaller/syz-cluster/pkg/service"
)

type APIServer struct {
	reportService     *service.ReportService
	discussionService *service.DiscussionService
}

func NewAPIServer(env *app.AppEnvironment) *APIServer {
	return &APIServer{
		reportService:     service.NewReportService(env),
		discussionService: service.NewDiscussionService(env),
	}
}

func (s *APIServer) Mux() *http.ServeMux {
	mux := http.NewServeMux()
	mux.HandleFunc("/reports/{report_id}/upstream", s.upstreamReport)
	mux.HandleFunc("/reports/{report_id}/confirm", s.confirmReport)
	mux.HandleFunc("/reports/{report_id}/invalidate", s.invalidateReport)
	mux.HandleFunc("/reports/record_reply", s.recordReply)
	mux.HandleFunc("/reports/last_reply", s.lastReply)
	mux.HandleFunc("/reports", s.nextReports)
	return mux
}

// nolint: dupl
func (s *APIServer) upstreamReport(w http.ResponseWriter, r *http.Request) {
	req := api.ParseJSON[api.UpstreamReportReq](w, r)
	if req == nil {
		return
	}
	// TODO: journal the action.
	err := s.reportService.Upstream(r.Context(), r.PathValue("report_id"), req)
	reply[interface{}](w, nil, err)
}

func (s *APIServer) invalidateReport(w http.ResponseWriter, r *http.Request) {
	// TODO: journal the action.
	err := s.reportService.Invalidate(r.Context(), r.PathValue("report_id"))
	reply[interface{}](w, nil, err)
}

func (s *APIServer) nextReports(w http.ResponseWriter, r *http.Request) {
	resp, err := s.reportService.Next(r.Context(), r.FormValue("reporter"))
	reply(w, resp, err)
}

func (s *APIServer) confirmReport(w http.ResponseWriter, r *http.Request) {
	err := s.reportService.Confirm(r.Context(), r.PathValue("report_id"))
	reply[interface{}](w, nil, err)
}

func (s *APIServer) recordReply(w http.ResponseWriter, r *http.Request) {
	req := api.ParseJSON[api.RecordReplyReq](w, r)
	if req == nil {
		return
	}
	resp, err := s.discussionService.RecordReply(r.Context(), req)
	reply(w, resp, err)
}

func (s *APIServer) lastReply(w http.ResponseWriter, r *http.Request) {
	resp, err := s.discussionService.LastReply(r.Context(), r.PathValue("reporter"))
	reply(w, resp, err)
}

func reply[T any](w http.ResponseWriter, obj T, err error) {
	if errors.Is(err, service.ErrReportNotFound) {
		http.Error(w, fmt.Sprint(err), http.StatusNotFound)
		return
	} else if errors.Is(err, service.ErrNotOnModeration) {
		http.Error(w, fmt.Sprint(err), http.StatusBadRequest)
		return
	} else if err != nil {
		http.Error(w, fmt.Sprint(err), http.StatusInternalServerError)
		return
	}
	api.ReplyJSON[T](w, obj)
}

func TestServer(t *testing.T, env *app.AppEnvironment) *api.ReporterClient {
	apiServer := NewAPIServer(env)
	server := httptest.NewServer(apiServer.Mux())
	t.Cleanup(server.Close)
	return api.NewReporterClient(server.URL)
}