diff options
| -rw-r--r-- | pkg/gce/gce.go | 21 | ||||
| -rw-r--r-- | pkg/gce/gce_test.go | 21 | ||||
| -rw-r--r-- | vm/gce/gce.go | 4 |
3 files changed, 44 insertions, 2 deletions
diff --git a/pkg/gce/gce.go b/pkg/gce/gce.go index 5e45cbce2..0072bd21f 100644 --- a/pkg/gce/gce.go +++ b/pkg/gce/gce.go @@ -17,6 +17,7 @@ import ( "io" "math/rand" "net/http" + "regexp" "strings" "time" @@ -30,6 +31,7 @@ import ( type Context struct { ProjectID string ZoneID string + RegionID string Instance string InternalIP string ExternalIP string @@ -82,6 +84,13 @@ func NewContext(customZoneID string) (*Context, error) { } else { ctx.ZoneID = myZoneID } + if !validateZone(ctx.ZoneID) { + return nil, fmt.Errorf("%q is not a valid zone name", ctx.ZoneID) + } + ctx.RegionID = zoneToRegion(ctx.ZoneID) + if ctx.RegionID == "" { + return nil, fmt.Errorf("failed to extract region id from %s", ctx.ZoneID) + } ctx.Instance, err = ctx.getMeta("instance/name") if err != nil { return nil, fmt.Errorf("failed to query gce instance name: %w", err) @@ -373,3 +382,15 @@ func (ctx *Context) apiCall(fn func() error) error { return err } } + +var zoneNameRe = regexp.MustCompile("^[a-zA-Z0-9]*-[a-zA-Z0-9]*[-][a-zA-Z0-9]*$") + +func validateZone(zone string) bool { + return zoneNameRe.MatchString(zone) +} + +var regionNameRe = regexp.MustCompile("^[a-zA-Z0-9]*-[a-zA-Z0-9]*") + +func zoneToRegion(zone string) string { + return regionNameRe.FindString(zone) +} diff --git a/pkg/gce/gce_test.go b/pkg/gce/gce_test.go new file mode 100644 index 000000000..39f086bdc --- /dev/null +++ b/pkg/gce/gce_test.go @@ -0,0 +1,21 @@ +// Copyright 2024 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package gce + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValidateZone(t *testing.T) { + assert.True(t, validateZone("us-west1-b")) + assert.True(t, validateZone("us-central1-a")) + assert.False(t, validateZone("us-central1")) +} + +func TestZoneToRegion(t *testing.T) { + assert.Equal(t, "us-west1", zoneToRegion("us-west1-b")) + assert.Equal(t, "northamerica-northeast2", zoneToRegion("northamerica-northeast2-a")) +} diff --git a/vm/gce/gce.go b/vm/gce/gce.go index e8eb35f06..f775f43b5 100644 --- a/vm/gce/gce.go +++ b/vm/gce/gce.go @@ -453,8 +453,8 @@ func (inst *instance) serialPortArgs(replay bool) []string { if replay { replayArg = ".replay-lines=10000" } - conAddr := fmt.Sprintf("%v.%v.%v.%s.port=1%s@ssh-serialport.googleapis.com", - inst.GCE.ProjectID, inst.GCE.ZoneID, inst.name, user, replayArg) + conAddr := fmt.Sprintf("%v.%v.%v.%s.port=1%s@%v-ssh-serialport.googleapis.com", + inst.GCE.ProjectID, inst.GCE.ZoneID, inst.name, user, replayArg, inst.GCE.RegionID) conArgs := append(vmimpl.SSHArgs(inst.debug, key, 9600), conAddr) // TODO(blackgnezdo): Remove this once ssh-serialport.googleapis.com stops using // host key algorithm: ssh-rsa. |
