aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/google/s2a-go/s2a.go
diff options
context:
space:
mode:
authordependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>2023-07-25 07:50:36 +0000
committerTaras Madan <tarasmadan@google.com>2023-07-25 08:27:48 +0000
commitb423bd03401d00e754d5e5c0236feda4dfb02e28 (patch)
treeb4c5cb53485b00adb2877b7fa27b9b6e5f02552a /vendor/github.com/google/s2a-go/s2a.go
parente06c669f49a06146914b04a1fbbdd21a0bf1d7b1 (diff)
mod: do: bump golang.org/x/oauth2 from 0.5.0 to 0.10.0
Bumps [golang.org/x/oauth2](https://github.com/golang/oauth2) from 0.5.0 to 0.10.0. - [Commits](https://github.com/golang/oauth2/compare/v0.5.0...v0.10.0) --- updated-dependencies: - dependency-name: golang.org/x/oauth2 dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com>
Diffstat (limited to 'vendor/github.com/google/s2a-go/s2a.go')
-rw-r--r--vendor/github.com/google/s2a-go/s2a.go412
1 files changed, 412 insertions, 0 deletions
diff --git a/vendor/github.com/google/s2a-go/s2a.go b/vendor/github.com/google/s2a-go/s2a.go
new file mode 100644
index 000000000..1c1349de4
--- /dev/null
+++ b/vendor/github.com/google/s2a-go/s2a.go
@@ -0,0 +1,412 @@
+/*
+ *
+ * Copyright 2021 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+// Package s2a provides the S2A transport credentials used by a gRPC
+// application.
+package s2a
+
+import (
+ "context"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "net"
+ "sync"
+ "time"
+
+ "github.com/golang/protobuf/proto"
+ "github.com/google/s2a-go/fallback"
+ "github.com/google/s2a-go/internal/handshaker"
+ "github.com/google/s2a-go/internal/handshaker/service"
+ "github.com/google/s2a-go/internal/tokenmanager"
+ "github.com/google/s2a-go/internal/v2"
+ "google.golang.org/grpc/credentials"
+ "google.golang.org/grpc/grpclog"
+
+ commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
+ s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
+)
+
+const (
+ s2aSecurityProtocol = "tls"
+ // defaultTimeout specifies the default server handshake timeout.
+ defaultTimeout = 30.0 * time.Second
+)
+
+// s2aTransportCreds are the transport credentials required for establishing
+// a secure connection using the S2A. They implement the
+// credentials.TransportCredentials interface.
+type s2aTransportCreds struct {
+ info *credentials.ProtocolInfo
+ minTLSVersion commonpb.TLSVersion
+ maxTLSVersion commonpb.TLSVersion
+ // tlsCiphersuites contains the ciphersuites used in the S2A connection.
+ // Note that these are currently unconfigurable.
+ tlsCiphersuites []commonpb.Ciphersuite
+ // localIdentity should only be used by the client.
+ localIdentity *commonpb.Identity
+ // localIdentities should only be used by the server.
+ localIdentities []*commonpb.Identity
+ // targetIdentities should only be used by the client.
+ targetIdentities []*commonpb.Identity
+ isClient bool
+ s2aAddr string
+ ensureProcessSessionTickets *sync.WaitGroup
+}
+
+// NewClientCreds returns a client-side transport credentials object that uses
+// the S2A to establish a secure connection with a server.
+func NewClientCreds(opts *ClientOptions) (credentials.TransportCredentials, error) {
+ if opts == nil {
+ return nil, errors.New("nil client options")
+ }
+ var targetIdentities []*commonpb.Identity
+ for _, targetIdentity := range opts.TargetIdentities {
+ protoTargetIdentity, err := toProtoIdentity(targetIdentity)
+ if err != nil {
+ return nil, err
+ }
+ targetIdentities = append(targetIdentities, protoTargetIdentity)
+ }
+ localIdentity, err := toProtoIdentity(opts.LocalIdentity)
+ if err != nil {
+ return nil, err
+ }
+ if opts.EnableLegacyMode {
+ return &s2aTransportCreds{
+ info: &credentials.ProtocolInfo{
+ SecurityProtocol: s2aSecurityProtocol,
+ },
+ minTLSVersion: commonpb.TLSVersion_TLS1_3,
+ maxTLSVersion: commonpb.TLSVersion_TLS1_3,
+ tlsCiphersuites: []commonpb.Ciphersuite{
+ commonpb.Ciphersuite_AES_128_GCM_SHA256,
+ commonpb.Ciphersuite_AES_256_GCM_SHA384,
+ commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
+ },
+ localIdentity: localIdentity,
+ targetIdentities: targetIdentities,
+ isClient: true,
+ s2aAddr: opts.S2AAddress,
+ ensureProcessSessionTickets: opts.EnsureProcessSessionTickets,
+ }, nil
+ }
+ verificationMode := getVerificationMode(opts.VerificationMode)
+ var fallbackFunc fallback.ClientHandshake
+ if opts.FallbackOpts != nil && opts.FallbackOpts.FallbackClientHandshakeFunc != nil {
+ fallbackFunc = opts.FallbackOpts.FallbackClientHandshakeFunc
+ }
+ return v2.NewClientCreds(opts.S2AAddress, localIdentity, verificationMode, fallbackFunc, opts.getS2AStream, opts.serverAuthorizationPolicy)
+}
+
+// NewServerCreds returns a server-side transport credentials object that uses
+// the S2A to establish a secure connection with a client.
+func NewServerCreds(opts *ServerOptions) (credentials.TransportCredentials, error) {
+ if opts == nil {
+ return nil, errors.New("nil server options")
+ }
+ var localIdentities []*commonpb.Identity
+ for _, localIdentity := range opts.LocalIdentities {
+ protoLocalIdentity, err := toProtoIdentity(localIdentity)
+ if err != nil {
+ return nil, err
+ }
+ localIdentities = append(localIdentities, protoLocalIdentity)
+ }
+ if opts.EnableLegacyMode {
+ return &s2aTransportCreds{
+ info: &credentials.ProtocolInfo{
+ SecurityProtocol: s2aSecurityProtocol,
+ },
+ minTLSVersion: commonpb.TLSVersion_TLS1_3,
+ maxTLSVersion: commonpb.TLSVersion_TLS1_3,
+ tlsCiphersuites: []commonpb.Ciphersuite{
+ commonpb.Ciphersuite_AES_128_GCM_SHA256,
+ commonpb.Ciphersuite_AES_256_GCM_SHA384,
+ commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
+ },
+ localIdentities: localIdentities,
+ isClient: false,
+ s2aAddr: opts.S2AAddress,
+ }, nil
+ }
+ verificationMode := getVerificationMode(opts.VerificationMode)
+ return v2.NewServerCreds(opts.S2AAddress, localIdentities, verificationMode, opts.getS2AStream)
+}
+
+// ClientHandshake initiates a client-side TLS handshake using the S2A.
+func (c *s2aTransportCreds) ClientHandshake(ctx context.Context, serverAuthority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
+ if !c.isClient {
+ return nil, nil, errors.New("client handshake called using server transport credentials")
+ }
+
+ // Connect to the S2A.
+ hsConn, err := service.Dial(c.s2aAddr)
+ if err != nil {
+ grpclog.Infof("Failed to connect to S2A: %v", err)
+ return nil, nil, err
+ }
+
+ var cancel context.CancelFunc
+ ctx, cancel = context.WithCancel(ctx)
+ defer cancel()
+
+ opts := &handshaker.ClientHandshakerOptions{
+ MinTLSVersion: c.minTLSVersion,
+ MaxTLSVersion: c.maxTLSVersion,
+ TLSCiphersuites: c.tlsCiphersuites,
+ TargetIdentities: c.targetIdentities,
+ LocalIdentity: c.localIdentity,
+ TargetName: serverAuthority,
+ EnsureProcessSessionTickets: c.ensureProcessSessionTickets,
+ }
+ chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, c.s2aAddr, opts)
+ if err != nil {
+ grpclog.Infof("Call to handshaker.NewClientHandshaker failed: %v", err)
+ return nil, nil, err
+ }
+ defer func() {
+ if err != nil {
+ if closeErr := chs.Close(); closeErr != nil {
+ grpclog.Infof("Close failed unexpectedly: %v", err)
+ err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr)
+ }
+ }
+ }()
+
+ secConn, authInfo, err := chs.ClientHandshake(context.Background())
+ if err != nil {
+ grpclog.Infof("Handshake failed: %v", err)
+ return nil, nil, err
+ }
+ return secConn, authInfo, nil
+}
+
+// ServerHandshake initiates a server-side TLS handshake using the S2A.
+func (c *s2aTransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
+ if c.isClient {
+ return nil, nil, errors.New("server handshake called using client transport credentials")
+ }
+
+ // Connect to the S2A.
+ hsConn, err := service.Dial(c.s2aAddr)
+ if err != nil {
+ grpclog.Infof("Failed to connect to S2A: %v", err)
+ return nil, nil, err
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
+ defer cancel()
+
+ opts := &handshaker.ServerHandshakerOptions{
+ MinTLSVersion: c.minTLSVersion,
+ MaxTLSVersion: c.maxTLSVersion,
+ TLSCiphersuites: c.tlsCiphersuites,
+ LocalIdentities: c.localIdentities,
+ }
+ shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, c.s2aAddr, opts)
+ if err != nil {
+ grpclog.Infof("Call to handshaker.NewServerHandshaker failed: %v", err)
+ return nil, nil, err
+ }
+ defer func() {
+ if err != nil {
+ if closeErr := shs.Close(); closeErr != nil {
+ grpclog.Infof("Close failed unexpectedly: %v", err)
+ err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr)
+ }
+ }
+ }()
+
+ secConn, authInfo, err := shs.ServerHandshake(context.Background())
+ if err != nil {
+ grpclog.Infof("Handshake failed: %v", err)
+ return nil, nil, err
+ }
+ return secConn, authInfo, nil
+}
+
+func (c *s2aTransportCreds) Info() credentials.ProtocolInfo {
+ return *c.info
+}
+
+func (c *s2aTransportCreds) Clone() credentials.TransportCredentials {
+ info := *c.info
+ var localIdentity *commonpb.Identity
+ if c.localIdentity != nil {
+ localIdentity = proto.Clone(c.localIdentity).(*commonpb.Identity)
+ }
+ var localIdentities []*commonpb.Identity
+ if c.localIdentities != nil {
+ localIdentities = make([]*commonpb.Identity, len(c.localIdentities))
+ for i, localIdentity := range c.localIdentities {
+ localIdentities[i] = proto.Clone(localIdentity).(*commonpb.Identity)
+ }
+ }
+ var targetIdentities []*commonpb.Identity
+ if c.targetIdentities != nil {
+ targetIdentities = make([]*commonpb.Identity, len(c.targetIdentities))
+ for i, targetIdentity := range c.targetIdentities {
+ targetIdentities[i] = proto.Clone(targetIdentity).(*commonpb.Identity)
+ }
+ }
+ return &s2aTransportCreds{
+ info: &info,
+ minTLSVersion: c.minTLSVersion,
+ maxTLSVersion: c.maxTLSVersion,
+ tlsCiphersuites: c.tlsCiphersuites,
+ localIdentity: localIdentity,
+ localIdentities: localIdentities,
+ targetIdentities: targetIdentities,
+ isClient: c.isClient,
+ s2aAddr: c.s2aAddr,
+ }
+}
+
+func (c *s2aTransportCreds) OverrideServerName(serverNameOverride string) error {
+ c.info.ServerName = serverNameOverride
+ return nil
+}
+
+// TLSClientConfigOptions specifies parameters for creating client TLS config.
+type TLSClientConfigOptions struct {
+ // ServerName is required by s2a as the expected name when verifying the hostname found in server's certificate.
+ // tlsConfig, _ := factory.Build(ctx, &s2a.TLSClientConfigOptions{
+ // ServerName: "example.com",
+ // })
+ ServerName string
+}
+
+// TLSClientConfigFactory defines the interface for a client TLS config factory.
+type TLSClientConfigFactory interface {
+ Build(ctx context.Context, opts *TLSClientConfigOptions) (*tls.Config, error)
+}
+
+// NewTLSClientConfigFactory returns an instance of s2aTLSClientConfigFactory.
+func NewTLSClientConfigFactory(opts *ClientOptions) (TLSClientConfigFactory, error) {
+ if opts == nil {
+ return nil, fmt.Errorf("opts must be non-nil")
+ }
+ if opts.EnableLegacyMode {
+ return nil, fmt.Errorf("NewTLSClientConfigFactory only supports S2Av2")
+ }
+ tokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
+ if err != nil {
+ // The only possible error is: access token not set in the environment,
+ // which is okay in environments other than serverless.
+ grpclog.Infof("Access token manager not initialized: %v", err)
+ return &s2aTLSClientConfigFactory{
+ s2av2Address: opts.S2AAddress,
+ tokenManager: nil,
+ verificationMode: getVerificationMode(opts.VerificationMode),
+ serverAuthorizationPolicy: opts.serverAuthorizationPolicy,
+ }, nil
+ }
+ return &s2aTLSClientConfigFactory{
+ s2av2Address: opts.S2AAddress,
+ tokenManager: tokenManager,
+ verificationMode: getVerificationMode(opts.VerificationMode),
+ serverAuthorizationPolicy: opts.serverAuthorizationPolicy,
+ }, nil
+}
+
+type s2aTLSClientConfigFactory struct {
+ s2av2Address string
+ tokenManager tokenmanager.AccessTokenManager
+ verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode
+ serverAuthorizationPolicy []byte
+}
+
+func (f *s2aTLSClientConfigFactory) Build(
+ ctx context.Context, opts *TLSClientConfigOptions) (*tls.Config, error) {
+ serverName := ""
+ if opts != nil && opts.ServerName != "" {
+ serverName = opts.ServerName
+ }
+ return v2.NewClientTLSConfig(ctx, f.s2av2Address, f.tokenManager, f.verificationMode, serverName, f.serverAuthorizationPolicy)
+}
+
+func getVerificationMode(verificationMode VerificationModeType) s2av2pb.ValidatePeerCertificateChainReq_VerificationMode {
+ switch verificationMode {
+ case ConnectToGoogle:
+ return s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE
+ case Spiffe:
+ return s2av2pb.ValidatePeerCertificateChainReq_SPIFFE
+ default:
+ return s2av2pb.ValidatePeerCertificateChainReq_UNSPECIFIED
+ }
+}
+
+// NewS2ADialTLSContextFunc returns a dialer which establishes an MTLS connection using S2A.
+// Example use with http.RoundTripper:
+//
+// dialTLSContext := s2a.NewS2aDialTLSContextFunc(&s2a.ClientOptions{
+// S2AAddress: s2aAddress, // required
+// })
+// transport := http.DefaultTransport
+// transport.DialTLSContext = dialTLSContext
+func NewS2ADialTLSContextFunc(opts *ClientOptions) func(ctx context.Context, network, addr string) (net.Conn, error) {
+
+ return func(ctx context.Context, network, addr string) (net.Conn, error) {
+
+ fallback := func(err error) (net.Conn, error) {
+ if opts.FallbackOpts != nil && opts.FallbackOpts.FallbackDialer != nil &&
+ opts.FallbackOpts.FallbackDialer.Dialer != nil && opts.FallbackOpts.FallbackDialer.ServerAddr != "" {
+ fbDialer := opts.FallbackOpts.FallbackDialer
+ grpclog.Infof("fall back to dial: %s", fbDialer.ServerAddr)
+ fbConn, fbErr := fbDialer.Dialer.DialContext(ctx, network, fbDialer.ServerAddr)
+ if fbErr != nil {
+ return nil, fmt.Errorf("error fallback to %s: %v; S2A error: %w", fbDialer.ServerAddr, fbErr, err)
+ }
+ return fbConn, nil
+ }
+ return nil, err
+ }
+
+ factory, err := NewTLSClientConfigFactory(opts)
+ if err != nil {
+ grpclog.Infof("error creating S2A client config factory: %v", err)
+ return fallback(err)
+ }
+
+ serverName, _, err := net.SplitHostPort(addr)
+ if err != nil {
+ serverName = addr
+ }
+ timeoutCtx, cancel := context.WithTimeout(ctx, v2.GetS2ATimeout())
+ defer cancel()
+ s2aTLSConfig, err := factory.Build(timeoutCtx, &TLSClientConfigOptions{
+ ServerName: serverName,
+ })
+ if err != nil {
+ grpclog.Infof("error building S2A TLS config: %v", err)
+ return fallback(err)
+ }
+
+ s2aDialer := &tls.Dialer{
+ Config: s2aTLSConfig,
+ }
+ c, err := s2aDialer.DialContext(ctx, network, addr)
+ if err != nil {
+ grpclog.Infof("error dialing with S2A to %s: %v", addr, err)
+ return fallback(err)
+ }
+ grpclog.Infof("success dialing MTLS to %s with S2A", addr)
+ return c, nil
+ }
+}