diff options
| author | Dmitry Vyukov <dvyukov@google.com> | 2017-06-13 19:31:19 +0200 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2017-06-13 19:31:19 +0200 |
| commit | 5b060131006494cbc077f08b9b2fbf172f3eb239 (patch) | |
| tree | 04f8586899db96f7fd8e7bc6a010fc10f1e2bb3b /vendor/cloud.google.com/go/spanner | |
| parent | cd8e13f826ff24f5f8e0b8de1b9d3373aaf93d2f (diff) | |
| parent | 612b82714b3e6660bf702f801ab96aacb3432e1f (diff) | |
Merge pull request #226 from google/dvyukov-vendor
vendor: vendor dependencies
Diffstat (limited to 'vendor/cloud.google.com/go/spanner')
38 files changed, 18468 insertions, 0 deletions
diff --git a/vendor/cloud.google.com/go/spanner/admin/database/apiv1/database_admin_client.go b/vendor/cloud.google.com/go/spanner/admin/database/apiv1/database_admin_client.go new file mode 100644 index 000000000..95d623059 --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/admin/database/apiv1/database_admin_client.go @@ -0,0 +1,556 @@ +// Copyright 2017, Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// AUTO-GENERATED CODE. DO NOT EDIT. + +package database + +import ( + "math" + "time" + + "cloud.google.com/go/internal/version" + "cloud.google.com/go/longrunning" + lroauto "cloud.google.com/go/longrunning/autogen" + gax "github.com/googleapis/gax-go" + "golang.org/x/net/context" + "google.golang.org/api/iterator" + "google.golang.org/api/option" + "google.golang.org/api/transport" + iampb "google.golang.org/genproto/googleapis/iam/v1" + longrunningpb "google.golang.org/genproto/googleapis/longrunning" + databasepb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" +) + +var ( + databaseAdminInstancePathTemplate = gax.MustCompilePathTemplate("projects/{project}/instances/{instance}") + databaseAdminDatabasePathTemplate = gax.MustCompilePathTemplate("projects/{project}/instances/{instance}/databases/{database}") +) + +// DatabaseAdminCallOptions contains the retry settings for each method of DatabaseAdminClient. +type DatabaseAdminCallOptions struct { + ListDatabases []gax.CallOption + CreateDatabase []gax.CallOption + GetDatabase []gax.CallOption + UpdateDatabaseDdl []gax.CallOption + DropDatabase []gax.CallOption + GetDatabaseDdl []gax.CallOption + SetIamPolicy []gax.CallOption + GetIamPolicy []gax.CallOption + TestIamPermissions []gax.CallOption +} + +func defaultDatabaseAdminClientOptions() []option.ClientOption { + return []option.ClientOption{ + option.WithEndpoint("spanner.googleapis.com:443"), + option.WithScopes(DefaultAuthScopes()...), + } +} + +func defaultDatabaseAdminCallOptions() *DatabaseAdminCallOptions { + retry := map[[2]string][]gax.CallOption{ + {"default", "idempotent"}: { + gax.WithRetry(func() gax.Retryer { + return gax.OnCodes([]codes.Code{ + codes.DeadlineExceeded, + codes.Unavailable, + }, gax.Backoff{ + Initial: 1000 * time.Millisecond, + Max: 32000 * time.Millisecond, + Multiplier: 1.3, + }) + }), + }, + {"default", "non_idempotent"}: { + gax.WithRetry(func() gax.Retryer { + return gax.OnCodes([]codes.Code{ + codes.Unavailable, + }, gax.Backoff{ + Initial: 1000 * time.Millisecond, + Max: 32000 * time.Millisecond, + Multiplier: 1.3, + }) + }), + }, + } + return &DatabaseAdminCallOptions{ + ListDatabases: retry[[2]string{"default", "idempotent"}], + CreateDatabase: retry[[2]string{"default", "non_idempotent"}], + GetDatabase: retry[[2]string{"default", "idempotent"}], + UpdateDatabaseDdl: retry[[2]string{"default", "idempotent"}], + DropDatabase: retry[[2]string{"default", "idempotent"}], + GetDatabaseDdl: retry[[2]string{"default", "idempotent"}], + SetIamPolicy: retry[[2]string{"default", "non_idempotent"}], + GetIamPolicy: retry[[2]string{"default", "idempotent"}], + TestIamPermissions: retry[[2]string{"default", "non_idempotent"}], + } +} + +// DatabaseAdminClient is a client for interacting with Cloud Spanner Database Admin API. +type DatabaseAdminClient struct { + // The connection to the service. + conn *grpc.ClientConn + + // The gRPC API client. + databaseAdminClient databasepb.DatabaseAdminClient + + // LROClient is used internally to handle longrunning operations. + // It is exposed so that its CallOptions can be modified if required. + // Users should not Close this client. + LROClient *lroauto.OperationsClient + + // The call options for this service. + CallOptions *DatabaseAdminCallOptions + + // The metadata to be sent with each request. + xGoogHeader []string +} + +// NewDatabaseAdminClient creates a new database admin client. +// +// Cloud Spanner Database Admin API +// +// The Cloud Spanner Database Admin API can be used to create, drop, and +// list databases. It also enables updating the schema of pre-existing +// databases. +func NewDatabaseAdminClient(ctx context.Context, opts ...option.ClientOption) (*DatabaseAdminClient, error) { + conn, err := transport.DialGRPC(ctx, append(defaultDatabaseAdminClientOptions(), opts...)...) + if err != nil { + return nil, err + } + c := &DatabaseAdminClient{ + conn: conn, + CallOptions: defaultDatabaseAdminCallOptions(), + + databaseAdminClient: databasepb.NewDatabaseAdminClient(conn), + } + c.SetGoogleClientInfo() + + c.LROClient, err = lroauto.NewOperationsClient(ctx, option.WithGRPCConn(conn)) + if err != nil { + // This error "should not happen", since we are just reusing old connection + // and never actually need to dial. + // If this does happen, we could leak conn. However, we cannot close conn: + // If the user invoked the function with option.WithGRPCConn, + // we would close a connection that's still in use. + // TODO(pongad): investigate error conditions. + return nil, err + } + return c, nil +} + +// Connection returns the client's connection to the API service. +func (c *DatabaseAdminClient) Connection() *grpc.ClientConn { + return c.conn +} + +// Close closes the connection to the API service. The user should invoke this when +// the client is no longer required. +func (c *DatabaseAdminClient) Close() error { + return c.conn.Close() +} + +// SetGoogleClientInfo sets the name and version of the application in +// the `x-goog-api-client` header passed on each request. Intended for +// use by Google-written clients. +func (c *DatabaseAdminClient) SetGoogleClientInfo(keyval ...string) { + kv := append([]string{"gl-go", version.Go()}, keyval...) + kv = append(kv, "gapic", version.Repo, "gax", gax.Version, "grpc", grpc.Version) + c.xGoogHeader = []string{gax.XGoogHeader(kv...)} +} + +// DatabaseAdminInstancePath returns the path for the instance resource. +func DatabaseAdminInstancePath(project, instance string) string { + path, err := databaseAdminInstancePathTemplate.Render(map[string]string{ + "project": project, + "instance": instance, + }) + if err != nil { + panic(err) + } + return path +} + +// DatabaseAdminDatabasePath returns the path for the database resource. +func DatabaseAdminDatabasePath(project, instance, database string) string { + path, err := databaseAdminDatabasePathTemplate.Render(map[string]string{ + "project": project, + "instance": instance, + "database": database, + }) + if err != nil { + panic(err) + } + return path +} + +// ListDatabases lists Cloud Spanner databases. +func (c *DatabaseAdminClient) ListDatabases(ctx context.Context, req *databasepb.ListDatabasesRequest, opts ...gax.CallOption) *DatabaseIterator { + ctx = insertXGoog(ctx, c.xGoogHeader) + opts = append(c.CallOptions.ListDatabases[0:len(c.CallOptions.ListDatabases):len(c.CallOptions.ListDatabases)], opts...) + it := &DatabaseIterator{} + it.InternalFetch = func(pageSize int, pageToken string) ([]*databasepb.Database, string, error) { + var resp *databasepb.ListDatabasesResponse + req.PageToken = pageToken + if pageSize > math.MaxInt32 { + req.PageSize = math.MaxInt32 + } else { + req.PageSize = int32(pageSize) + } + err := gax.Invoke(ctx, func(ctx context.Context, settings gax.CallSettings) error { + var err error + resp, err = c.databaseAdminClient.ListDatabases(ctx, req, settings.GRPC...) + return err + }, opts...) + if err != nil { + return nil, "", err + } + return resp.Databases, resp.NextPageToken, nil + } + fetch := func(pageSize int, pageToken string) (string, error) { + items, nextPageToken, err := it.InternalFetch(pageSize, pageToken) + if err != nil { + return "", err + } + it.items = append(it.items, items...) + return nextPageToken, nil + } + it.pageInfo, it.nextFunc = iterator.NewPageInfo(fetch, it.bufLen, it.takeBuf) + return it +} + +// CreateDatabase creates a new Cloud Spanner database and starts to prepare it for serving. +// The returned [long-running operation][google.longrunning.Operation] will +// have a name of the format `<database_name>/operations/<operation_id>` and +// can be used to track preparation of the database. The +// [metadata][google.longrunning.Operation.metadata] field type is +// [CreateDatabaseMetadata][google.spanner.admin.database.v1.CreateDatabaseMetadata]. The +// [response][google.longrunning.Operation.response] field type is +// [Database][google.spanner.admin.database.v1.Database], if successful. +func (c *DatabaseAdminClient) CreateDatabase(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (*CreateDatabaseOperation, error) { + ctx = insertXGoog(ctx, c.xGoogHeader) + opts = append(c.CallOptions.CreateDatabase[0:len(c.CallOptions.CreateDatabase):len(c.CallOptions.CreateDatabase)], opts...) + var resp *longrunningpb.Operation + err := gax.Invoke(ctx, func(ctx context.Context, settings gax.CallSettings) error { + var err error + resp, err = c.databaseAdminClient.CreateDatabase(ctx, req, settings.GRPC...) + return err + }, opts...) + if err != nil { + return nil, err + } + return &CreateDatabaseOperation{ + lro: longrunning.InternalNewOperation(c.LROClient, resp), + }, nil +} + +// GetDatabase gets the state of a Cloud Spanner database. +func (c *DatabaseAdminClient) GetDatabase(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + ctx = insertXGoog(ctx, c.xGoogHeader) + opts = append(c.CallOptions.GetDatabase[0:len(c.CallOptions.GetDatabase):len(c.CallOptions.GetDatabase)], opts...) + var resp *databasepb.Database + err := gax.Invoke(ctx, func(ctx context.Context, settings gax.CallSettings) error { + var err error + resp, err = c.databaseAdminClient.GetDatabase(ctx, req, settings.GRPC...) + return err + }, opts...) + if err != nil { + return nil, err + } + return resp, nil +} + +// UpdateDatabaseDdl updates the schema of a Cloud Spanner database by +// creating/altering/dropping tables, columns, indexes, etc. The returned +// [long-running operation][google.longrunning.Operation] will have a name of +// the format `<database_name>/operations/<operation_id>` and can be used to +// track execution of the schema change(s). The +// [metadata][google.longrunning.Operation.metadata] field type is +// [UpdateDatabaseDdlMetadata][google.spanner.admin.database.v1.UpdateDatabaseDdlMetadata]. The operation has no response. +func (c *DatabaseAdminClient) UpdateDatabaseDdl(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (*UpdateDatabaseDdlOperation, error) { + ctx = insertXGoog(ctx, c.xGoogHeader) + opts = append(c.CallOptions.UpdateDatabaseDdl[0:len(c.CallOptions.UpdateDatabaseDdl):len(c.CallOptions.UpdateDatabaseDdl)], opts...) + var resp *longrunningpb.Operation + err := gax.Invoke(ctx, func(ctx context.Context, settings gax.CallSettings) error { + var err error + resp, err = c.databaseAdminClient.UpdateDatabaseDdl(ctx, req, settings.GRPC...) + return err + }, opts...) + if err != nil { + return nil, err + } + return &UpdateDatabaseDdlOperation{ + lro: longrunning.InternalNewOperation(c.LROClient, resp), + }, nil +} + +// DropDatabase drops (aka deletes) a Cloud Spanner database. +func (c *DatabaseAdminClient) DropDatabase(ctx context.Context, req *databasepb.DropDatabaseRequest, opts ...gax.CallOption) error { + ctx = insertXGoog(ctx, c.xGoogHeader) + opts = append(c.CallOptions.DropDatabase[0:len(c.CallOptions.DropDatabase):len(c.CallOptions.DropDatabase)], opts...) + err := gax.Invoke(ctx, func(ctx context.Context, settings gax.CallSettings) error { + var err error + _, err = c.databaseAdminClient.DropDatabase(ctx, req, settings.GRPC...) + return err + }, opts...) + return err +} + +// GetDatabaseDdl returns the schema of a Cloud Spanner database as a list of formatted +// DDL statements. This method does not show pending schema updates, those may +// be queried using the [Operations][google.longrunning.Operations] API. +func (c *DatabaseAdminClient) GetDatabaseDdl(ctx context.Context, req *databasepb.GetDatabaseDdlRequest, opts ...gax.CallOption) (*databasepb.GetDatabaseDdlResponse, error) { + ctx = insertXGoog(ctx, c.xGoogHeader) + opts = append(c.CallOptions.GetDatabaseDdl[0:len(c.CallOptions.GetDatabaseDdl):len(c.CallOptions.GetDatabaseDdl)], opts...) + var resp *databasepb.GetDatabaseDdlResponse + err := gax.Invoke(ctx, func(ctx context.Context, settings gax.CallSettings) error { + var err error + resp, err = c.databaseAdminClient.GetDatabaseDdl(ctx, req, settings.GRPC...) + return err + }, opts...) + if err != nil { + return nil, err + } + return resp, nil +} + +// SetIamPolicy sets the access control policy on a database resource. Replaces any +// existing policy. +// +// Authorization requires `spanner.databases.setIamPolicy` permission on +// [resource][google.iam.v1.SetIamPolicyRequest.resource]. +func (c *DatabaseAdminClient) SetIamPolicy(ctx context.Context, req *iampb.SetIamPolicyRequest, opts ...gax.CallOption) (*iampb.Policy, error) { + ctx = insertXGoog(ctx, c.xGoogHeader) + opts = append(c.CallOptions.SetIamPolicy[0:len(c.CallOptions.SetIamPolicy):len(c.CallOptions.SetIamPolicy)], opts...) + var resp *iampb.Policy + err := gax.Invoke(ctx, func(ctx context.Context, settings gax.CallSettings) error { + var err error + resp, err = c.databaseAdminClient.SetIamPolicy(ctx, req, settings.GRPC...) + return err + }, opts...) + if err != nil { + return nil, err + } + return resp, nil +} + +// GetIamPolicy gets the access control policy for a database resource. Returns an empty +// policy if a database exists but does not have a policy set. +// +// Authorization requires `spanner.databases.getIamPolicy` permission on +// [resource][google.iam.v1.GetIamPolicyRequest.resource]. +func (c *DatabaseAdminClient) GetIamPolicy(ctx context.Context, req *iampb.GetIamPolicyRequest, opts ...gax.CallOption) (*iampb.Policy, error) { + ctx = insertXGoog(ctx, c.xGoogHeader) + opts = append(c.CallOptions.GetIamPolicy[0:len(c.CallOptions.GetIamPolicy):len(c.CallOptions.GetIamPolicy)], opts...) + var resp *iampb.Policy + err := gax.Invoke(ctx, func(ctx context.Context, settings gax.CallSettings) error { + var err error + resp, err = c.databaseAdminClient.GetIamPolicy(ctx, req, settings.GRPC...) + return err + }, opts...) + if err != nil { + return nil, err + } + return resp, nil +} + +// TestIamPermissions returns permissions that the caller has on the specified database resource. +// +// Attempting this RPC on a non-existent Cloud Spanner database will result in +// a NOT_FOUND error if the user has `spanner.databases.list` permission on +// the containing Cloud Spanner instance. Otherwise returns an empty set of +// permissions. +func (c *DatabaseAdminClient) TestIamPermissions(ctx context.Context, req *iampb.TestIamPermissionsRequest, opts ...gax.CallOption) (*iampb.TestIamPermissionsResponse, error) { + ctx = insertXGoog(ctx, c.xGoogHeader) + opts = append(c.CallOptions.TestIamPermissions[0:len(c.CallOptions.TestIamPermissions):len(c.CallOptions.TestIamPermissions)], opts...) + var resp *iampb.TestIamPermissionsResponse + err := gax.Invoke(ctx, func(ctx context.Context, settings gax.CallSettings) error { + var err error + resp, err = c.databaseAdminClient.TestIamPermissions(ctx, req, settings.GRPC...) + return err + }, opts...) + if err != nil { + return nil, err + } + return resp, nil +} + +// DatabaseIterator manages a stream of *databasepb.Database. +type DatabaseIterator struct { + items []*databasepb.Database + pageInfo *iterator.PageInfo + nextFunc func() error + + // InternalFetch is for use by the Google Cloud Libraries only. + // It is not part of the stable interface of this package. + // + // InternalFetch returns results from a single call to the underlying RPC. + // The number of results is no greater than pageSize. + // If there are no more results, nextPageToken is empty and err is nil. + InternalFetch func(pageSize int, pageToken string) (results []*databasepb.Database, nextPageToken string, err error) +} + +// PageInfo supports pagination. See the google.golang.org/api/iterator package for details. +func (it *DatabaseIterator) PageInfo() *iterator.PageInfo { + return it.pageInfo +} + +// Next returns the next result. Its second return value is iterator.Done if there are no more +// results. Once Next returns Done, all subsequent calls will return Done. +func (it *DatabaseIterator) Next() (*databasepb.Database, error) { + var item *databasepb.Database + if err := it.nextFunc(); err != nil { + return item, err + } + item = it.items[0] + it.items = it.items[1:] + return item, nil +} + +func (it *DatabaseIterator) bufLen() int { + return len(it.items) +} + +func (it *DatabaseIterator) takeBuf() interface{} { + b := it.items + it.items = nil + return b +} + +// CreateDatabaseOperation manages a long-running operation from CreateDatabase. +type CreateDatabaseOperation struct { + lro *longrunning.Operation +} + +// CreateDatabaseOperation returns a new CreateDatabaseOperation from a given name. +// The name must be that of a previously created CreateDatabaseOperation, possibly from a different process. +func (c *DatabaseAdminClient) CreateDatabaseOperation(name string) *CreateDatabaseOperation { + return &CreateDatabaseOperation{ + lro: longrunning.InternalNewOperation(c.LROClient, &longrunningpb.Operation{Name: name}), + } +} + +// Wait blocks until the long-running operation is completed, returning the response and any errors encountered. +// +// See documentation of Poll for error-handling information. +func (op *CreateDatabaseOperation) Wait(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) { + var resp databasepb.Database + if err := op.lro.Wait(ctx, &resp, opts...); err != nil { + return nil, err + } + return &resp, nil +} + +// Poll fetches the latest state of the long-running operation. +// +// Poll also fetches the latest metadata, which can be retrieved by Metadata. +// +// If Poll fails, the error is returned and op is unmodified. If Poll succeeds and +// the operation has completed with failure, the error is returned and op.Done will return true. +// If Poll succeeds and the operation has completed successfully, +// op.Done will return true, and the response of the operation is returned. +// If Poll succeeds and the operation has not completed, the returned response and error are both nil. +func (op *CreateDatabaseOperation) Poll(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) { + var resp databasepb.Database + if err := op.lro.Poll(ctx, &resp, opts...); err != nil { + return nil, err + } + if !op.Done() { + return nil, nil + } + return &resp, nil +} + +// Metadata returns metadata associated with the long-running operation. +// Metadata itself does not contact the server, but Poll does. +// To get the latest metadata, call this method after a successful call to Poll. +// If the metadata is not available, the returned metadata and error are both nil. +func (op *CreateDatabaseOperation) Metadata() (*databasepb.CreateDatabaseMetadata, error) { + var meta databasepb.CreateDatabaseMetadata + if err := op.lro.Metadata(&meta); err == longrunning.ErrNoMetadata { + return nil, nil + } else if err != nil { + return nil, err + } + return &meta, nil +} + +// Done reports whether the long-running operation has completed. +func (op *CreateDatabaseOperation) Done() bool { + return op.lro.Done() +} + +// Name returns the name of the long-running operation. +// The name is assigned by the server and is unique within the service from which the operation is created. +func (op *CreateDatabaseOperation) Name() string { + return op.lro.Name() +} + +// UpdateDatabaseDdlOperation manages a long-running operation from UpdateDatabaseDdl. +type UpdateDatabaseDdlOperation struct { + lro *longrunning.Operation +} + +// UpdateDatabaseDdlOperation returns a new UpdateDatabaseDdlOperation from a given name. +// The name must be that of a previously created UpdateDatabaseDdlOperation, possibly from a different process. +func (c *DatabaseAdminClient) UpdateDatabaseDdlOperation(name string) *UpdateDatabaseDdlOperation { + return &UpdateDatabaseDdlOperation{ + lro: longrunning.InternalNewOperation(c.LROClient, &longrunningpb.Operation{Name: name}), + } +} + +// Wait blocks until the long-running operation is completed, returning any error encountered. +// +// See documentation of Poll for error-handling information. +func (op *UpdateDatabaseDdlOperation) Wait(ctx context.Context, opts ...gax.CallOption) error { + return op.lro.Wait(ctx, nil, opts...) +} + +// Poll fetches the latest state of the long-running operation. +// +// Poll also fetches the latest metadata, which can be retrieved by Metadata. +// +// If Poll fails, the error is returned and op is unmodified. If Poll succeeds and +// the operation has completed with failure, the error is returned and op.Done will return true. +// If Poll succeeds and the operation has completed successfully, op.Done will return true. +func (op *UpdateDatabaseDdlOperation) Poll(ctx context.Context, opts ...gax.CallOption) error { + return op.lro.Poll(ctx, nil, opts...) +} + +// Metadata returns metadata associated with the long-running operation. +// Metadata itself does not contact the server, but Poll does. +// To get the latest metadata, call this method after a successful call to Poll. +// If the metadata is not available, the returned metadata and error are both nil. +func (op *UpdateDatabaseDdlOperation) Metadata() (*databasepb.UpdateDatabaseDdlMetadata, error) { + var meta databasepb.UpdateDatabaseDdlMetadata + if err := op.lro.Metadata(&meta); err == longrunning.ErrNoMetadata { + return nil, nil + } else if err != nil { + return nil, err + } + return &meta, nil +} + +// Done reports whether the long-running operation has completed. +func (op *UpdateDatabaseDdlOperation) Done() bool { + return op.lro.Done() +} + +// Name returns the name of the long-running operation. +// The name is assigned by the server and is unique within the service from which the operation is created. +func (op *UpdateDatabaseDdlOperation) Name() string { + return op.lro.Name() +} diff --git a/vendor/cloud.google.com/go/spanner/admin/database/apiv1/database_admin_client_example_test.go b/vendor/cloud.google.com/go/spanner/admin/database/apiv1/database_admin_client_example_test.go new file mode 100644 index 000000000..0769d1193 --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/admin/database/apiv1/database_admin_client_example_test.go @@ -0,0 +1,204 @@ +// Copyright 2017, Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// AUTO-GENERATED CODE. DO NOT EDIT. + +package database_test + +import ( + "cloud.google.com/go/spanner/admin/database/apiv1" + "golang.org/x/net/context" + iampb "google.golang.org/genproto/googleapis/iam/v1" + databasepb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" +) + +func ExampleNewDatabaseAdminClient() { + ctx := context.Background() + c, err := database.NewDatabaseAdminClient(ctx) + if err != nil { + // TODO: Handle error. + } + // TODO: Use client. + _ = c +} + +func ExampleDatabaseAdminClient_ListDatabases() { + ctx := context.Background() + c, err := database.NewDatabaseAdminClient(ctx) + if err != nil { + // TODO: Handle error. + } + + req := &databasepb.ListDatabasesRequest{ + // TODO: Fill request struct fields. + } + it := c.ListDatabases(ctx, req) + for { + resp, err := it.Next() + if err != nil { + // TODO: Handle error. + break + } + // TODO: Use resp. + _ = resp + } +} + +func ExampleDatabaseAdminClient_CreateDatabase() { + ctx := context.Background() + c, err := database.NewDatabaseAdminClient(ctx) + if err != nil { + // TODO: Handle error. + } + + req := &databasepb.CreateDatabaseRequest{ + // TODO: Fill request struct fields. + } + op, err := c.CreateDatabase(ctx, req) + if err != nil { + // TODO: Handle error. + } + + resp, err := op.Wait(ctx) + if err != nil { + // TODO: Handle error. + } + // TODO: Use resp. + _ = resp +} + +func ExampleDatabaseAdminClient_GetDatabase() { + ctx := context.Background() + c, err := database.NewDatabaseAdminClient(ctx) + if err != nil { + // TODO: Handle error. + } + + req := &databasepb.GetDatabaseRequest{ + // TODO: Fill request struct fields. + } + resp, err := c.GetDatabase(ctx, req) + if err != nil { + // TODO: Handle error. + } + // TODO: Use resp. + _ = resp +} + +func ExampleDatabaseAdminClient_UpdateDatabaseDdl() { + ctx := context.Background() + c, err := database.NewDatabaseAdminClient(ctx) + if err != nil { + // TODO: Handle error. + } + + req := &databasepb.UpdateDatabaseDdlRequest{ + // TODO: Fill request struct fields. + } + op, err := c.UpdateDatabaseDdl(ctx, req) + if err != nil { + // TODO: Handle error. + } + + err = op.Wait(ctx) + // TODO: Handle error. +} + +func ExampleDatabaseAdminClient_DropDatabase() { + ctx := context.Background() + c, err := database.NewDatabaseAdminClient(ctx) + if err != nil { + // TODO: Handle error. + } + + req := &databasepb.DropDatabaseRequest{ + // TODO: Fill request struct fields. + } + err = c.DropDatabase(ctx, req) + if err != nil { + // TODO: Handle error. + } +} + +func ExampleDatabaseAdminClient_GetDatabaseDdl() { + ctx := context.Background() + c, err := database.NewDatabaseAdminClient(ctx) + if err != nil { + // TODO: Handle error. + } + + req := &databasepb.GetDatabaseDdlRequest{ + // TODO: Fill request struct fields. + } + resp, err := c.GetDatabaseDdl(ctx, req) + if err != nil { + // TODO: Handle error. + } + // TODO: Use resp. + _ = resp +} + +func ExampleDatabaseAdminClient_SetIamPolicy() { + ctx := context.Background() + c, err := database.NewDatabaseAdminClient(ctx) + if err != nil { + // TODO: Handle error. + } + + req := &iampb.SetIamPolicyRequest{ + // TODO: Fill request struct fields. + } + resp, err := c.SetIamPolicy(ctx, req) + if err != nil { + // TODO: Handle error. + } + // TODO: Use resp. + _ = resp +} + +func ExampleDatabaseAdminClient_GetIamPolicy() { + ctx := context.Background() + c, err := database.NewDatabaseAdminClient(ctx) + if err != nil { + // TODO: Handle error. + } + + req := &iampb.GetIamPolicyRequest{ + // TODO: Fill request struct fields. + } + resp, err := c.GetIamPolicy(ctx, req) + if err != nil { + // TODO: Handle error. + } + // TODO: Use resp. + _ = resp +} + +func ExampleDatabaseAdminClient_TestIamPermissions() { + ctx := context.Background() + c, err := database.NewDatabaseAdminClient(ctx) + if err != nil { + // TODO: Handle error. + } + + req := &iampb.TestIamPermissionsRequest{ + // TODO: Fill request struct fields. + } + resp, err := c.TestIamPermissions(ctx, req) + if err != nil { + // TODO: Handle error. + } + // TODO: Use resp. + _ = resp +} diff --git a/vendor/cloud.google.com/go/spanner/admin/database/apiv1/doc.go b/vendor/cloud.google.com/go/spanner/admin/database/apiv1/doc.go new file mode 100644 index 000000000..46eaaae73 --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/admin/database/apiv1/doc.go @@ -0,0 +1,39 @@ +// Copyright 2017, Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// AUTO-GENERATED CODE. DO NOT EDIT. + +// Package database is an experimental, auto-generated package for the +// database API. +// +package database // import "cloud.google.com/go/spanner/admin/database/apiv1" + +import ( + "golang.org/x/net/context" + "google.golang.org/grpc/metadata" +) + +func insertXGoog(ctx context.Context, val []string) context.Context { + md, _ := metadata.FromOutgoingContext(ctx) + md = md.Copy() + md["x-goog-api-client"] = val + return metadata.NewOutgoingContext(ctx, md) +} + +func DefaultAuthScopes() []string { + return []string{ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/spanner.admin", + } +} diff --git a/vendor/cloud.google.com/go/spanner/admin/database/apiv1/mock_test.go b/vendor/cloud.google.com/go/spanner/admin/database/apiv1/mock_test.go new file mode 100644 index 000000000..accd30cb3 --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/admin/database/apiv1/mock_test.go @@ -0,0 +1,779 @@ +// Copyright 2017, Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// AUTO-GENERATED CODE. DO NOT EDIT. + +package database + +import ( + emptypb "github.com/golang/protobuf/ptypes/empty" + iampb "google.golang.org/genproto/googleapis/iam/v1" + longrunningpb "google.golang.org/genproto/googleapis/longrunning" + databasepb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" +) + +import ( + "flag" + "fmt" + "io" + "log" + "net" + "os" + "strings" + "testing" + + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes" + "golang.org/x/net/context" + "google.golang.org/api/option" + status "google.golang.org/genproto/googleapis/rpc/status" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" +) + +var _ = io.EOF +var _ = ptypes.MarshalAny +var _ status.Status + +type mockDatabaseAdminServer struct { + // Embed for forward compatibility. + // Tests will keep working if more methods are added + // in the future. + databasepb.DatabaseAdminServer + + reqs []proto.Message + + // If set, all calls return this error. + err error + + // responses to return if err == nil + resps []proto.Message +} + +func (s *mockDatabaseAdminServer) ListDatabases(ctx context.Context, req *databasepb.ListDatabasesRequest) (*databasepb.ListDatabasesResponse, error) { + md, _ := metadata.FromIncomingContext(ctx) + if xg := md["x-goog-api-client"]; len(xg) == 0 || !strings.Contains(xg[0], "gl-go/") { + return nil, fmt.Errorf("x-goog-api-client = %v, expected gl-go key", xg) + } + s.reqs = append(s.reqs, req) + if s.err != nil { + return nil, s.err + } + return s.resps[0].(*databasepb.ListDatabasesResponse), nil +} + +func (s *mockDatabaseAdminServer) CreateDatabase(ctx context.Context, req *databasepb.CreateDatabaseRequest) (*longrunningpb.Operation, error) { + md, _ := metadata.FromIncomingContext(ctx) + if xg := md["x-goog-api-client"]; len(xg) == 0 || !strings.Contains(xg[0], "gl-go/") { + return nil, fmt.Errorf("x-goog-api-client = %v, expected gl-go key", xg) + } + s.reqs = append(s.reqs, req) + if s.err != nil { + return nil, s.err + } + return s.resps[0].(*longrunningpb.Operation), nil +} + +func (s *mockDatabaseAdminServer) GetDatabase(ctx context.Context, req *databasepb.GetDatabaseRequest) (*databasepb.Database, error) { + md, _ := metadata.FromIncomingContext(ctx) + if xg := md["x-goog-api-client"]; len(xg) == 0 || !strings.Contains(xg[0], "gl-go/") { + return nil, fmt.Errorf("x-goog-api-client = %v, expected gl-go key", xg) + } + s.reqs = append(s.reqs, req) + if s.err != nil { + return nil, s.err + } + return s.resps[0].(*databasepb.Database), nil +} + +func (s *mockDatabaseAdminServer) UpdateDatabaseDdl(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest) (*longrunningpb.Operation, error) { + md, _ := metadata.FromIncomingContext(ctx) + if xg := md["x-goog-api-client"]; len(xg) == 0 || !strings.Contains(xg[0], "gl-go/") { + return nil, fmt.Errorf("x-goog-api-client = %v, expected gl-go key", xg) + } + s.reqs = append(s.reqs, req) + if s.err != nil { + return nil, s.err + } + return s.resps[0].(*longrunningpb.Operation), nil +} + +func (s *mockDatabaseAdminServer) DropDatabase(ctx context.Context, req *databasepb.DropDatabaseRequest) (*emptypb.Empty, error) { + md, _ := metadata.FromIncomingContext(ctx) + if xg := md["x-goog-api-client"]; len(xg) == 0 || !strings.Contains(xg[0], "gl-go/") { + return nil, fmt.Errorf("x-goog-api-client = %v, expected gl-go key", xg) + } + s.reqs = append(s.reqs, req) + if s.err != nil { + return nil, s.err + } + return s.resps[0].(*emptypb.Empty), nil +} + +func (s *mockDatabaseAdminServer) GetDatabaseDdl(ctx context.Context, req *databasepb.GetDatabaseDdlRequest) (*databasepb.GetDatabaseDdlResponse, error) { + md, _ := metadata.FromIncomingContext(ctx) + if xg := md["x-goog-api-client"]; len(xg) == 0 || !strings.Contains(xg[0], "gl-go/") { + return nil, fmt.Errorf("x-goog-api-client = %v, expected gl-go key", xg) + } + s.reqs = append(s.reqs, req) + if s.err != nil { + return nil, s.err + } + return s.resps[0].(*databasepb.GetDatabaseDdlResponse), nil +} + +func (s *mockDatabaseAdminServer) SetIamPolicy(ctx context.Context, req *iampb.SetIamPolicyRequest) (*iampb.Policy, error) { + md, _ := metadata.FromIncomingContext(ctx) + if xg := md["x-goog-api-client"]; len(xg) == 0 || !strings.Contains(xg[0], "gl-go/") { + return nil, fmt.Errorf("x-goog-api-client = %v, expected gl-go key", xg) + } + s.reqs = append(s.reqs, req) + if s.err != nil { + return nil, s.err + } + return s.resps[0].(*iampb.Policy), nil +} + +func (s *mockDatabaseAdminServer) GetIamPolicy(ctx context.Context, req *iampb.GetIamPolicyRequest) (*iampb.Policy, error) { + md, _ := metadata.FromIncomingContext(ctx) + if xg := md["x-goog-api-client"]; len(xg) == 0 || !strings.Contains(xg[0], "gl-go/") { + return nil, fmt.Errorf("x-goog-api-client = %v, expected gl-go key", xg) + } + s.reqs = append(s.reqs, req) + if s.err != nil { + return nil, s.err + } + return s.resps[0].(*iampb.Policy), nil +} + +func (s *mockDatabaseAdminServer) TestIamPermissions(ctx context.Context, req *iampb.TestIamPermissionsRequest) (*iampb.TestIamPermissionsResponse, error) { + md, _ := metadata.FromIncomingContext(ctx) + if xg := md["x-goog-api-client"]; len(xg) == 0 || !strings.Contains(xg[0], "gl-go/") { + return nil, fmt.Errorf("x-goog-api-client = %v, expected gl-go key", xg) + } + s.reqs = append(s.reqs, req) + if s.err != nil { + return nil, s.err + } + return s.resps[0].(*iampb.TestIamPermissionsResponse), nil +} + +// clientOpt is the option tests should use to connect to the test server. +// It is initialized by TestMain. +var clientOpt option.ClientOption + +var ( + mockDatabaseAdmin mockDatabaseAdminServer +) + +func TestMain(m *testing.M) { + flag.Parse() + + serv := grpc.NewServer() + databasepb.RegisterDatabaseAdminServer(serv, &mockDatabaseAdmin) + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + log.Fatal(err) + } + go serv.Serve(lis) + + conn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) + if err != nil { + log.Fatal(err) + } + clientOpt = option.WithGRPCConn(conn) + + os.Exit(m.Run()) +} + +func TestDatabaseAdminListDatabases(t *testing.T) { + var nextPageToken string = "" + var databasesElement *databasepb.Database = &databasepb.Database{} + var databases = []*databasepb.Database{databasesElement} + var expectedResponse = &databasepb.ListDatabasesResponse{ + NextPageToken: nextPageToken, + Databases: databases, + } + + mockDatabaseAdmin.err = nil + mockDatabaseAdmin.reqs = nil + + mockDatabaseAdmin.resps = append(mockDatabaseAdmin.resps[:0], expectedResponse) + + var formattedParent string = DatabaseAdminInstancePath("[PROJECT]", "[INSTANCE]") + var request = &databasepb.ListDatabasesRequest{ + Parent: formattedParent, + } + + c, err := NewDatabaseAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + resp, err := c.ListDatabases(context.Background(), request).Next() + + if err != nil { + t.Fatal(err) + } + + if want, got := request, mockDatabaseAdmin.reqs[0]; !proto.Equal(want, got) { + t.Errorf("wrong request %q, want %q", got, want) + } + + want := (interface{})(expectedResponse.Databases[0]) + got := (interface{})(resp) + var ok bool + + switch want := (want).(type) { + case proto.Message: + ok = proto.Equal(want, got.(proto.Message)) + default: + ok = want == got + } + if !ok { + t.Errorf("wrong response %q, want %q)", got, want) + } +} + +func TestDatabaseAdminListDatabasesError(t *testing.T) { + errCode := codes.PermissionDenied + mockDatabaseAdmin.err = grpc.Errorf(errCode, "test error") + + var formattedParent string = DatabaseAdminInstancePath("[PROJECT]", "[INSTANCE]") + var request = &databasepb.ListDatabasesRequest{ + Parent: formattedParent, + } + + c, err := NewDatabaseAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + resp, err := c.ListDatabases(context.Background(), request).Next() + + if c := grpc.Code(err); c != errCode { + t.Errorf("got error code %q, want %q", c, errCode) + } + _ = resp +} +func TestDatabaseAdminCreateDatabase(t *testing.T) { + var name string = "name3373707" + var expectedResponse = &databasepb.Database{ + Name: name, + } + + mockDatabaseAdmin.err = nil + mockDatabaseAdmin.reqs = nil + + any, err := ptypes.MarshalAny(expectedResponse) + if err != nil { + t.Fatal(err) + } + mockDatabaseAdmin.resps = append(mockDatabaseAdmin.resps[:0], &longrunningpb.Operation{ + Name: "longrunning-test", + Done: true, + Result: &longrunningpb.Operation_Response{Response: any}, + }) + + var formattedParent string = DatabaseAdminInstancePath("[PROJECT]", "[INSTANCE]") + var createStatement string = "createStatement552974828" + var request = &databasepb.CreateDatabaseRequest{ + Parent: formattedParent, + CreateStatement: createStatement, + } + + c, err := NewDatabaseAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + respLRO, err := c.CreateDatabase(context.Background(), request) + if err != nil { + t.Fatal(err) + } + resp, err := respLRO.Wait(context.Background()) + + if err != nil { + t.Fatal(err) + } + + if want, got := request, mockDatabaseAdmin.reqs[0]; !proto.Equal(want, got) { + t.Errorf("wrong request %q, want %q", got, want) + } + + if want, got := expectedResponse, resp; !proto.Equal(want, got) { + t.Errorf("wrong response %q, want %q)", got, want) + } +} + +func TestDatabaseAdminCreateDatabaseError(t *testing.T) { + errCode := codes.PermissionDenied + mockDatabaseAdmin.err = nil + mockDatabaseAdmin.resps = append(mockDatabaseAdmin.resps[:0], &longrunningpb.Operation{ + Name: "longrunning-test", + Done: true, + Result: &longrunningpb.Operation_Error{ + Error: &status.Status{ + Code: int32(errCode), + Message: "test error", + }, + }, + }) + + var formattedParent string = DatabaseAdminInstancePath("[PROJECT]", "[INSTANCE]") + var createStatement string = "createStatement552974828" + var request = &databasepb.CreateDatabaseRequest{ + Parent: formattedParent, + CreateStatement: createStatement, + } + + c, err := NewDatabaseAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + respLRO, err := c.CreateDatabase(context.Background(), request) + if err != nil { + t.Fatal(err) + } + resp, err := respLRO.Wait(context.Background()) + + if c := grpc.Code(err); c != errCode { + t.Errorf("got error code %q, want %q", c, errCode) + } + _ = resp +} +func TestDatabaseAdminGetDatabase(t *testing.T) { + var name2 string = "name2-1052831874" + var expectedResponse = &databasepb.Database{ + Name: name2, + } + + mockDatabaseAdmin.err = nil + mockDatabaseAdmin.reqs = nil + + mockDatabaseAdmin.resps = append(mockDatabaseAdmin.resps[:0], expectedResponse) + + var formattedName string = DatabaseAdminDatabasePath("[PROJECT]", "[INSTANCE]", "[DATABASE]") + var request = &databasepb.GetDatabaseRequest{ + Name: formattedName, + } + + c, err := NewDatabaseAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + resp, err := c.GetDatabase(context.Background(), request) + + if err != nil { + t.Fatal(err) + } + + if want, got := request, mockDatabaseAdmin.reqs[0]; !proto.Equal(want, got) { + t.Errorf("wrong request %q, want %q", got, want) + } + + if want, got := expectedResponse, resp; !proto.Equal(want, got) { + t.Errorf("wrong response %q, want %q)", got, want) + } +} + +func TestDatabaseAdminGetDatabaseError(t *testing.T) { + errCode := codes.PermissionDenied + mockDatabaseAdmin.err = grpc.Errorf(errCode, "test error") + + var formattedName string = DatabaseAdminDatabasePath("[PROJECT]", "[INSTANCE]", "[DATABASE]") + var request = &databasepb.GetDatabaseRequest{ + Name: formattedName, + } + + c, err := NewDatabaseAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + resp, err := c.GetDatabase(context.Background(), request) + + if c := grpc.Code(err); c != errCode { + t.Errorf("got error code %q, want %q", c, errCode) + } + _ = resp +} +func TestDatabaseAdminUpdateDatabaseDdl(t *testing.T) { + var expectedResponse *emptypb.Empty = &emptypb.Empty{} + + mockDatabaseAdmin.err = nil + mockDatabaseAdmin.reqs = nil + + any, err := ptypes.MarshalAny(expectedResponse) + if err != nil { + t.Fatal(err) + } + mockDatabaseAdmin.resps = append(mockDatabaseAdmin.resps[:0], &longrunningpb.Operation{ + Name: "longrunning-test", + Done: true, + Result: &longrunningpb.Operation_Response{Response: any}, + }) + + var formattedDatabase string = DatabaseAdminDatabasePath("[PROJECT]", "[INSTANCE]", "[DATABASE]") + var statements []string = nil + var request = &databasepb.UpdateDatabaseDdlRequest{ + Database: formattedDatabase, + Statements: statements, + } + + c, err := NewDatabaseAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + respLRO, err := c.UpdateDatabaseDdl(context.Background(), request) + if err != nil { + t.Fatal(err) + } + err = respLRO.Wait(context.Background()) + + if err != nil { + t.Fatal(err) + } + + if want, got := request, mockDatabaseAdmin.reqs[0]; !proto.Equal(want, got) { + t.Errorf("wrong request %q, want %q", got, want) + } + +} + +func TestDatabaseAdminUpdateDatabaseDdlError(t *testing.T) { + errCode := codes.PermissionDenied + mockDatabaseAdmin.err = nil + mockDatabaseAdmin.resps = append(mockDatabaseAdmin.resps[:0], &longrunningpb.Operation{ + Name: "longrunning-test", + Done: true, + Result: &longrunningpb.Operation_Error{ + Error: &status.Status{ + Code: int32(errCode), + Message: "test error", + }, + }, + }) + + var formattedDatabase string = DatabaseAdminDatabasePath("[PROJECT]", "[INSTANCE]", "[DATABASE]") + var statements []string = nil + var request = &databasepb.UpdateDatabaseDdlRequest{ + Database: formattedDatabase, + Statements: statements, + } + + c, err := NewDatabaseAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + respLRO, err := c.UpdateDatabaseDdl(context.Background(), request) + if err != nil { + t.Fatal(err) + } + err = respLRO.Wait(context.Background()) + + if c := grpc.Code(err); c != errCode { + t.Errorf("got error code %q, want %q", c, errCode) + } +} +func TestDatabaseAdminDropDatabase(t *testing.T) { + var expectedResponse *emptypb.Empty = &emptypb.Empty{} + + mockDatabaseAdmin.err = nil + mockDatabaseAdmin.reqs = nil + + mockDatabaseAdmin.resps = append(mockDatabaseAdmin.resps[:0], expectedResponse) + + var formattedDatabase string = DatabaseAdminDatabasePath("[PROJECT]", "[INSTANCE]", "[DATABASE]") + var request = &databasepb.DropDatabaseRequest{ + Database: formattedDatabase, + } + + c, err := NewDatabaseAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + err = c.DropDatabase(context.Background(), request) + + if err != nil { + t.Fatal(err) + } + + if want, got := request, mockDatabaseAdmin.reqs[0]; !proto.Equal(want, got) { + t.Errorf("wrong request %q, want %q", got, want) + } + +} + +func TestDatabaseAdminDropDatabaseError(t *testing.T) { + errCode := codes.PermissionDenied + mockDatabaseAdmin.err = grpc.Errorf(errCode, "test error") + + var formattedDatabase string = DatabaseAdminDatabasePath("[PROJECT]", "[INSTANCE]", "[DATABASE]") + var request = &databasepb.DropDatabaseRequest{ + Database: formattedDatabase, + } + + c, err := NewDatabaseAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + err = c.DropDatabase(context.Background(), request) + + if c := grpc.Code(err); c != errCode { + t.Errorf("got error code %q, want %q", c, errCode) + } +} +func TestDatabaseAdminGetDatabaseDdl(t *testing.T) { + var expectedResponse *databasepb.GetDatabaseDdlResponse = &databasepb.GetDatabaseDdlResponse{} + + mockDatabaseAdmin.err = nil + mockDatabaseAdmin.reqs = nil + + mockDatabaseAdmin.resps = append(mockDatabaseAdmin.resps[:0], expectedResponse) + + var formattedDatabase string = DatabaseAdminDatabasePath("[PROJECT]", "[INSTANCE]", "[DATABASE]") + var request = &databasepb.GetDatabaseDdlRequest{ + Database: formattedDatabase, + } + + c, err := NewDatabaseAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + resp, err := c.GetDatabaseDdl(context.Background(), request) + + if err != nil { + t.Fatal(err) + } + + if want, got := request, mockDatabaseAdmin.reqs[0]; !proto.Equal(want, got) { + t.Errorf("wrong request %q, want %q", got, want) + } + + if want, got := expectedResponse, resp; !proto.Equal(want, got) { + t.Errorf("wrong response %q, want %q)", got, want) + } +} + +func TestDatabaseAdminGetDatabaseDdlError(t *testing.T) { + errCode := codes.PermissionDenied + mockDatabaseAdmin.err = grpc.Errorf(errCode, "test error") + + var formattedDatabase string = DatabaseAdminDatabasePath("[PROJECT]", "[INSTANCE]", "[DATABASE]") + var request = &databasepb.GetDatabaseDdlRequest{ + Database: formattedDatabase, + } + + c, err := NewDatabaseAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + resp, err := c.GetDatabaseDdl(context.Background(), request) + + if c := grpc.Code(err); c != errCode { + t.Errorf("got error code %q, want %q", c, errCode) + } + _ = resp +} +func TestDatabaseAdminSetIamPolicy(t *testing.T) { + var version int32 = 351608024 + var etag []byte = []byte("21") + var expectedResponse = &iampb.Policy{ + Version: version, + Etag: etag, + } + + mockDatabaseAdmin.err = nil + mockDatabaseAdmin.reqs = nil + + mockDatabaseAdmin.resps = append(mockDatabaseAdmin.resps[:0], expectedResponse) + + var formattedResource string = DatabaseAdminDatabasePath("[PROJECT]", "[INSTANCE]", "[DATABASE]") + var policy *iampb.Policy = &iampb.Policy{} + var request = &iampb.SetIamPolicyRequest{ + Resource: formattedResource, + Policy: policy, + } + + c, err := NewDatabaseAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + resp, err := c.SetIamPolicy(context.Background(), request) + + if err != nil { + t.Fatal(err) + } + + if want, got := request, mockDatabaseAdmin.reqs[0]; !proto.Equal(want, got) { + t.Errorf("wrong request %q, want %q", got, want) + } + + if want, got := expectedResponse, resp; !proto.Equal(want, got) { + t.Errorf("wrong response %q, want %q)", got, want) + } +} + +func TestDatabaseAdminSetIamPolicyError(t *testing.T) { + errCode := codes.PermissionDenied + mockDatabaseAdmin.err = grpc.Errorf(errCode, "test error") + + var formattedResource string = DatabaseAdminDatabasePath("[PROJECT]", "[INSTANCE]", "[DATABASE]") + var policy *iampb.Policy = &iampb.Policy{} + var request = &iampb.SetIamPolicyRequest{ + Resource: formattedResource, + Policy: policy, + } + + c, err := NewDatabaseAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + resp, err := c.SetIamPolicy(context.Background(), request) + + if c := grpc.Code(err); c != errCode { + t.Errorf("got error code %q, want %q", c, errCode) + } + _ = resp +} +func TestDatabaseAdminGetIamPolicy(t *testing.T) { + var version int32 = 351608024 + var etag []byte = []byte("21") + var expectedResponse = &iampb.Policy{ + Version: version, + Etag: etag, + } + + mockDatabaseAdmin.err = nil + mockDatabaseAdmin.reqs = nil + + mockDatabaseAdmin.resps = append(mockDatabaseAdmin.resps[:0], expectedResponse) + + var formattedResource string = DatabaseAdminDatabasePath("[PROJECT]", "[INSTANCE]", "[DATABASE]") + var request = &iampb.GetIamPolicyRequest{ + Resource: formattedResource, + } + + c, err := NewDatabaseAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + resp, err := c.GetIamPolicy(context.Background(), request) + + if err != nil { + t.Fatal(err) + } + + if want, got := request, mockDatabaseAdmin.reqs[0]; !proto.Equal(want, got) { + t.Errorf("wrong request %q, want %q", got, want) + } + + if want, got := expectedResponse, resp; !proto.Equal(want, got) { + t.Errorf("wrong response %q, want %q)", got, want) + } +} + +func TestDatabaseAdminGetIamPolicyError(t *testing.T) { + errCode := codes.PermissionDenied + mockDatabaseAdmin.err = grpc.Errorf(errCode, "test error") + + var formattedResource string = DatabaseAdminDatabasePath("[PROJECT]", "[INSTANCE]", "[DATABASE]") + var request = &iampb.GetIamPolicyRequest{ + Resource: formattedResource, + } + + c, err := NewDatabaseAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + resp, err := c.GetIamPolicy(context.Background(), request) + + if c := grpc.Code(err); c != errCode { + t.Errorf("got error code %q, want %q", c, errCode) + } + _ = resp +} +func TestDatabaseAdminTestIamPermissions(t *testing.T) { + var expectedResponse *iampb.TestIamPermissionsResponse = &iampb.TestIamPermissionsResponse{} + + mockDatabaseAdmin.err = nil + mockDatabaseAdmin.reqs = nil + + mockDatabaseAdmin.resps = append(mockDatabaseAdmin.resps[:0], expectedResponse) + + var formattedResource string = DatabaseAdminDatabasePath("[PROJECT]", "[INSTANCE]", "[DATABASE]") + var permissions []string = nil + var request = &iampb.TestIamPermissionsRequest{ + Resource: formattedResource, + Permissions: permissions, + } + + c, err := NewDatabaseAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + resp, err := c.TestIamPermissions(context.Background(), request) + + if err != nil { + t.Fatal(err) + } + + if want, got := request, mockDatabaseAdmin.reqs[0]; !proto.Equal(want, got) { + t.Errorf("wrong request %q, want %q", got, want) + } + + if want, got := expectedResponse, resp; !proto.Equal(want, got) { + t.Errorf("wrong response %q, want %q)", got, want) + } +} + +func TestDatabaseAdminTestIamPermissionsError(t *testing.T) { + errCode := codes.PermissionDenied + mockDatabaseAdmin.err = grpc.Errorf(errCode, "test error") + + var formattedResource string = DatabaseAdminDatabasePath("[PROJECT]", "[INSTANCE]", "[DATABASE]") + var permissions []string = nil + var request = &iampb.TestIamPermissionsRequest{ + Resource: formattedResource, + Permissions: permissions, + } + + c, err := NewDatabaseAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + resp, err := c.TestIamPermissions(context.Background(), request) + + if c := grpc.Code(err); c != errCode { + t.Errorf("got error code %q, want %q", c, errCode) + } + _ = resp +} diff --git a/vendor/cloud.google.com/go/spanner/admin/instance/apiv1/doc.go b/vendor/cloud.google.com/go/spanner/admin/instance/apiv1/doc.go new file mode 100644 index 000000000..b0c982123 --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/admin/instance/apiv1/doc.go @@ -0,0 +1,39 @@ +// Copyright 2017, Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// AUTO-GENERATED CODE. DO NOT EDIT. + +// Package instance is an experimental, auto-generated package for the +// instance API. +// +package instance // import "cloud.google.com/go/spanner/admin/instance/apiv1" + +import ( + "golang.org/x/net/context" + "google.golang.org/grpc/metadata" +) + +func insertXGoog(ctx context.Context, val []string) context.Context { + md, _ := metadata.FromOutgoingContext(ctx) + md = md.Copy() + md["x-goog-api-client"] = val + return metadata.NewOutgoingContext(ctx, md) +} + +func DefaultAuthScopes() []string { + return []string{ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/spanner.admin", + } +} diff --git a/vendor/cloud.google.com/go/spanner/admin/instance/apiv1/instance_admin_client.go b/vendor/cloud.google.com/go/spanner/admin/instance/apiv1/instance_admin_client.go new file mode 100644 index 000000000..138e813f2 --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/admin/instance/apiv1/instance_admin_client.go @@ -0,0 +1,742 @@ +// Copyright 2017, Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// AUTO-GENERATED CODE. DO NOT EDIT. + +package instance + +import ( + "math" + "time" + + "cloud.google.com/go/internal/version" + "cloud.google.com/go/longrunning" + lroauto "cloud.google.com/go/longrunning/autogen" + gax "github.com/googleapis/gax-go" + "golang.org/x/net/context" + "google.golang.org/api/iterator" + "google.golang.org/api/option" + "google.golang.org/api/transport" + iampb "google.golang.org/genproto/googleapis/iam/v1" + longrunningpb "google.golang.org/genproto/googleapis/longrunning" + instancepb "google.golang.org/genproto/googleapis/spanner/admin/instance/v1" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" +) + +var ( + instanceAdminProjectPathTemplate = gax.MustCompilePathTemplate("projects/{project}") + instanceAdminInstanceConfigPathTemplate = gax.MustCompilePathTemplate("projects/{project}/instanceConfigs/{instance_config}") + instanceAdminInstancePathTemplate = gax.MustCompilePathTemplate("projects/{project}/instances/{instance}") +) + +// InstanceAdminCallOptions contains the retry settings for each method of InstanceAdminClient. +type InstanceAdminCallOptions struct { + ListInstanceConfigs []gax.CallOption + GetInstanceConfig []gax.CallOption + ListInstances []gax.CallOption + GetInstance []gax.CallOption + CreateInstance []gax.CallOption + UpdateInstance []gax.CallOption + DeleteInstance []gax.CallOption + SetIamPolicy []gax.CallOption + GetIamPolicy []gax.CallOption + TestIamPermissions []gax.CallOption +} + +func defaultInstanceAdminClientOptions() []option.ClientOption { + return []option.ClientOption{ + option.WithEndpoint("spanner.googleapis.com:443"), + option.WithScopes(DefaultAuthScopes()...), + } +} + +func defaultInstanceAdminCallOptions() *InstanceAdminCallOptions { + retry := map[[2]string][]gax.CallOption{ + {"default", "idempotent"}: { + gax.WithRetry(func() gax.Retryer { + return gax.OnCodes([]codes.Code{ + codes.DeadlineExceeded, + codes.Unavailable, + }, gax.Backoff{ + Initial: 1000 * time.Millisecond, + Max: 32000 * time.Millisecond, + Multiplier: 1.3, + }) + }), + }, + {"default", "non_idempotent"}: { + gax.WithRetry(func() gax.Retryer { + return gax.OnCodes([]codes.Code{ + codes.Unavailable, + }, gax.Backoff{ + Initial: 1000 * time.Millisecond, + Max: 32000 * time.Millisecond, + Multiplier: 1.3, + }) + }), + }, + } + return &InstanceAdminCallOptions{ + ListInstanceConfigs: retry[[2]string{"default", "idempotent"}], + GetInstanceConfig: retry[[2]string{"default", "idempotent"}], + ListInstances: retry[[2]string{"default", "idempotent"}], + GetInstance: retry[[2]string{"default", "idempotent"}], + CreateInstance: retry[[2]string{"default", "non_idempotent"}], + UpdateInstance: retry[[2]string{"default", "non_idempotent"}], + DeleteInstance: retry[[2]string{"default", "idempotent"}], + SetIamPolicy: retry[[2]string{"default", "non_idempotent"}], + GetIamPolicy: retry[[2]string{"default", "idempotent"}], + TestIamPermissions: retry[[2]string{"default", "non_idempotent"}], + } +} + +// InstanceAdminClient is a client for interacting with Cloud Spanner Instance Admin API. +type InstanceAdminClient struct { + // The connection to the service. + conn *grpc.ClientConn + + // The gRPC API client. + instanceAdminClient instancepb.InstanceAdminClient + + // LROClient is used internally to handle longrunning operations. + // It is exposed so that its CallOptions can be modified if required. + // Users should not Close this client. + LROClient *lroauto.OperationsClient + + // The call options for this service. + CallOptions *InstanceAdminCallOptions + + // The metadata to be sent with each request. + xGoogHeader []string +} + +// NewInstanceAdminClient creates a new instance admin client. +// +// Cloud Spanner Instance Admin API +// +// The Cloud Spanner Instance Admin API can be used to create, delete, +// modify and list instances. Instances are dedicated Cloud Spanner serving +// and storage resources to be used by Cloud Spanner databases. +// +// Each instance has a "configuration", which dictates where the +// serving resources for the Cloud Spanner instance are located (e.g., +// US-central, Europe). Configurations are created by Google based on +// resource availability. +// +// Cloud Spanner billing is based on the instances that exist and their +// sizes. After an instance exists, there are no additional +// per-database or per-operation charges for use of the instance +// (though there may be additional network bandwidth charges). +// Instances offer isolation: problems with databases in one instance +// will not affect other instances. However, within an instance +// databases can affect each other. For example, if one database in an +// instance receives a lot of requests and consumes most of the +// instance resources, fewer resources are available for other +// databases in that instance, and their performance may suffer. +func NewInstanceAdminClient(ctx context.Context, opts ...option.ClientOption) (*InstanceAdminClient, error) { + conn, err := transport.DialGRPC(ctx, append(defaultInstanceAdminClientOptions(), opts...)...) + if err != nil { + return nil, err + } + c := &InstanceAdminClient{ + conn: conn, + CallOptions: defaultInstanceAdminCallOptions(), + + instanceAdminClient: instancepb.NewInstanceAdminClient(conn), + } + c.SetGoogleClientInfo() + + c.LROClient, err = lroauto.NewOperationsClient(ctx, option.WithGRPCConn(conn)) + if err != nil { + // This error "should not happen", since we are just reusing old connection + // and never actually need to dial. + // If this does happen, we could leak conn. However, we cannot close conn: + // If the user invoked the function with option.WithGRPCConn, + // we would close a connection that's still in use. + // TODO(pongad): investigate error conditions. + return nil, err + } + return c, nil +} + +// Connection returns the client's connection to the API service. +func (c *InstanceAdminClient) Connection() *grpc.ClientConn { + return c.conn +} + +// Close closes the connection to the API service. The user should invoke this when +// the client is no longer required. +func (c *InstanceAdminClient) Close() error { + return c.conn.Close() +} + +// SetGoogleClientInfo sets the name and version of the application in +// the `x-goog-api-client` header passed on each request. Intended for +// use by Google-written clients. +func (c *InstanceAdminClient) SetGoogleClientInfo(keyval ...string) { + kv := append([]string{"gl-go", version.Go()}, keyval...) + kv = append(kv, "gapic", version.Repo, "gax", gax.Version, "grpc", grpc.Version) + c.xGoogHeader = []string{gax.XGoogHeader(kv...)} +} + +// InstanceAdminProjectPath returns the path for the project resource. +func InstanceAdminProjectPath(project string) string { + path, err := instanceAdminProjectPathTemplate.Render(map[string]string{ + "project": project, + }) + if err != nil { + panic(err) + } + return path +} + +// InstanceAdminInstanceConfigPath returns the path for the instance config resource. +func InstanceAdminInstanceConfigPath(project, instanceConfig string) string { + path, err := instanceAdminInstanceConfigPathTemplate.Render(map[string]string{ + "project": project, + "instance_config": instanceConfig, + }) + if err != nil { + panic(err) + } + return path +} + +// InstanceAdminInstancePath returns the path for the instance resource. +func InstanceAdminInstancePath(project, instance string) string { + path, err := instanceAdminInstancePathTemplate.Render(map[string]string{ + "project": project, + "instance": instance, + }) + if err != nil { + panic(err) + } + return path +} + +// ListInstanceConfigs lists the supported instance configurations for a given project. +func (c *InstanceAdminClient) ListInstanceConfigs(ctx context.Context, req *instancepb.ListInstanceConfigsRequest, opts ...gax.CallOption) *InstanceConfigIterator { + ctx = insertXGoog(ctx, c.xGoogHeader) + opts = append(c.CallOptions.ListInstanceConfigs[0:len(c.CallOptions.ListInstanceConfigs):len(c.CallOptions.ListInstanceConfigs)], opts...) + it := &InstanceConfigIterator{} + it.InternalFetch = func(pageSize int, pageToken string) ([]*instancepb.InstanceConfig, string, error) { + var resp *instancepb.ListInstanceConfigsResponse + req.PageToken = pageToken + if pageSize > math.MaxInt32 { + req.PageSize = math.MaxInt32 + } else { + req.PageSize = int32(pageSize) + } + err := gax.Invoke(ctx, func(ctx context.Context, settings gax.CallSettings) error { + var err error + resp, err = c.instanceAdminClient.ListInstanceConfigs(ctx, req, settings.GRPC...) + return err + }, opts...) + if err != nil { + return nil, "", err + } + return resp.InstanceConfigs, resp.NextPageToken, nil + } + fetch := func(pageSize int, pageToken string) (string, error) { + items, nextPageToken, err := it.InternalFetch(pageSize, pageToken) + if err != nil { + return "", err + } + it.items = append(it.items, items...) + return nextPageToken, nil + } + it.pageInfo, it.nextFunc = iterator.NewPageInfo(fetch, it.bufLen, it.takeBuf) + return it +} + +// GetInstanceConfig gets information about a particular instance configuration. +func (c *InstanceAdminClient) GetInstanceConfig(ctx context.Context, req *instancepb.GetInstanceConfigRequest, opts ...gax.CallOption) (*instancepb.InstanceConfig, error) { + ctx = insertXGoog(ctx, c.xGoogHeader) + opts = append(c.CallOptions.GetInstanceConfig[0:len(c.CallOptions.GetInstanceConfig):len(c.CallOptions.GetInstanceConfig)], opts...) + var resp *instancepb.InstanceConfig + err := gax.Invoke(ctx, func(ctx context.Context, settings gax.CallSettings) error { + var err error + resp, err = c.instanceAdminClient.GetInstanceConfig(ctx, req, settings.GRPC...) + return err + }, opts...) + if err != nil { + return nil, err + } + return resp, nil +} + +// ListInstances lists all instances in the given project. +func (c *InstanceAdminClient) ListInstances(ctx context.Context, req *instancepb.ListInstancesRequest, opts ...gax.CallOption) *InstanceIterator { + ctx = insertXGoog(ctx, c.xGoogHeader) + opts = append(c.CallOptions.ListInstances[0:len(c.CallOptions.ListInstances):len(c.CallOptions.ListInstances)], opts...) + it := &InstanceIterator{} + it.InternalFetch = func(pageSize int, pageToken string) ([]*instancepb.Instance, string, error) { + var resp *instancepb.ListInstancesResponse + req.PageToken = pageToken + if pageSize > math.MaxInt32 { + req.PageSize = math.MaxInt32 + } else { + req.PageSize = int32(pageSize) + } + err := gax.Invoke(ctx, func(ctx context.Context, settings gax.CallSettings) error { + var err error + resp, err = c.instanceAdminClient.ListInstances(ctx, req, settings.GRPC...) + return err + }, opts...) + if err != nil { + return nil, "", err + } + return resp.Instances, resp.NextPageToken, nil + } + fetch := func(pageSize int, pageToken string) (string, error) { + items, nextPageToken, err := it.InternalFetch(pageSize, pageToken) + if err != nil { + return "", err + } + it.items = append(it.items, items...) + return nextPageToken, nil + } + it.pageInfo, it.nextFunc = iterator.NewPageInfo(fetch, it.bufLen, it.takeBuf) + return it +} + +// GetInstance gets information about a particular instance. +func (c *InstanceAdminClient) GetInstance(ctx context.Context, req *instancepb.GetInstanceRequest, opts ...gax.CallOption) (*instancepb.Instance, error) { + ctx = insertXGoog(ctx, c.xGoogHeader) + opts = append(c.CallOptions.GetInstance[0:len(c.CallOptions.GetInstance):len(c.CallOptions.GetInstance)], opts...) + var resp *instancepb.Instance + err := gax.Invoke(ctx, func(ctx context.Context, settings gax.CallSettings) error { + var err error + resp, err = c.instanceAdminClient.GetInstance(ctx, req, settings.GRPC...) + return err + }, opts...) + if err != nil { + return nil, err + } + return resp, nil +} + +// CreateInstance creates an instance and begins preparing it to begin serving. The +// returned [long-running operation][google.longrunning.Operation] +// can be used to track the progress of preparing the new +// instance. The instance name is assigned by the caller. If the +// named instance already exists, `CreateInstance` returns +// `ALREADY_EXISTS`. +// +// Immediately upon completion of this request: +// +// * The instance is readable via the API, with all requested attributes +// but no allocated resources. Its state is `CREATING`. +// +// Until completion of the returned operation: +// +// * Cancelling the operation renders the instance immediately unreadable +// via the API. +// * The instance can be deleted. +// * All other attempts to modify the instance are rejected. +// +// Upon completion of the returned operation: +// +// * Billing for all successfully-allocated resources begins (some types +// may have lower than the requested levels). +// * Databases can be created in the instance. +// * The instance's allocated resource levels are readable via the API. +// * The instance's state becomes `READY`. +// +// The returned [long-running operation][google.longrunning.Operation] will +// have a name of the format `<instance_name>/operations/<operation_id>` and +// can be used to track creation of the instance. The +// [metadata][google.longrunning.Operation.metadata] field type is +// [CreateInstanceMetadata][google.spanner.admin.instance.v1.CreateInstanceMetadata]. +// The [response][google.longrunning.Operation.response] field type is +// [Instance][google.spanner.admin.instance.v1.Instance], if successful. +func (c *InstanceAdminClient) CreateInstance(ctx context.Context, req *instancepb.CreateInstanceRequest, opts ...gax.CallOption) (*CreateInstanceOperation, error) { + ctx = insertXGoog(ctx, c.xGoogHeader) + opts = append(c.CallOptions.CreateInstance[0:len(c.CallOptions.CreateInstance):len(c.CallOptions.CreateInstance)], opts...) + var resp *longrunningpb.Operation + err := gax.Invoke(ctx, func(ctx context.Context, settings gax.CallSettings) error { + var err error + resp, err = c.instanceAdminClient.CreateInstance(ctx, req, settings.GRPC...) + return err + }, opts...) + if err != nil { + return nil, err + } + return &CreateInstanceOperation{ + lro: longrunning.InternalNewOperation(c.LROClient, resp), + }, nil +} + +// UpdateInstance updates an instance, and begins allocating or releasing resources +// as requested. The returned [long-running +// operation][google.longrunning.Operation] can be used to track the +// progress of updating the instance. If the named instance does not +// exist, returns `NOT_FOUND`. +// +// Immediately upon completion of this request: +// +// * For resource types for which a decrease in the instance's allocation +// has been requested, billing is based on the newly-requested level. +// +// Until completion of the returned operation: +// +// * Cancelling the operation sets its metadata's +// [cancel_time][google.spanner.admin.instance.v1.UpdateInstanceMetadata.cancel_time], and begins +// restoring resources to their pre-request values. The operation +// is guaranteed to succeed at undoing all resource changes, +// after which point it terminates with a `CANCELLED` status. +// * All other attempts to modify the instance are rejected. +// * Reading the instance via the API continues to give the pre-request +// resource levels. +// +// Upon completion of the returned operation: +// +// * Billing begins for all successfully-allocated resources (some types +// may have lower than the requested levels). +// * All newly-reserved resources are available for serving the instance's +// tables. +// * The instance's new resource levels are readable via the API. +// +// The returned [long-running operation][google.longrunning.Operation] will +// have a name of the format `<instance_name>/operations/<operation_id>` and +// can be used to track the instance modification. The +// [metadata][google.longrunning.Operation.metadata] field type is +// [UpdateInstanceMetadata][google.spanner.admin.instance.v1.UpdateInstanceMetadata]. +// The [response][google.longrunning.Operation.response] field type is +// [Instance][google.spanner.admin.instance.v1.Instance], if successful. +// +// Authorization requires `spanner.instances.update` permission on +// resource [name][google.spanner.admin.instance.v1.Instance.name]. +func (c *InstanceAdminClient) UpdateInstance(ctx context.Context, req *instancepb.UpdateInstanceRequest, opts ...gax.CallOption) (*UpdateInstanceOperation, error) { + ctx = insertXGoog(ctx, c.xGoogHeader) + opts = append(c.CallOptions.UpdateInstance[0:len(c.CallOptions.UpdateInstance):len(c.CallOptions.UpdateInstance)], opts...) + var resp *longrunningpb.Operation + err := gax.Invoke(ctx, func(ctx context.Context, settings gax.CallSettings) error { + var err error + resp, err = c.instanceAdminClient.UpdateInstance(ctx, req, settings.GRPC...) + return err + }, opts...) + if err != nil { + return nil, err + } + return &UpdateInstanceOperation{ + lro: longrunning.InternalNewOperation(c.LROClient, resp), + }, nil +} + +// DeleteInstance deletes an instance. +// +// Immediately upon completion of the request: +// +// * Billing ceases for all of the instance's reserved resources. +// +// Soon afterward: +// +// * The instance and *all of its databases* immediately and +// irrevocably disappear from the API. All data in the databases +// is permanently deleted. +func (c *InstanceAdminClient) DeleteInstance(ctx context.Context, req *instancepb.DeleteInstanceRequest, opts ...gax.CallOption) error { + ctx = insertXGoog(ctx, c.xGoogHeader) + opts = append(c.CallOptions.DeleteInstance[0:len(c.CallOptions.DeleteInstance):len(c.CallOptions.DeleteInstance)], opts...) + err := gax.Invoke(ctx, func(ctx context.Context, settings gax.CallSettings) error { + var err error + _, err = c.instanceAdminClient.DeleteInstance(ctx, req, settings.GRPC...) + return err + }, opts...) + return err +} + +// SetIamPolicy sets the access control policy on an instance resource. Replaces any +// existing policy. +// +// Authorization requires `spanner.instances.setIamPolicy` on +// [resource][google.iam.v1.SetIamPolicyRequest.resource]. +func (c *InstanceAdminClient) SetIamPolicy(ctx context.Context, req *iampb.SetIamPolicyRequest, opts ...gax.CallOption) (*iampb.Policy, error) { + ctx = insertXGoog(ctx, c.xGoogHeader) + opts = append(c.CallOptions.SetIamPolicy[0:len(c.CallOptions.SetIamPolicy):len(c.CallOptions.SetIamPolicy)], opts...) + var resp *iampb.Policy + err := gax.Invoke(ctx, func(ctx context.Context, settings gax.CallSettings) error { + var err error + resp, err = c.instanceAdminClient.SetIamPolicy(ctx, req, settings.GRPC...) + return err + }, opts...) + if err != nil { + return nil, err + } + return resp, nil +} + +// GetIamPolicy gets the access control policy for an instance resource. Returns an empty +// policy if an instance exists but does not have a policy set. +// +// Authorization requires `spanner.instances.getIamPolicy` on +// [resource][google.iam.v1.GetIamPolicyRequest.resource]. +func (c *InstanceAdminClient) GetIamPolicy(ctx context.Context, req *iampb.GetIamPolicyRequest, opts ...gax.CallOption) (*iampb.Policy, error) { + ctx = insertXGoog(ctx, c.xGoogHeader) + opts = append(c.CallOptions.GetIamPolicy[0:len(c.CallOptions.GetIamPolicy):len(c.CallOptions.GetIamPolicy)], opts...) + var resp *iampb.Policy + err := gax.Invoke(ctx, func(ctx context.Context, settings gax.CallSettings) error { + var err error + resp, err = c.instanceAdminClient.GetIamPolicy(ctx, req, settings.GRPC...) + return err + }, opts...) + if err != nil { + return nil, err + } + return resp, nil +} + +// TestIamPermissions returns permissions that the caller has on the specified instance resource. +// +// Attempting this RPC on a non-existent Cloud Spanner instance resource will +// result in a NOT_FOUND error if the user has `spanner.instances.list` +// permission on the containing Google Cloud Project. Otherwise returns an +// empty set of permissions. +func (c *InstanceAdminClient) TestIamPermissions(ctx context.Context, req *iampb.TestIamPermissionsRequest, opts ...gax.CallOption) (*iampb.TestIamPermissionsResponse, error) { + ctx = insertXGoog(ctx, c.xGoogHeader) + opts = append(c.CallOptions.TestIamPermissions[0:len(c.CallOptions.TestIamPermissions):len(c.CallOptions.TestIamPermissions)], opts...) + var resp *iampb.TestIamPermissionsResponse + err := gax.Invoke(ctx, func(ctx context.Context, settings gax.CallSettings) error { + var err error + resp, err = c.instanceAdminClient.TestIamPermissions(ctx, req, settings.GRPC...) + return err + }, opts...) + if err != nil { + return nil, err + } + return resp, nil +} + +// InstanceConfigIterator manages a stream of *instancepb.InstanceConfig. +type InstanceConfigIterator struct { + items []*instancepb.InstanceConfig + pageInfo *iterator.PageInfo + nextFunc func() error + + // InternalFetch is for use by the Google Cloud Libraries only. + // It is not part of the stable interface of this package. + // + // InternalFetch returns results from a single call to the underlying RPC. + // The number of results is no greater than pageSize. + // If there are no more results, nextPageToken is empty and err is nil. + InternalFetch func(pageSize int, pageToken string) (results []*instancepb.InstanceConfig, nextPageToken string, err error) +} + +// PageInfo supports pagination. See the google.golang.org/api/iterator package for details. +func (it *InstanceConfigIterator) PageInfo() *iterator.PageInfo { + return it.pageInfo +} + +// Next returns the next result. Its second return value is iterator.Done if there are no more +// results. Once Next returns Done, all subsequent calls will return Done. +func (it *InstanceConfigIterator) Next() (*instancepb.InstanceConfig, error) { + var item *instancepb.InstanceConfig + if err := it.nextFunc(); err != nil { + return item, err + } + item = it.items[0] + it.items = it.items[1:] + return item, nil +} + +func (it *InstanceConfigIterator) bufLen() int { + return len(it.items) +} + +func (it *InstanceConfigIterator) takeBuf() interface{} { + b := it.items + it.items = nil + return b +} + +// InstanceIterator manages a stream of *instancepb.Instance. +type InstanceIterator struct { + items []*instancepb.Instance + pageInfo *iterator.PageInfo + nextFunc func() error + + // InternalFetch is for use by the Google Cloud Libraries only. + // It is not part of the stable interface of this package. + // + // InternalFetch returns results from a single call to the underlying RPC. + // The number of results is no greater than pageSize. + // If there are no more results, nextPageToken is empty and err is nil. + InternalFetch func(pageSize int, pageToken string) (results []*instancepb.Instance, nextPageToken string, err error) +} + +// PageInfo supports pagination. See the google.golang.org/api/iterator package for details. +func (it *InstanceIterator) PageInfo() *iterator.PageInfo { + return it.pageInfo +} + +// Next returns the next result. Its second return value is iterator.Done if there are no more +// results. Once Next returns Done, all subsequent calls will return Done. +func (it *InstanceIterator) Next() (*instancepb.Instance, error) { + var item *instancepb.Instance + if err := it.nextFunc(); err != nil { + return item, err + } + item = it.items[0] + it.items = it.items[1:] + return item, nil +} + +func (it *InstanceIterator) bufLen() int { + return len(it.items) +} + +func (it *InstanceIterator) takeBuf() interface{} { + b := it.items + it.items = nil + return b +} + +// CreateInstanceOperation manages a long-running operation from CreateInstance. +type CreateInstanceOperation struct { + lro *longrunning.Operation +} + +// CreateInstanceOperation returns a new CreateInstanceOperation from a given name. +// The name must be that of a previously created CreateInstanceOperation, possibly from a different process. +func (c *InstanceAdminClient) CreateInstanceOperation(name string) *CreateInstanceOperation { + return &CreateInstanceOperation{ + lro: longrunning.InternalNewOperation(c.LROClient, &longrunningpb.Operation{Name: name}), + } +} + +// Wait blocks until the long-running operation is completed, returning the response and any errors encountered. +// +// See documentation of Poll for error-handling information. +func (op *CreateInstanceOperation) Wait(ctx context.Context, opts ...gax.CallOption) (*instancepb.Instance, error) { + var resp instancepb.Instance + if err := op.lro.Wait(ctx, &resp, opts...); err != nil { + return nil, err + } + return &resp, nil +} + +// Poll fetches the latest state of the long-running operation. +// +// Poll also fetches the latest metadata, which can be retrieved by Metadata. +// +// If Poll fails, the error is returned and op is unmodified. If Poll succeeds and +// the operation has completed with failure, the error is returned and op.Done will return true. +// If Poll succeeds and the operation has completed successfully, +// op.Done will return true, and the response of the operation is returned. +// If Poll succeeds and the operation has not completed, the returned response and error are both nil. +func (op *CreateInstanceOperation) Poll(ctx context.Context, opts ...gax.CallOption) (*instancepb.Instance, error) { + var resp instancepb.Instance + if err := op.lro.Poll(ctx, &resp, opts...); err != nil { + return nil, err + } + if !op.Done() { + return nil, nil + } + return &resp, nil +} + +// Metadata returns metadata associated with the long-running operation. +// Metadata itself does not contact the server, but Poll does. +// To get the latest metadata, call this method after a successful call to Poll. +// If the metadata is not available, the returned metadata and error are both nil. +func (op *CreateInstanceOperation) Metadata() (*instancepb.CreateInstanceMetadata, error) { + var meta instancepb.CreateInstanceMetadata + if err := op.lro.Metadata(&meta); err == longrunning.ErrNoMetadata { + return nil, nil + } else if err != nil { + return nil, err + } + return &meta, nil +} + +// Done reports whether the long-running operation has completed. +func (op *CreateInstanceOperation) Done() bool { + return op.lro.Done() +} + +// Name returns the name of the long-running operation. +// The name is assigned by the server and is unique within the service from which the operation is created. +func (op *CreateInstanceOperation) Name() string { + return op.lro.Name() +} + +// UpdateInstanceOperation manages a long-running operation from UpdateInstance. +type UpdateInstanceOperation struct { + lro *longrunning.Operation +} + +// UpdateInstanceOperation returns a new UpdateInstanceOperation from a given name. +// The name must be that of a previously created UpdateInstanceOperation, possibly from a different process. +func (c *InstanceAdminClient) UpdateInstanceOperation(name string) *UpdateInstanceOperation { + return &UpdateInstanceOperation{ + lro: longrunning.InternalNewOperation(c.LROClient, &longrunningpb.Operation{Name: name}), + } +} + +// Wait blocks until the long-running operation is completed, returning the response and any errors encountered. +// +// See documentation of Poll for error-handling information. +func (op *UpdateInstanceOperation) Wait(ctx context.Context, opts ...gax.CallOption) (*instancepb.Instance, error) { + var resp instancepb.Instance + if err := op.lro.Wait(ctx, &resp, opts...); err != nil { + return nil, err + } + return &resp, nil +} + +// Poll fetches the latest state of the long-running operation. +// +// Poll also fetches the latest metadata, which can be retrieved by Metadata. +// +// If Poll fails, the error is returned and op is unmodified. If Poll succeeds and +// the operation has completed with failure, the error is returned and op.Done will return true. +// If Poll succeeds and the operation has completed successfully, +// op.Done will return true, and the response of the operation is returned. +// If Poll succeeds and the operation has not completed, the returned response and error are both nil. +func (op *UpdateInstanceOperation) Poll(ctx context.Context, opts ...gax.CallOption) (*instancepb.Instance, error) { + var resp instancepb.Instance + if err := op.lro.Poll(ctx, &resp, opts...); err != nil { + return nil, err + } + if !op.Done() { + return nil, nil + } + return &resp, nil +} + +// Metadata returns metadata associated with the long-running operation. +// Metadata itself does not contact the server, but Poll does. +// To get the latest metadata, call this method after a successful call to Poll. +// If the metadata is not available, the returned metadata and error are both nil. +func (op *UpdateInstanceOperation) Metadata() (*instancepb.UpdateInstanceMetadata, error) { + var meta instancepb.UpdateInstanceMetadata + if err := op.lro.Metadata(&meta); err == longrunning.ErrNoMetadata { + return nil, nil + } else if err != nil { + return nil, err + } + return &meta, nil +} + +// Done reports whether the long-running operation has completed. +func (op *UpdateInstanceOperation) Done() bool { + return op.lro.Done() +} + +// Name returns the name of the long-running operation. +// The name is assigned by the server and is unique within the service from which the operation is created. +func (op *UpdateInstanceOperation) Name() string { + return op.lro.Name() +} diff --git a/vendor/cloud.google.com/go/spanner/admin/instance/apiv1/instance_admin_client_example_test.go b/vendor/cloud.google.com/go/spanner/admin/instance/apiv1/instance_admin_client_example_test.go new file mode 100644 index 000000000..ee807fdbc --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/admin/instance/apiv1/instance_admin_client_example_test.go @@ -0,0 +1,230 @@ +// Copyright 2017, Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// AUTO-GENERATED CODE. DO NOT EDIT. + +package instance_test + +import ( + "cloud.google.com/go/spanner/admin/instance/apiv1" + "golang.org/x/net/context" + iampb "google.golang.org/genproto/googleapis/iam/v1" + instancepb "google.golang.org/genproto/googleapis/spanner/admin/instance/v1" +) + +func ExampleNewInstanceAdminClient() { + ctx := context.Background() + c, err := instance.NewInstanceAdminClient(ctx) + if err != nil { + // TODO: Handle error. + } + // TODO: Use client. + _ = c +} + +func ExampleInstanceAdminClient_ListInstanceConfigs() { + ctx := context.Background() + c, err := instance.NewInstanceAdminClient(ctx) + if err != nil { + // TODO: Handle error. + } + + req := &instancepb.ListInstanceConfigsRequest{ + // TODO: Fill request struct fields. + } + it := c.ListInstanceConfigs(ctx, req) + for { + resp, err := it.Next() + if err != nil { + // TODO: Handle error. + break + } + // TODO: Use resp. + _ = resp + } +} + +func ExampleInstanceAdminClient_GetInstanceConfig() { + ctx := context.Background() + c, err := instance.NewInstanceAdminClient(ctx) + if err != nil { + // TODO: Handle error. + } + + req := &instancepb.GetInstanceConfigRequest{ + // TODO: Fill request struct fields. + } + resp, err := c.GetInstanceConfig(ctx, req) + if err != nil { + // TODO: Handle error. + } + // TODO: Use resp. + _ = resp +} + +func ExampleInstanceAdminClient_ListInstances() { + ctx := context.Background() + c, err := instance.NewInstanceAdminClient(ctx) + if err != nil { + // TODO: Handle error. + } + + req := &instancepb.ListInstancesRequest{ + // TODO: Fill request struct fields. + } + it := c.ListInstances(ctx, req) + for { + resp, err := it.Next() + if err != nil { + // TODO: Handle error. + break + } + // TODO: Use resp. + _ = resp + } +} + +func ExampleInstanceAdminClient_GetInstance() { + ctx := context.Background() + c, err := instance.NewInstanceAdminClient(ctx) + if err != nil { + // TODO: Handle error. + } + + req := &instancepb.GetInstanceRequest{ + // TODO: Fill request struct fields. + } + resp, err := c.GetInstance(ctx, req) + if err != nil { + // TODO: Handle error. + } + // TODO: Use resp. + _ = resp +} + +func ExampleInstanceAdminClient_CreateInstance() { + ctx := context.Background() + c, err := instance.NewInstanceAdminClient(ctx) + if err != nil { + // TODO: Handle error. + } + + req := &instancepb.CreateInstanceRequest{ + // TODO: Fill request struct fields. + } + op, err := c.CreateInstance(ctx, req) + if err != nil { + // TODO: Handle error. + } + + resp, err := op.Wait(ctx) + if err != nil { + // TODO: Handle error. + } + // TODO: Use resp. + _ = resp +} + +func ExampleInstanceAdminClient_UpdateInstance() { + ctx := context.Background() + c, err := instance.NewInstanceAdminClient(ctx) + if err != nil { + // TODO: Handle error. + } + + req := &instancepb.UpdateInstanceRequest{ + // TODO: Fill request struct fields. + } + op, err := c.UpdateInstance(ctx, req) + if err != nil { + // TODO: Handle error. + } + + resp, err := op.Wait(ctx) + if err != nil { + // TODO: Handle error. + } + // TODO: Use resp. + _ = resp +} + +func ExampleInstanceAdminClient_DeleteInstance() { + ctx := context.Background() + c, err := instance.NewInstanceAdminClient(ctx) + if err != nil { + // TODO: Handle error. + } + + req := &instancepb.DeleteInstanceRequest{ + // TODO: Fill request struct fields. + } + err = c.DeleteInstance(ctx, req) + if err != nil { + // TODO: Handle error. + } +} + +func ExampleInstanceAdminClient_SetIamPolicy() { + ctx := context.Background() + c, err := instance.NewInstanceAdminClient(ctx) + if err != nil { + // TODO: Handle error. + } + + req := &iampb.SetIamPolicyRequest{ + // TODO: Fill request struct fields. + } + resp, err := c.SetIamPolicy(ctx, req) + if err != nil { + // TODO: Handle error. + } + // TODO: Use resp. + _ = resp +} + +func ExampleInstanceAdminClient_GetIamPolicy() { + ctx := context.Background() + c, err := instance.NewInstanceAdminClient(ctx) + if err != nil { + // TODO: Handle error. + } + + req := &iampb.GetIamPolicyRequest{ + // TODO: Fill request struct fields. + } + resp, err := c.GetIamPolicy(ctx, req) + if err != nil { + // TODO: Handle error. + } + // TODO: Use resp. + _ = resp +} + +func ExampleInstanceAdminClient_TestIamPermissions() { + ctx := context.Background() + c, err := instance.NewInstanceAdminClient(ctx) + if err != nil { + // TODO: Handle error. + } + + req := &iampb.TestIamPermissionsRequest{ + // TODO: Fill request struct fields. + } + resp, err := c.TestIamPermissions(ctx, req) + if err != nil { + // TODO: Handle error. + } + // TODO: Use resp. + _ = resp +} diff --git a/vendor/cloud.google.com/go/spanner/admin/instance/apiv1/mock_test.go b/vendor/cloud.google.com/go/spanner/admin/instance/apiv1/mock_test.go new file mode 100644 index 000000000..8728d0b25 --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/admin/instance/apiv1/mock_test.go @@ -0,0 +1,896 @@ +// Copyright 2017, Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// AUTO-GENERATED CODE. DO NOT EDIT. + +package instance + +import ( + emptypb "github.com/golang/protobuf/ptypes/empty" + iampb "google.golang.org/genproto/googleapis/iam/v1" + longrunningpb "google.golang.org/genproto/googleapis/longrunning" + instancepb "google.golang.org/genproto/googleapis/spanner/admin/instance/v1" + field_maskpb "google.golang.org/genproto/protobuf/field_mask" +) + +import ( + "flag" + "fmt" + "io" + "log" + "net" + "os" + "strings" + "testing" + + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes" + "golang.org/x/net/context" + "google.golang.org/api/option" + status "google.golang.org/genproto/googleapis/rpc/status" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" +) + +var _ = io.EOF +var _ = ptypes.MarshalAny +var _ status.Status + +type mockInstanceAdminServer struct { + // Embed for forward compatibility. + // Tests will keep working if more methods are added + // in the future. + instancepb.InstanceAdminServer + + reqs []proto.Message + + // If set, all calls return this error. + err error + + // responses to return if err == nil + resps []proto.Message +} + +func (s *mockInstanceAdminServer) ListInstanceConfigs(ctx context.Context, req *instancepb.ListInstanceConfigsRequest) (*instancepb.ListInstanceConfigsResponse, error) { + md, _ := metadata.FromIncomingContext(ctx) + if xg := md["x-goog-api-client"]; len(xg) == 0 || !strings.Contains(xg[0], "gl-go/") { + return nil, fmt.Errorf("x-goog-api-client = %v, expected gl-go key", xg) + } + s.reqs = append(s.reqs, req) + if s.err != nil { + return nil, s.err + } + return s.resps[0].(*instancepb.ListInstanceConfigsResponse), nil +} + +func (s *mockInstanceAdminServer) GetInstanceConfig(ctx context.Context, req *instancepb.GetInstanceConfigRequest) (*instancepb.InstanceConfig, error) { + md, _ := metadata.FromIncomingContext(ctx) + if xg := md["x-goog-api-client"]; len(xg) == 0 || !strings.Contains(xg[0], "gl-go/") { + return nil, fmt.Errorf("x-goog-api-client = %v, expected gl-go key", xg) + } + s.reqs = append(s.reqs, req) + if s.err != nil { + return nil, s.err + } + return s.resps[0].(*instancepb.InstanceConfig), nil +} + +func (s *mockInstanceAdminServer) ListInstances(ctx context.Context, req *instancepb.ListInstancesRequest) (*instancepb.ListInstancesResponse, error) { + md, _ := metadata.FromIncomingContext(ctx) + if xg := md["x-goog-api-client"]; len(xg) == 0 || !strings.Contains(xg[0], "gl-go/") { + return nil, fmt.Errorf("x-goog-api-client = %v, expected gl-go key", xg) + } + s.reqs = append(s.reqs, req) + if s.err != nil { + return nil, s.err + } + return s.resps[0].(*instancepb.ListInstancesResponse), nil +} + +func (s *mockInstanceAdminServer) GetInstance(ctx context.Context, req *instancepb.GetInstanceRequest) (*instancepb.Instance, error) { + md, _ := metadata.FromIncomingContext(ctx) + if xg := md["x-goog-api-client"]; len(xg) == 0 || !strings.Contains(xg[0], "gl-go/") { + return nil, fmt.Errorf("x-goog-api-client = %v, expected gl-go key", xg) + } + s.reqs = append(s.reqs, req) + if s.err != nil { + return nil, s.err + } + return s.resps[0].(*instancepb.Instance), nil +} + +func (s *mockInstanceAdminServer) CreateInstance(ctx context.Context, req *instancepb.CreateInstanceRequest) (*longrunningpb.Operation, error) { + md, _ := metadata.FromIncomingContext(ctx) + if xg := md["x-goog-api-client"]; len(xg) == 0 || !strings.Contains(xg[0], "gl-go/") { + return nil, fmt.Errorf("x-goog-api-client = %v, expected gl-go key", xg) + } + s.reqs = append(s.reqs, req) + if s.err != nil { + return nil, s.err + } + return s.resps[0].(*longrunningpb.Operation), nil +} + +func (s *mockInstanceAdminServer) UpdateInstance(ctx context.Context, req *instancepb.UpdateInstanceRequest) (*longrunningpb.Operation, error) { + md, _ := metadata.FromIncomingContext(ctx) + if xg := md["x-goog-api-client"]; len(xg) == 0 || !strings.Contains(xg[0], "gl-go/") { + return nil, fmt.Errorf("x-goog-api-client = %v, expected gl-go key", xg) + } + s.reqs = append(s.reqs, req) + if s.err != nil { + return nil, s.err + } + return s.resps[0].(*longrunningpb.Operation), nil +} + +func (s *mockInstanceAdminServer) DeleteInstance(ctx context.Context, req *instancepb.DeleteInstanceRequest) (*emptypb.Empty, error) { + md, _ := metadata.FromIncomingContext(ctx) + if xg := md["x-goog-api-client"]; len(xg) == 0 || !strings.Contains(xg[0], "gl-go/") { + return nil, fmt.Errorf("x-goog-api-client = %v, expected gl-go key", xg) + } + s.reqs = append(s.reqs, req) + if s.err != nil { + return nil, s.err + } + return s.resps[0].(*emptypb.Empty), nil +} + +func (s *mockInstanceAdminServer) SetIamPolicy(ctx context.Context, req *iampb.SetIamPolicyRequest) (*iampb.Policy, error) { + md, _ := metadata.FromIncomingContext(ctx) + if xg := md["x-goog-api-client"]; len(xg) == 0 || !strings.Contains(xg[0], "gl-go/") { + return nil, fmt.Errorf("x-goog-api-client = %v, expected gl-go key", xg) + } + s.reqs = append(s.reqs, req) + if s.err != nil { + return nil, s.err + } + return s.resps[0].(*iampb.Policy), nil +} + +func (s *mockInstanceAdminServer) GetIamPolicy(ctx context.Context, req *iampb.GetIamPolicyRequest) (*iampb.Policy, error) { + md, _ := metadata.FromIncomingContext(ctx) + if xg := md["x-goog-api-client"]; len(xg) == 0 || !strings.Contains(xg[0], "gl-go/") { + return nil, fmt.Errorf("x-goog-api-client = %v, expected gl-go key", xg) + } + s.reqs = append(s.reqs, req) + if s.err != nil { + return nil, s.err + } + return s.resps[0].(*iampb.Policy), nil +} + +func (s *mockInstanceAdminServer) TestIamPermissions(ctx context.Context, req *iampb.TestIamPermissionsRequest) (*iampb.TestIamPermissionsResponse, error) { + md, _ := metadata.FromIncomingContext(ctx) + if xg := md["x-goog-api-client"]; len(xg) == 0 || !strings.Contains(xg[0], "gl-go/") { + return nil, fmt.Errorf("x-goog-api-client = %v, expected gl-go key", xg) + } + s.reqs = append(s.reqs, req) + if s.err != nil { + return nil, s.err + } + return s.resps[0].(*iampb.TestIamPermissionsResponse), nil +} + +// clientOpt is the option tests should use to connect to the test server. +// It is initialized by TestMain. +var clientOpt option.ClientOption + +var ( + mockInstanceAdmin mockInstanceAdminServer +) + +func TestMain(m *testing.M) { + flag.Parse() + + serv := grpc.NewServer() + instancepb.RegisterInstanceAdminServer(serv, &mockInstanceAdmin) + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + log.Fatal(err) + } + go serv.Serve(lis) + + conn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) + if err != nil { + log.Fatal(err) + } + clientOpt = option.WithGRPCConn(conn) + + os.Exit(m.Run()) +} + +func TestInstanceAdminListInstanceConfigs(t *testing.T) { + var nextPageToken string = "" + var instanceConfigsElement *instancepb.InstanceConfig = &instancepb.InstanceConfig{} + var instanceConfigs = []*instancepb.InstanceConfig{instanceConfigsElement} + var expectedResponse = &instancepb.ListInstanceConfigsResponse{ + NextPageToken: nextPageToken, + InstanceConfigs: instanceConfigs, + } + + mockInstanceAdmin.err = nil + mockInstanceAdmin.reqs = nil + + mockInstanceAdmin.resps = append(mockInstanceAdmin.resps[:0], expectedResponse) + + var formattedParent string = InstanceAdminProjectPath("[PROJECT]") + var request = &instancepb.ListInstanceConfigsRequest{ + Parent: formattedParent, + } + + c, err := NewInstanceAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + resp, err := c.ListInstanceConfigs(context.Background(), request).Next() + + if err != nil { + t.Fatal(err) + } + + if want, got := request, mockInstanceAdmin.reqs[0]; !proto.Equal(want, got) { + t.Errorf("wrong request %q, want %q", got, want) + } + + want := (interface{})(expectedResponse.InstanceConfigs[0]) + got := (interface{})(resp) + var ok bool + + switch want := (want).(type) { + case proto.Message: + ok = proto.Equal(want, got.(proto.Message)) + default: + ok = want == got + } + if !ok { + t.Errorf("wrong response %q, want %q)", got, want) + } +} + +func TestInstanceAdminListInstanceConfigsError(t *testing.T) { + errCode := codes.PermissionDenied + mockInstanceAdmin.err = grpc.Errorf(errCode, "test error") + + var formattedParent string = InstanceAdminProjectPath("[PROJECT]") + var request = &instancepb.ListInstanceConfigsRequest{ + Parent: formattedParent, + } + + c, err := NewInstanceAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + resp, err := c.ListInstanceConfigs(context.Background(), request).Next() + + if c := grpc.Code(err); c != errCode { + t.Errorf("got error code %q, want %q", c, errCode) + } + _ = resp +} +func TestInstanceAdminGetInstanceConfig(t *testing.T) { + var name2 string = "name2-1052831874" + var displayName string = "displayName1615086568" + var expectedResponse = &instancepb.InstanceConfig{ + Name: name2, + DisplayName: displayName, + } + + mockInstanceAdmin.err = nil + mockInstanceAdmin.reqs = nil + + mockInstanceAdmin.resps = append(mockInstanceAdmin.resps[:0], expectedResponse) + + var formattedName string = InstanceAdminInstanceConfigPath("[PROJECT]", "[INSTANCE_CONFIG]") + var request = &instancepb.GetInstanceConfigRequest{ + Name: formattedName, + } + + c, err := NewInstanceAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + resp, err := c.GetInstanceConfig(context.Background(), request) + + if err != nil { + t.Fatal(err) + } + + if want, got := request, mockInstanceAdmin.reqs[0]; !proto.Equal(want, got) { + t.Errorf("wrong request %q, want %q", got, want) + } + + if want, got := expectedResponse, resp; !proto.Equal(want, got) { + t.Errorf("wrong response %q, want %q)", got, want) + } +} + +func TestInstanceAdminGetInstanceConfigError(t *testing.T) { + errCode := codes.PermissionDenied + mockInstanceAdmin.err = grpc.Errorf(errCode, "test error") + + var formattedName string = InstanceAdminInstanceConfigPath("[PROJECT]", "[INSTANCE_CONFIG]") + var request = &instancepb.GetInstanceConfigRequest{ + Name: formattedName, + } + + c, err := NewInstanceAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + resp, err := c.GetInstanceConfig(context.Background(), request) + + if c := grpc.Code(err); c != errCode { + t.Errorf("got error code %q, want %q", c, errCode) + } + _ = resp +} +func TestInstanceAdminListInstances(t *testing.T) { + var nextPageToken string = "" + var instancesElement *instancepb.Instance = &instancepb.Instance{} + var instances = []*instancepb.Instance{instancesElement} + var expectedResponse = &instancepb.ListInstancesResponse{ + NextPageToken: nextPageToken, + Instances: instances, + } + + mockInstanceAdmin.err = nil + mockInstanceAdmin.reqs = nil + + mockInstanceAdmin.resps = append(mockInstanceAdmin.resps[:0], expectedResponse) + + var formattedParent string = InstanceAdminProjectPath("[PROJECT]") + var request = &instancepb.ListInstancesRequest{ + Parent: formattedParent, + } + + c, err := NewInstanceAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + resp, err := c.ListInstances(context.Background(), request).Next() + + if err != nil { + t.Fatal(err) + } + + if want, got := request, mockInstanceAdmin.reqs[0]; !proto.Equal(want, got) { + t.Errorf("wrong request %q, want %q", got, want) + } + + want := (interface{})(expectedResponse.Instances[0]) + got := (interface{})(resp) + var ok bool + + switch want := (want).(type) { + case proto.Message: + ok = proto.Equal(want, got.(proto.Message)) + default: + ok = want == got + } + if !ok { + t.Errorf("wrong response %q, want %q)", got, want) + } +} + +func TestInstanceAdminListInstancesError(t *testing.T) { + errCode := codes.PermissionDenied + mockInstanceAdmin.err = grpc.Errorf(errCode, "test error") + + var formattedParent string = InstanceAdminProjectPath("[PROJECT]") + var request = &instancepb.ListInstancesRequest{ + Parent: formattedParent, + } + + c, err := NewInstanceAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + resp, err := c.ListInstances(context.Background(), request).Next() + + if c := grpc.Code(err); c != errCode { + t.Errorf("got error code %q, want %q", c, errCode) + } + _ = resp +} +func TestInstanceAdminGetInstance(t *testing.T) { + var name2 string = "name2-1052831874" + var config string = "config-1354792126" + var displayName string = "displayName1615086568" + var nodeCount int32 = 1539922066 + var expectedResponse = &instancepb.Instance{ + Name: name2, + Config: config, + DisplayName: displayName, + NodeCount: nodeCount, + } + + mockInstanceAdmin.err = nil + mockInstanceAdmin.reqs = nil + + mockInstanceAdmin.resps = append(mockInstanceAdmin.resps[:0], expectedResponse) + + var formattedName string = InstanceAdminInstancePath("[PROJECT]", "[INSTANCE]") + var request = &instancepb.GetInstanceRequest{ + Name: formattedName, + } + + c, err := NewInstanceAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + resp, err := c.GetInstance(context.Background(), request) + + if err != nil { + t.Fatal(err) + } + + if want, got := request, mockInstanceAdmin.reqs[0]; !proto.Equal(want, got) { + t.Errorf("wrong request %q, want %q", got, want) + } + + if want, got := expectedResponse, resp; !proto.Equal(want, got) { + t.Errorf("wrong response %q, want %q)", got, want) + } +} + +func TestInstanceAdminGetInstanceError(t *testing.T) { + errCode := codes.PermissionDenied + mockInstanceAdmin.err = grpc.Errorf(errCode, "test error") + + var formattedName string = InstanceAdminInstancePath("[PROJECT]", "[INSTANCE]") + var request = &instancepb.GetInstanceRequest{ + Name: formattedName, + } + + c, err := NewInstanceAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + resp, err := c.GetInstance(context.Background(), request) + + if c := grpc.Code(err); c != errCode { + t.Errorf("got error code %q, want %q", c, errCode) + } + _ = resp +} +func TestInstanceAdminCreateInstance(t *testing.T) { + var name string = "name3373707" + var config string = "config-1354792126" + var displayName string = "displayName1615086568" + var nodeCount int32 = 1539922066 + var expectedResponse = &instancepb.Instance{ + Name: name, + Config: config, + DisplayName: displayName, + NodeCount: nodeCount, + } + + mockInstanceAdmin.err = nil + mockInstanceAdmin.reqs = nil + + any, err := ptypes.MarshalAny(expectedResponse) + if err != nil { + t.Fatal(err) + } + mockInstanceAdmin.resps = append(mockInstanceAdmin.resps[:0], &longrunningpb.Operation{ + Name: "longrunning-test", + Done: true, + Result: &longrunningpb.Operation_Response{Response: any}, + }) + + var formattedParent string = InstanceAdminProjectPath("[PROJECT]") + var instanceId string = "instanceId-2101995259" + var instance *instancepb.Instance = &instancepb.Instance{} + var request = &instancepb.CreateInstanceRequest{ + Parent: formattedParent, + InstanceId: instanceId, + Instance: instance, + } + + c, err := NewInstanceAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + respLRO, err := c.CreateInstance(context.Background(), request) + if err != nil { + t.Fatal(err) + } + resp, err := respLRO.Wait(context.Background()) + + if err != nil { + t.Fatal(err) + } + + if want, got := request, mockInstanceAdmin.reqs[0]; !proto.Equal(want, got) { + t.Errorf("wrong request %q, want %q", got, want) + } + + if want, got := expectedResponse, resp; !proto.Equal(want, got) { + t.Errorf("wrong response %q, want %q)", got, want) + } +} + +func TestInstanceAdminCreateInstanceError(t *testing.T) { + errCode := codes.PermissionDenied + mockInstanceAdmin.err = nil + mockInstanceAdmin.resps = append(mockInstanceAdmin.resps[:0], &longrunningpb.Operation{ + Name: "longrunning-test", + Done: true, + Result: &longrunningpb.Operation_Error{ + Error: &status.Status{ + Code: int32(errCode), + Message: "test error", + }, + }, + }) + + var formattedParent string = InstanceAdminProjectPath("[PROJECT]") + var instanceId string = "instanceId-2101995259" + var instance *instancepb.Instance = &instancepb.Instance{} + var request = &instancepb.CreateInstanceRequest{ + Parent: formattedParent, + InstanceId: instanceId, + Instance: instance, + } + + c, err := NewInstanceAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + respLRO, err := c.CreateInstance(context.Background(), request) + if err != nil { + t.Fatal(err) + } + resp, err := respLRO.Wait(context.Background()) + + if c := grpc.Code(err); c != errCode { + t.Errorf("got error code %q, want %q", c, errCode) + } + _ = resp +} +func TestInstanceAdminUpdateInstance(t *testing.T) { + var name string = "name3373707" + var config string = "config-1354792126" + var displayName string = "displayName1615086568" + var nodeCount int32 = 1539922066 + var expectedResponse = &instancepb.Instance{ + Name: name, + Config: config, + DisplayName: displayName, + NodeCount: nodeCount, + } + + mockInstanceAdmin.err = nil + mockInstanceAdmin.reqs = nil + + any, err := ptypes.MarshalAny(expectedResponse) + if err != nil { + t.Fatal(err) + } + mockInstanceAdmin.resps = append(mockInstanceAdmin.resps[:0], &longrunningpb.Operation{ + Name: "longrunning-test", + Done: true, + Result: &longrunningpb.Operation_Response{Response: any}, + }) + + var instance *instancepb.Instance = &instancepb.Instance{} + var fieldMask *field_maskpb.FieldMask = &field_maskpb.FieldMask{} + var request = &instancepb.UpdateInstanceRequest{ + Instance: instance, + FieldMask: fieldMask, + } + + c, err := NewInstanceAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + respLRO, err := c.UpdateInstance(context.Background(), request) + if err != nil { + t.Fatal(err) + } + resp, err := respLRO.Wait(context.Background()) + + if err != nil { + t.Fatal(err) + } + + if want, got := request, mockInstanceAdmin.reqs[0]; !proto.Equal(want, got) { + t.Errorf("wrong request %q, want %q", got, want) + } + + if want, got := expectedResponse, resp; !proto.Equal(want, got) { + t.Errorf("wrong response %q, want %q)", got, want) + } +} + +func TestInstanceAdminUpdateInstanceError(t *testing.T) { + errCode := codes.PermissionDenied + mockInstanceAdmin.err = nil + mockInstanceAdmin.resps = append(mockInstanceAdmin.resps[:0], &longrunningpb.Operation{ + Name: "longrunning-test", + Done: true, + Result: &longrunningpb.Operation_Error{ + Error: &status.Status{ + Code: int32(errCode), + Message: "test error", + }, + }, + }) + + var instance *instancepb.Instance = &instancepb.Instance{} + var fieldMask *field_maskpb.FieldMask = &field_maskpb.FieldMask{} + var request = &instancepb.UpdateInstanceRequest{ + Instance: instance, + FieldMask: fieldMask, + } + + c, err := NewInstanceAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + respLRO, err := c.UpdateInstance(context.Background(), request) + if err != nil { + t.Fatal(err) + } + resp, err := respLRO.Wait(context.Background()) + + if c := grpc.Code(err); c != errCode { + t.Errorf("got error code %q, want %q", c, errCode) + } + _ = resp +} +func TestInstanceAdminDeleteInstance(t *testing.T) { + var expectedResponse *emptypb.Empty = &emptypb.Empty{} + + mockInstanceAdmin.err = nil + mockInstanceAdmin.reqs = nil + + mockInstanceAdmin.resps = append(mockInstanceAdmin.resps[:0], expectedResponse) + + var formattedName string = InstanceAdminInstancePath("[PROJECT]", "[INSTANCE]") + var request = &instancepb.DeleteInstanceRequest{ + Name: formattedName, + } + + c, err := NewInstanceAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + err = c.DeleteInstance(context.Background(), request) + + if err != nil { + t.Fatal(err) + } + + if want, got := request, mockInstanceAdmin.reqs[0]; !proto.Equal(want, got) { + t.Errorf("wrong request %q, want %q", got, want) + } + +} + +func TestInstanceAdminDeleteInstanceError(t *testing.T) { + errCode := codes.PermissionDenied + mockInstanceAdmin.err = grpc.Errorf(errCode, "test error") + + var formattedName string = InstanceAdminInstancePath("[PROJECT]", "[INSTANCE]") + var request = &instancepb.DeleteInstanceRequest{ + Name: formattedName, + } + + c, err := NewInstanceAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + err = c.DeleteInstance(context.Background(), request) + + if c := grpc.Code(err); c != errCode { + t.Errorf("got error code %q, want %q", c, errCode) + } +} +func TestInstanceAdminSetIamPolicy(t *testing.T) { + var version int32 = 351608024 + var etag []byte = []byte("21") + var expectedResponse = &iampb.Policy{ + Version: version, + Etag: etag, + } + + mockInstanceAdmin.err = nil + mockInstanceAdmin.reqs = nil + + mockInstanceAdmin.resps = append(mockInstanceAdmin.resps[:0], expectedResponse) + + var formattedResource string = InstanceAdminInstancePath("[PROJECT]", "[INSTANCE]") + var policy *iampb.Policy = &iampb.Policy{} + var request = &iampb.SetIamPolicyRequest{ + Resource: formattedResource, + Policy: policy, + } + + c, err := NewInstanceAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + resp, err := c.SetIamPolicy(context.Background(), request) + + if err != nil { + t.Fatal(err) + } + + if want, got := request, mockInstanceAdmin.reqs[0]; !proto.Equal(want, got) { + t.Errorf("wrong request %q, want %q", got, want) + } + + if want, got := expectedResponse, resp; !proto.Equal(want, got) { + t.Errorf("wrong response %q, want %q)", got, want) + } +} + +func TestInstanceAdminSetIamPolicyError(t *testing.T) { + errCode := codes.PermissionDenied + mockInstanceAdmin.err = grpc.Errorf(errCode, "test error") + + var formattedResource string = InstanceAdminInstancePath("[PROJECT]", "[INSTANCE]") + var policy *iampb.Policy = &iampb.Policy{} + var request = &iampb.SetIamPolicyRequest{ + Resource: formattedResource, + Policy: policy, + } + + c, err := NewInstanceAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + resp, err := c.SetIamPolicy(context.Background(), request) + + if c := grpc.Code(err); c != errCode { + t.Errorf("got error code %q, want %q", c, errCode) + } + _ = resp +} +func TestInstanceAdminGetIamPolicy(t *testing.T) { + var version int32 = 351608024 + var etag []byte = []byte("21") + var expectedResponse = &iampb.Policy{ + Version: version, + Etag: etag, + } + + mockInstanceAdmin.err = nil + mockInstanceAdmin.reqs = nil + + mockInstanceAdmin.resps = append(mockInstanceAdmin.resps[:0], expectedResponse) + + var formattedResource string = InstanceAdminInstancePath("[PROJECT]", "[INSTANCE]") + var request = &iampb.GetIamPolicyRequest{ + Resource: formattedResource, + } + + c, err := NewInstanceAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + resp, err := c.GetIamPolicy(context.Background(), request) + + if err != nil { + t.Fatal(err) + } + + if want, got := request, mockInstanceAdmin.reqs[0]; !proto.Equal(want, got) { + t.Errorf("wrong request %q, want %q", got, want) + } + + if want, got := expectedResponse, resp; !proto.Equal(want, got) { + t.Errorf("wrong response %q, want %q)", got, want) + } +} + +func TestInstanceAdminGetIamPolicyError(t *testing.T) { + errCode := codes.PermissionDenied + mockInstanceAdmin.err = grpc.Errorf(errCode, "test error") + + var formattedResource string = InstanceAdminInstancePath("[PROJECT]", "[INSTANCE]") + var request = &iampb.GetIamPolicyRequest{ + Resource: formattedResource, + } + + c, err := NewInstanceAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + resp, err := c.GetIamPolicy(context.Background(), request) + + if c := grpc.Code(err); c != errCode { + t.Errorf("got error code %q, want %q", c, errCode) + } + _ = resp +} +func TestInstanceAdminTestIamPermissions(t *testing.T) { + var expectedResponse *iampb.TestIamPermissionsResponse = &iampb.TestIamPermissionsResponse{} + + mockInstanceAdmin.err = nil + mockInstanceAdmin.reqs = nil + + mockInstanceAdmin.resps = append(mockInstanceAdmin.resps[:0], expectedResponse) + + var formattedResource string = InstanceAdminInstancePath("[PROJECT]", "[INSTANCE]") + var permissions []string = nil + var request = &iampb.TestIamPermissionsRequest{ + Resource: formattedResource, + Permissions: permissions, + } + + c, err := NewInstanceAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + resp, err := c.TestIamPermissions(context.Background(), request) + + if err != nil { + t.Fatal(err) + } + + if want, got := request, mockInstanceAdmin.reqs[0]; !proto.Equal(want, got) { + t.Errorf("wrong request %q, want %q", got, want) + } + + if want, got := expectedResponse, resp; !proto.Equal(want, got) { + t.Errorf("wrong response %q, want %q)", got, want) + } +} + +func TestInstanceAdminTestIamPermissionsError(t *testing.T) { + errCode := codes.PermissionDenied + mockInstanceAdmin.err = grpc.Errorf(errCode, "test error") + + var formattedResource string = InstanceAdminInstancePath("[PROJECT]", "[INSTANCE]") + var permissions []string = nil + var request = &iampb.TestIamPermissionsRequest{ + Resource: formattedResource, + Permissions: permissions, + } + + c, err := NewInstanceAdminClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + resp, err := c.TestIamPermissions(context.Background(), request) + + if c := grpc.Code(err); c != errCode { + t.Errorf("got error code %q, want %q", c, errCode) + } + _ = resp +} diff --git a/vendor/cloud.google.com/go/spanner/backoff.go b/vendor/cloud.google.com/go/spanner/backoff.go new file mode 100644 index 000000000..d38723843 --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/backoff.go @@ -0,0 +1,58 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spanner + +import ( + "math/rand" + "time" +) + +const ( + // minBackoff is the minimum backoff used by default. + minBackoff = 1 * time.Second + // maxBackoff is the maximum backoff used by default. + maxBackoff = 32 * time.Second + // jitter is the jitter factor. + jitter = 0.4 + // rate is the rate of exponential increase in the backoff. + rate = 1.3 +) + +var defaultBackoff = exponentialBackoff{minBackoff, maxBackoff} + +type exponentialBackoff struct { + min, max time.Duration +} + +// delay calculates the delay that should happen at n-th +// exponential backoff in a series. +func (b exponentialBackoff) delay(retries int) time.Duration { + min, max := float64(b.min), float64(b.max) + delay := min + for delay < max && retries > 0 { + delay *= rate + retries-- + } + if delay > max { + delay = max + } + delay -= delay * jitter * rand.Float64() + if delay < min { + delay = min + } + return time.Duration(delay) +} diff --git a/vendor/cloud.google.com/go/spanner/backoff_test.go b/vendor/cloud.google.com/go/spanner/backoff_test.go new file mode 100644 index 000000000..7a0314e81 --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/backoff_test.go @@ -0,0 +1,62 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spanner + +import ( + "math" + "time" + + "testing" +) + +// Test if exponential backoff helper can produce correct series of +// retry delays. +func TestBackoff(t *testing.T) { + b := exponentialBackoff{minBackoff, maxBackoff} + tests := []struct { + retries int + min time.Duration + max time.Duration + }{ + { + retries: 0, + min: minBackoff, + max: minBackoff, + }, + { + retries: 1, + min: minBackoff, + max: time.Duration(rate * float64(minBackoff)), + }, + { + retries: 3, + min: time.Duration(math.Pow(rate, 3) * (1 - jitter) * float64(minBackoff)), + max: time.Duration(math.Pow(rate, 3) * float64(minBackoff)), + }, + { + retries: 1000, + min: time.Duration((1 - jitter) * float64(maxBackoff)), + max: maxBackoff, + }, + } + for _, test := range tests { + got := b.delay(test.retries) + if float64(got) < float64(test.min) || float64(got) > float64(test.max) { + t.Errorf("delay(%v) = %v, want in range [%v, %v]", test.retries, got, test.min, test.max) + } + } +} diff --git a/vendor/cloud.google.com/go/spanner/client.go b/vendor/cloud.google.com/go/spanner/client.go new file mode 100644 index 000000000..4ccc9333a --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/client.go @@ -0,0 +1,322 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spanner + +import ( + "fmt" + "regexp" + "sync/atomic" + "time" + + "cloud.google.com/go/internal/version" + "golang.org/x/net/context" + "google.golang.org/api/option" + "google.golang.org/api/transport" + sppb "google.golang.org/genproto/googleapis/spanner/v1" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" +) + +const ( + prodAddr = "spanner.googleapis.com:443" + + // resourcePrefixHeader is the name of the metadata header used to indicate + // the resource being operated on. + resourcePrefixHeader = "google-cloud-resource-prefix" + // apiClientHeader is the name of the metadata header used to indicate client + // information. + apiClientHeader = "x-goog-api-client" + + // numChannels is the default value for NumChannels of client + numChannels = 4 +) + +const ( + // Scope is the scope for Cloud Spanner Data API. + Scope = "https://www.googleapis.com/auth/spanner.data" + + // AdminScope is the scope for Cloud Spanner Admin APIs. + AdminScope = "https://www.googleapis.com/auth/spanner.admin" +) + +var ( + validDBPattern = regexp.MustCompile("^projects/[^/]+/instances/[^/]+/databases/[^/]+$") + clientUserAgent = fmt.Sprintf("gl-go/%s gccl/%s grpc/%s", version.Go(), version.Repo, grpc.Version) +) + +func validDatabaseName(db string) error { + if matched := validDBPattern.MatchString(db); !matched { + return fmt.Errorf("database name %q should conform to pattern %q", + db, validDBPattern.String()) + } + return nil +} + +// Client is a client for reading and writing data to a Cloud Spanner database. A +// client is safe to use concurrently, except for its Close method. +type Client struct { + // rr must be accessed through atomic operations. + rr uint32 + conns []*grpc.ClientConn + clients []sppb.SpannerClient + database string + // Metadata to be sent with each request. + md metadata.MD + idleSessions *sessionPool +} + +// ClientConfig has configurations for the client. +type ClientConfig struct { + // NumChannels is the number of GRPC channels. + // If zero, numChannels is used. + NumChannels int + co []option.ClientOption + // SessionPoolConfig is the configuration for session pool. + SessionPoolConfig +} + +// errDial returns error for dialing to Cloud Spanner. +func errDial(ci int, err error) error { + e := toSpannerError(err).(*Error) + e.decorate(fmt.Sprintf("dialing fails for channel[%v]", ci)) + return e +} + +func contextWithOutgoingMetadata(ctx context.Context, md metadata.MD) context.Context { + existing, ok := metadata.FromOutgoingContext(ctx) + if ok { + md = metadata.Join(existing, md) + } + return metadata.NewOutgoingContext(ctx, md) +} + +// NewClient creates a client to a database. A valid database name has the +// form projects/PROJECT_ID/instances/INSTANCE_ID/databases/DATABASE_ID. It uses a default +// configuration. +func NewClient(ctx context.Context, database string, opts ...option.ClientOption) (*Client, error) { + return NewClientWithConfig(ctx, database, ClientConfig{}, opts...) +} + +// NewClientWithConfig creates a client to a database. A valid database name has the +// form projects/PROJECT_ID/instances/INSTANCE_ID/databases/DATABASE_ID. +func NewClientWithConfig(ctx context.Context, database string, config ClientConfig, opts ...option.ClientOption) (*Client, error) { + // Validate database path. + if err := validDatabaseName(database); err != nil { + return nil, err + } + c := &Client{ + database: database, + md: metadata.Pairs( + resourcePrefixHeader, database, + apiClientHeader, clientUserAgent), + } + allOpts := []option.ClientOption{option.WithEndpoint(prodAddr), option.WithScopes(Scope), option.WithUserAgent(clientUserAgent), option.WithGRPCDialOption(grpc.WithDefaultCallOptions(grpc.MaxCallSendMsgSize(100<<20), grpc.MaxCallRecvMsgSize(100<<20)))} + allOpts = append(allOpts, opts...) + // Prepare gRPC channels. + if config.NumChannels == 0 { + config.NumChannels = numChannels + } + // Default MaxOpened sessions + if config.MaxOpened == 0 { + config.MaxOpened = uint64(config.NumChannels * 100) + } + if config.MaxBurst == 0 { + config.MaxBurst = 10 + } + for i := 0; i < config.NumChannels; i++ { + conn, err := transport.DialGRPC(ctx, allOpts...) + if err != nil { + return nil, errDial(i, err) + } + c.conns = append(c.conns, conn) + c.clients = append(c.clients, sppb.NewSpannerClient(conn)) + } + // Prepare session pool. + config.SessionPoolConfig.getRPCClient = func() (sppb.SpannerClient, error) { + // TODO: support more loadbalancing options. + return c.rrNext(), nil + } + sp, err := newSessionPool(database, config.SessionPoolConfig, c.md) + if err != nil { + c.Close() + return nil, err + } + c.idleSessions = sp + return c, nil +} + +// rrNext returns the next available Cloud Spanner RPC client in a round-robin manner. +func (c *Client) rrNext() sppb.SpannerClient { + return c.clients[atomic.AddUint32(&c.rr, 1)%uint32(len(c.clients))] +} + +// Close closes the client. +func (c *Client) Close() { + if c.idleSessions != nil { + c.idleSessions.close() + } + for _, conn := range c.conns { + conn.Close() + } +} + +// Single provides a read-only snapshot transaction optimized for the case +// where only a single read or query is needed. This is more efficient than +// using ReadOnlyTransaction() for a single read or query. +// +// Single will use a strong TimestampBound by default. Use +// ReadOnlyTransaction.WithTimestampBound to specify a different +// TimestampBound. A non-strong bound can be used to reduce latency, or +// "time-travel" to prior versions of the database, see the documentation of +// TimestampBound for details. +func (c *Client) Single() *ReadOnlyTransaction { + t := &ReadOnlyTransaction{singleUse: true, sp: c.idleSessions} + t.txReadOnly.txReadEnv = t + return t +} + +// ReadOnlyTransaction returns a ReadOnlyTransaction that can be used for +// multiple reads from the database. You must call Close() when the +// ReadOnlyTransaction is no longer needed to release resources on the server. +// +// ReadOnlyTransaction will use a strong TimestampBound by default. Use +// ReadOnlyTransaction.WithTimestampBound to specify a different +// TimestampBound. A non-strong bound can be used to reduce latency, or +// "time-travel" to prior versions of the database, see the documentation of +// TimestampBound for details. +func (c *Client) ReadOnlyTransaction() *ReadOnlyTransaction { + t := &ReadOnlyTransaction{ + singleUse: false, + sp: c.idleSessions, + txReadyOrClosed: make(chan struct{}), + } + t.txReadOnly.txReadEnv = t + return t +} + +type transactionInProgressKey struct{} + +func checkNestedTxn(ctx context.Context) error { + if ctx.Value(transactionInProgressKey{}) != nil { + return spannerErrorf(codes.FailedPrecondition, "Cloud Spanner does not support nested transactions") + } + return nil +} + +// ReadWriteTransaction executes a read-write transaction, with retries as +// necessary. +// +// The function f will be called one or more times. It must not maintain +// any state between calls. +// +// If the transaction cannot be committed or if f returns an IsAborted error, +// ReadWriteTransaction will call f again. It will continue to call f until the +// transaction can be committed or the Context times out or is cancelled. If f +// returns an error other than IsAborted, ReadWriteTransaction will abort the +// transaction and return the error. +// +// To limit the number of retries, set a deadline on the Context rather than +// using a fixed limit on the number of attempts. ReadWriteTransaction will +// retry as needed until that deadline is met. +func (c *Client) ReadWriteTransaction(ctx context.Context, f func(context.Context, *ReadWriteTransaction) error) (time.Time, error) { + if err := checkNestedTxn(ctx); err != nil { + return time.Time{}, err + } + var ( + ts time.Time + sh *sessionHandle + ) + err := runRetryable(ctx, func(ctx context.Context) error { + var ( + err error + t *ReadWriteTransaction + ) + if sh == nil || sh.getID() == "" || sh.getClient() == nil { + // Session handle hasn't been allocated or has been destroyed. + sh, err = c.idleSessions.takeWriteSession(ctx) + if err != nil { + // If session retrieval fails, just fail the transaction. + return err + } + t = &ReadWriteTransaction{ + sh: sh, + tx: sh.getTransactionID(), + } + } else { + t = &ReadWriteTransaction{ + sh: sh, + } + } + t.txReadOnly.txReadEnv = t + if err = t.begin(ctx); err != nil { + // Mask error from begin operation as retryable error. + return errRetry(err) + } + ts, err = t.runInTransaction(ctx, f) + if err != nil { + return err + } + return nil + }) + if sh != nil { + sh.recycle() + } + return ts, err +} + +// applyOption controls the behavior of Client.Apply. +type applyOption struct { + // If atLeastOnce == true, Client.Apply will execute the mutations on Cloud Spanner at least once. + atLeastOnce bool +} + +// An ApplyOption is an optional argument to Apply. +type ApplyOption func(*applyOption) + +// ApplyAtLeastOnce returns an ApplyOption that removes replay protection. +// +// With this option, Apply may attempt to apply mutations more than once; if +// the mutations are not idempotent, this may lead to a failure being reported +// when the mutation was applied more than once. For example, an insert may +// fail with ALREADY_EXISTS even though the row did not exist before Apply was +// called. For this reason, most users of the library will prefer not to use +// this option. However, ApplyAtLeastOnce requires only a single RPC, whereas +// Apply's default replay protection may require an additional RPC. So this +// option may be appropriate for latency sensitive and/or high throughput blind +// writing. +func ApplyAtLeastOnce() ApplyOption { + return func(ao *applyOption) { + ao.atLeastOnce = true + } +} + +// Apply applies a list of mutations atomically to the database. +func (c *Client) Apply(ctx context.Context, ms []*Mutation, opts ...ApplyOption) (time.Time, error) { + ao := &applyOption{} + for _, opt := range opts { + opt(ao) + } + if !ao.atLeastOnce { + return c.ReadWriteTransaction(ctx, func(ctx context.Context, t *ReadWriteTransaction) error { + t.BufferWrite(ms) + return nil + }) + } + t := &writeOnlyTransaction{c.idleSessions} + return t.applyAtLeastOnce(ctx, ms...) +} diff --git a/vendor/cloud.google.com/go/spanner/client_test.go b/vendor/cloud.google.com/go/spanner/client_test.go new file mode 100644 index 000000000..643f863b7 --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/client_test.go @@ -0,0 +1,50 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spanner + +import ( + "strings" + "testing" +) + +// Test validDatabaseName() +func TestValidDatabaseName(t *testing.T) { + validDbUri := "projects/spanner-cloud-test/instances/foo/databases/foodb" + invalidDbUris := []string{ + // Completely wrong DB URI. + "foobarDB", + // Project ID contains "/". + "projects/spanner-cloud/test/instances/foo/databases/foodb", + // No instance ID. + "projects/spanner-cloud-test/instances//databases/foodb", + } + if err := validDatabaseName(validDbUri); err != nil { + t.Errorf("validateDatabaseName(%q) = %v, want nil", validDbUri, err) + } + for _, d := range invalidDbUris { + if err, wantErr := validDatabaseName(d), "should conform to pattern"; !strings.Contains(err.Error(), wantErr) { + t.Errorf("validateDatabaseName(%q) = %q, want error pattern %q", validDbUri, err, wantErr) + } + } +} + +func TestReadOnlyTransactionClose(t *testing.T) { + // Closing a ReadOnlyTransaction shouldn't panic. + c := &Client{} + tx := c.ReadOnlyTransaction() + tx.Close() +} diff --git a/vendor/cloud.google.com/go/spanner/doc.go b/vendor/cloud.google.com/go/spanner/doc.go new file mode 100644 index 000000000..88f0b3221 --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/doc.go @@ -0,0 +1,311 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +/* +Package spanner provides a client for reading and writing to Cloud Spanner +databases. See the packages under admin for clients that operate on databases +and instances. + +Note: This package is in alpha. Backwards-incompatible changes may occur +without notice. + +See https://cloud.google.com/spanner/docs/getting-started/go/ for an introduction +to Cloud Spanner and additional help on using this API. + +Creating a Client + +To start working with this package, create a client that refers to the database +of interest: + + ctx := context.Background() + client, err := spanner.NewClient(ctx, "projects/P/instances/I/databases/D") + defer client.Close() + if err != nil { + // TODO: Handle error. + } + +Remember to close the client after use to free up the sessions in the session +pool. + + +Simple Reads and Writes + +Two Client methods, Apply and Single, work well for simple reads and writes. As +a quick introduction, here we write a new row to the database and read it back: + + _, err := client.Apply(ctx, []*spanner.Mutation{ + spanner.Insert("Users", + []string{"name", "email"}, + []interface{}{"alice", "a@example.com"})}) + if err != nil { + // TODO: Handle error. + } + row, err := client.Single().ReadRow(ctx, "Users", + spanner.Key{"alice"}, []string{"email"}) + if err != nil { + // TODO: Handle error. + } + +All the methods used above are discussed in more detail below. + + +Keys + +Every Cloud Spanner row has a unique key, composed of one or more columns. +Construct keys with a literal of type Key: + + key1 := spanner.Key{"alice"} + + +KeyRanges + +The keys of a Cloud Spanner table are ordered. You can specify ranges of keys +using the KeyRange type: + + kr1 := spanner.KeyRange{Start: key1, End: key2} + +By default, a KeyRange includes its start key but not its end key. Use +the Kind field to specify other boundary conditions: + + // include both keys + kr2 := spanner.KeyRange{Start: key1, End: key2, Kind: spanner.ClosedClosed} + + +KeySets + +A KeySet represents a set of keys. A single Key or KeyRange can act as a KeySet. Use +the KeySets function to build the union of several KeySets: + + ks1 := spanner.KeySets(key1, key2, kr1, kr2) + +AllKeys returns a KeySet that refers to all the keys in a table: + + ks2 := spanner.AllKeys() + + +Transactions + +All Cloud Spanner reads and writes occur inside transactions. There are two +types of transactions, read-only and read-write. Read-only transactions cannot +change the database, do not acquire locks, and may access either the current +database state or states in the past. Read-write transactions can read the +database before writing to it, and always apply to the most recent database +state. + + +Single Reads + +The simplest and fastest transaction is a ReadOnlyTransaction that supports a +single read operation. Use Client.Single to create such a transaction. You can +chain the call to Single with a call to a Read method. + +When you only want one row whose key you know, use ReadRow. Provide the table +name, key, and the columns you want to read: + + row, err := client.Single().ReadRow(ctx, "Accounts", spanner.Key{"alice"}, []string{"balance"}) + +Read multiple rows with the Read method. It takes a table name, KeySet, and list +of columns: + + iter := client.Single().Read(ctx, "Accounts", keyset1, columns) + +Read returns a RowIterator. You can call the Do method on the iterator and pass +a callback: + + err := iter.Do(func(row *Row) error { + // TODO: use row + return nil + }) + +RowIterator also follows the standard pattern for the Google +Cloud Client Libraries: + + defer iter.Stop() + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + // TODO: Handle error. + } + // TODO: use row + } + +Always call Stop when you finish using an iterator this way, whether or not you +iterate to the end. (Failing to call Stop could lead you to exhaust the +database's session quota.) + +To read rows with an index, use ReadUsingIndex. + +Statements + +The most general form of reading uses SQL statements. Construct a Statement +with NewStatement, setting any parameters using the Statement's Params map: + + stmt := spanner.NewStatement("SELECT First, Last FROM SINGERS WHERE Last >= @start") + stmt.Params["start"] = "Dylan" + +You can also construct a Statement directly with a struct literal, providing +your own map of parameters. + +Use the Query method to run the statement and obtain an iterator: + + iter := client.Single().Query(ctx, stmt) + + +Rows + +Once you have a Row, via an iterator or a call to ReadRow, you can extract +column values in several ways. Pass in a pointer to a Go variable of the +appropriate type when you extract a value. + +You can extract by column position or name: + + err := row.Column(0, &name) + err = row.ColumnByName("balance", &balance) + +You can extract all the columns at once: + + err = row.Columns(&name, &balance) + +Or you can define a Go struct that corresponds to your columns, and extract +into that: + + var s struct { Name string; Balance int64 } + err = row.ToStruct(&s) + + +For Cloud Spanner columns that may contain NULL, use one of the NullXXX types, +like NullString: + + var ns spanner.NullString + if err =: row.Column(0, &ns); err != nil { + // TODO: Handle error. + } + if ns.Valid { + fmt.Println(ns.StringVal) + } else { + fmt.Println("column is NULL") + } + + +Multiple Reads + +To perform more than one read in a transaction, use ReadOnlyTransaction: + + txn := client.ReadOnlyTransaction() + defer txn.Close() + iter := txn.Query(ctx, stmt1) + // ... + iter = txn.Query(ctx, stmt2) + // ... + +You must call Close when you are done with the transaction. + + +Timestamps and Timestamp Bounds + +Cloud Spanner read-only transactions conceptually perform all their reads at a +single moment in time, called the transaction's read timestamp. Once a read has +started, you can call ReadOnlyTransaction's Timestamp method to obtain the read +timestamp. + +By default, a transaction will pick the most recent time (a time where all +previously committed transactions are visible) for its reads. This provides the +freshest data, but may involve some delay. You can often get a quicker response +if you are willing to tolerate "stale" data. You can control the read timestamp +selected by a transaction by calling the WithTimestampBound method on the +transaction before using it. For example, to perform a query on data that is at +most one minute stale, use + + client.Single(). + WithTimestampBound(spanner.MaxStaleness(1*time.Minute)). + Query(ctx, stmt) + +See the documentation of TimestampBound for more details. + + +Mutations + +To write values to a Cloud Spanner database, construct a Mutation. The spanner +package has functions for inserting, updating and deleting rows. Except for the +Delete methods, which take a Key or KeyRange, each mutation-building function +comes in three varieties. + +One takes lists of columns and values along with the table name: + + m1 := spanner.Insert("Users", + []string{"name", "email"}, + []interface{}{"alice", "a@example.com"}) + +One takes a map from column names to values: + + m2 := spanner.InsertMap("Users", map[string]interface{}{ + "name": "alice", + "email": "a@example.com", + }) + +And the third accepts a struct value, and determines the columns from the +struct field names: + + type User struct { Name, Email string } + u := User{Name: "alice", Email: "a@example.com"} + m3, err := spanner.InsertStruct("Users", u) + + +Writes + +To apply a list of mutations to the database, use Apply: + + _, err := client.Apply(ctx, []*spanner.Mutation{m1, m2, m3}) + +If you need to read before writing in a single transaction, use a +ReadWriteTransaction. ReadWriteTransactions may abort and need to be retried. +You pass in a function to ReadWriteTransaction, and the client will handle the +retries automatically. Use the transaction's BufferWrite method to buffer +mutations, which will all be executed at the end of the transaction: + + _, err := client.ReadWriteTransaction(ctx, func(txn *spanner.ReadWriteTransaction) error { + var balance int64 + row, err := txn.ReadRow(ctx, "Accounts", spanner.Key{"alice"}, []string{"balance"}) + if err != nil { + // This function will be called again if this is an IsAborted error. + return err + } + if err := row.Column(0, &balance); err != nil { + return err + } + + if balance <= 10 { + return errors.New("insufficient funds in account") + } + balance -= 10 + m := spanner.Update("Accounts", []string{"user", "balance"}, []interface{}{"alice", balance}) + txn.BufferWrite([]*spanner.Mutation{m}) + + // The buffered mutation will be committed. If the commit + // fails with an IsAborted error, this function will be called + // again. + return nil + }) + +Authentication + +See examples of authorization and authentication at +https://godoc.org/cloud.google.com/go#pkg-examples. +*/ +package spanner // import "cloud.google.com/go/spanner" diff --git a/vendor/cloud.google.com/go/spanner/errors.go b/vendor/cloud.google.com/go/spanner/errors.go new file mode 100644 index 000000000..13106f2ed --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/errors.go @@ -0,0 +1,108 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spanner + +import ( + "fmt" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" +) + +// Error is the structured error returned by Cloud Spanner client. +type Error struct { + // Code is the canonical error code for describing the nature of a + // particular error. + Code codes.Code + // Desc explains more details of the error. + Desc string + // trailers are the trailers returned in the response, if any. + trailers metadata.MD +} + +// Error implements error.Error. +func (e *Error) Error() string { + if e == nil { + return fmt.Sprintf("spanner: OK") + } + return fmt.Sprintf("spanner: code = %q, desc = %q", e.Code, e.Desc) +} + +// decorate decorates an existing spanner.Error with more information. +func (e *Error) decorate(info string) { + e.Desc = fmt.Sprintf("%v, %v", info, e.Desc) +} + +// spannerErrorf generates a *spanner.Error with the given error code and +// description. +func spannerErrorf(ec codes.Code, format string, args ...interface{}) error { + return &Error{ + Code: ec, + Desc: fmt.Sprintf(format, args...), + } +} + +// toSpannerError converts general Go error to *spanner.Error. +func toSpannerError(err error) error { + return toSpannerErrorWithMetadata(err, nil) +} + +// toSpannerErrorWithMetadata converts general Go error and grpc trailers to *spanner.Error. +// Note: modifies original error if trailers aren't nil +func toSpannerErrorWithMetadata(err error, trailers metadata.MD) error { + if err == nil { + return nil + } + if se, ok := err.(*Error); ok { + if trailers != nil { + se.trailers = metadata.Join(se.trailers, trailers) + } + return se + } + if grpc.Code(err) == codes.Unknown { + return &Error{codes.Unknown, err.Error(), trailers} + } + return &Error{grpc.Code(err), grpc.ErrorDesc(err), trailers} +} + +// ErrCode extracts the canonical error code from a Go error. +func ErrCode(err error) codes.Code { + se, ok := toSpannerError(err).(*Error) + if !ok { + return codes.Unknown + } + return se.Code +} + +// ErrDesc extracts the Cloud Spanner error description from a Go error. +func ErrDesc(err error) string { + se, ok := toSpannerError(err).(*Error) + if !ok { + return err.Error() + } + return se.Desc +} + +// errTrailers extracts the grpc trailers if present from a Go error. +func errTrailers(err error) metadata.MD { + se, ok := err.(*Error) + if !ok { + return nil + } + return se.trailers +} diff --git a/vendor/cloud.google.com/go/spanner/examples_test.go b/vendor/cloud.google.com/go/spanner/examples_test.go new file mode 100644 index 000000000..4cd19451c --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/examples_test.go @@ -0,0 +1,536 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spanner_test + +import ( + "errors" + "fmt" + "time" + + "cloud.google.com/go/spanner" + "golang.org/x/net/context" + "google.golang.org/api/iterator" + sppb "google.golang.org/genproto/googleapis/spanner/v1" +) + +func ExampleNewClient() { + ctx := context.Background() + const myDB = "projects/my-project/instances/my-instance/database/my-db" + client, err := spanner.NewClient(ctx, myDB) + if err != nil { + // TODO: Handle error. + } + _ = client // TODO: Use client. +} + +const myDB = "projects/my-project/instances/my-instance/database/my-db" + +func ExampleNewClientWithConfig() { + ctx := context.Background() + const myDB = "projects/my-project/instances/my-instance/database/my-db" + client, err := spanner.NewClientWithConfig(ctx, myDB, spanner.ClientConfig{ + NumChannels: 10, + }) + if err != nil { + // TODO: Handle error. + } + _ = client // TODO: Use client. + client.Close() // Close client when done. +} + +func ExampleClient_Single() { + ctx := context.Background() + client, err := spanner.NewClient(ctx, myDB) + if err != nil { + // TODO: Handle error. + } + iter := client.Single().Query(ctx, spanner.NewStatement("SELECT FirstName FROM Singers")) + _ = iter // TODO: iterate using Next or Do. +} + +func ExampleClient_ReadOnlyTransaction() { + ctx := context.Background() + client, err := spanner.NewClient(ctx, myDB) + if err != nil { + // TODO: Handle error. + } + t := client.ReadOnlyTransaction() + defer t.Close() + // TODO: Read with t using Read, ReadRow, ReadUsingIndex, or Query. +} + +func ExampleClient_ReadWriteTransaction() { + ctx := context.Background() + client, err := spanner.NewClient(ctx, myDB) + if err != nil { + // TODO: Handle error. + } + _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { + var balance int64 + row, err := txn.ReadRow(ctx, "Accounts", spanner.Key{"alice"}, []string{"balance"}) + if err != nil { + // This function will be called again if this is an + // IsAborted error. + return err + } + if err := row.Column(0, &balance); err != nil { + return err + } + + if balance <= 10 { + return errors.New("insufficient funds in account") + } + balance -= 10 + m := spanner.Update("Accounts", []string{"user", "balance"}, []interface{}{"alice", balance}) + return txn.BufferWrite([]*spanner.Mutation{m}) + // The buffered mutation will be committed. If the commit + // fails with an IsAborted error, this function will be called + // again. + }) + if err != nil { + // TODO: Handle error. + } +} + +func ExampleUpdate() { + ctx := context.Background() + client, err := spanner.NewClient(ctx, myDB) + if err != nil { + // TODO: Handle error. + } + _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { + row, err := txn.ReadRow(ctx, "Accounts", spanner.Key{"alice"}, []string{"balance"}) + if err != nil { + return err + } + var balance int64 + if err := row.Column(0, &balance); err != nil { + return err + } + return txn.BufferWrite([]*spanner.Mutation{ + spanner.Update("Accounts", []string{"user", "balance"}, []interface{}{"alice", balance + 10}), + }) + }) + if err != nil { + // TODO: Handle error. + } +} + +// This example is the same as the one for Update, except for the use of UpdateMap. +func ExampleUpdateMap() { + ctx := context.Background() + client, err := spanner.NewClient(ctx, myDB) + if err != nil { + // TODO: Handle error. + } + _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { + row, err := txn.ReadRow(ctx, "Accounts", spanner.Key{"alice"}, []string{"balance"}) + if err != nil { + return err + } + var balance int64 + if err := row.Column(0, &balance); err != nil { + return err + } + return txn.BufferWrite([]*spanner.Mutation{ + spanner.UpdateMap("Accounts", map[string]interface{}{ + "user": "alice", + "balance": balance + 10, + }), + }) + }) + if err != nil { + // TODO: Handle error. + } +} + +// This example is the same as the one for Update, except for the use of UpdateStruct. +func ExampleUpdateStruct() { + ctx := context.Background() + client, err := spanner.NewClient(ctx, myDB) + if err != nil { + // TODO: Handle error. + } + type account struct { + User string `spanner:"user"` + Balance int64 `spanner:"balance"` + } + _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { + row, err := txn.ReadRow(ctx, "Accounts", spanner.Key{"alice"}, []string{"balance"}) + if err != nil { + return err + } + var balance int64 + if err := row.Column(0, &balance); err != nil { + return err + } + m, err := spanner.UpdateStruct("Accounts", account{ + User: "alice", + Balance: balance + 10, + }) + if err != nil { + return err + } + return txn.BufferWrite([]*spanner.Mutation{m}) + }) + if err != nil { + // TODO: Handle error. + } +} + +func ExampleClient_Apply() { + ctx := context.Background() + client, err := spanner.NewClient(ctx, myDB) + if err != nil { + // TODO: Handle error. + } + m := spanner.Update("Users", []string{"name", "email"}, []interface{}{"alice", "a@example.com"}) + _, err = client.Apply(ctx, []*spanner.Mutation{m}) + if err != nil { + // TODO: Handle error. + } +} + +func ExampleInsert() { + m := spanner.Insert("Users", []string{"name", "email"}, []interface{}{"alice", "a@example.com"}) + _ = m // TODO: use with Client.Apply or in a ReadWriteTransaction. +} + +func ExampleInsertMap() { + m := spanner.InsertMap("Users", map[string]interface{}{ + "name": "alice", + "email": "a@example.com", + }) + _ = m // TODO: use with Client.Apply or in a ReadWriteTransaction. +} + +func ExampleInsertStruct() { + type User struct { + Name, Email string + } + u := User{Name: "alice", Email: "a@example.com"} + m, err := spanner.InsertStruct("Users", u) + if err != nil { + // TODO: Handle error. + } + _ = m // TODO: use with Client.Apply or in a ReadWriteTransaction. +} + +func ExampleDelete() { + m := spanner.Delete("Users", spanner.Key{"alice"}) + _ = m // TODO: use with Client.Apply or in a ReadWriteTransaction. +} + +func ExampleDelete_KeyRange() { + m := spanner.Delete("Users", spanner.KeyRange{ + Start: spanner.Key{"alice"}, + End: spanner.Key{"bob"}, + Kind: spanner.ClosedClosed, + }) + _ = m // TODO: use with Client.Apply or in a ReadWriteTransaction. +} + +func ExampleRowIterator_Next() { + ctx := context.Background() + client, err := spanner.NewClient(ctx, myDB) + if err != nil { + // TODO: Handle error. + } + iter := client.Single().Query(ctx, spanner.NewStatement("SELECT FirstName FROM Singers")) + defer iter.Stop() + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + // TODO: Handle error. + } + var firstName string + if err := row.Column(0, &firstName); err != nil { + // TODO: Handle error. + } + fmt.Println(firstName) + } +} + +func ExampleRowIterator_Do() { + ctx := context.Background() + client, err := spanner.NewClient(ctx, myDB) + if err != nil { + // TODO: Handle error. + } + iter := client.Single().Query(ctx, spanner.NewStatement("SELECT FirstName FROM Singers")) + err = iter.Do(func(r *spanner.Row) error { + var firstName string + if err := r.Column(0, &firstName); err != nil { + return err + } + fmt.Println(firstName) + return nil + }) + if err != nil { + // TODO: Handle error. + } +} + +func ExampleRow_Size() { + ctx := context.Background() + client, err := spanner.NewClient(ctx, myDB) + if err != nil { + // TODO: Handle error. + } + row, err := client.Single().ReadRow(ctx, "Accounts", spanner.Key{"alice"}, []string{"name", "balance"}) + if err != nil { + // TODO: Handle error. + } + fmt.Println(row.Size()) // size is 2 +} + +func ExampleRow_ColumnName() { + ctx := context.Background() + client, err := spanner.NewClient(ctx, myDB) + if err != nil { + // TODO: Handle error. + } + row, err := client.Single().ReadRow(ctx, "Accounts", spanner.Key{"alice"}, []string{"name", "balance"}) + if err != nil { + // TODO: Handle error. + } + fmt.Println(row.ColumnName(1)) // prints "balance" +} + +func ExampleRow_ColumnIndex() { + ctx := context.Background() + client, err := spanner.NewClient(ctx, myDB) + if err != nil { + // TODO: Handle error. + } + row, err := client.Single().ReadRow(ctx, "Accounts", spanner.Key{"alice"}, []string{"name", "balance"}) + if err != nil { + // TODO: Handle error. + } + index, err := row.ColumnIndex("balance") + if err != nil { + // TODO: Handle error. + } + fmt.Println(index) +} + +func ExampleRow_ColumnNames() { + ctx := context.Background() + client, err := spanner.NewClient(ctx, myDB) + if err != nil { + // TODO: Handle error. + } + row, err := client.Single().ReadRow(ctx, "Accounts", spanner.Key{"alice"}, []string{"name", "balance"}) + if err != nil { + // TODO: Handle error. + } + fmt.Println(row.ColumnNames()) +} + +func ExampleRow_ColumnByName() { + ctx := context.Background() + client, err := spanner.NewClient(ctx, myDB) + if err != nil { + // TODO: Handle error. + } + row, err := client.Single().ReadRow(ctx, "Accounts", spanner.Key{"alice"}, []string{"name", "balance"}) + if err != nil { + // TODO: Handle error. + } + var balance int64 + if err := row.ColumnByName("balance", &balance); err != nil { + // TODO: Handle error. + } + fmt.Println(balance) +} + +func ExampleRow_Columns() { + ctx := context.Background() + client, err := spanner.NewClient(ctx, myDB) + if err != nil { + // TODO: Handle error. + } + row, err := client.Single().ReadRow(ctx, "Accounts", spanner.Key{"alice"}, []string{"name", "balance"}) + if err != nil { + // TODO: Handle error. + } + var name string + var balance int64 + if err := row.Columns(&name, &balance); err != nil { + // TODO: Handle error. + } + fmt.Println(name, balance) +} + +func ExampleRow_ToStruct() { + ctx := context.Background() + client, err := spanner.NewClient(ctx, myDB) + if err != nil { + // TODO: Handle error. + } + row, err := client.Single().ReadRow(ctx, "Accounts", spanner.Key{"alice"}, []string{"name", "balance"}) + if err != nil { + // TODO: Handle error. + } + + type Account struct { + Name string + Balance int64 + } + + var acct Account + if err := row.ToStruct(&acct); err != nil { + // TODO: Handle error. + } + fmt.Println(acct) +} + +func ExampleReadOnlyTransaction_Read() { + ctx := context.Background() + client, err := spanner.NewClient(ctx, myDB) + if err != nil { + // TODO: Handle error. + } + iter := client.Single().Read(ctx, "Users", + spanner.KeySets(spanner.Key{"alice"}, spanner.Key{"bob"}), + []string{"name", "email"}) + _ = iter // TODO: iterate using Next or Do. +} + +func ExampleReadOnlyTransaction_ReadUsingIndex() { + ctx := context.Background() + client, err := spanner.NewClient(ctx, myDB) + if err != nil { + // TODO: Handle error. + } + iter := client.Single().ReadUsingIndex(ctx, "Users", + "UsersByEmail", + spanner.KeySets(spanner.Key{"a@example.com"}, spanner.Key{"b@example.com"}), + []string{"name", "email"}) + _ = iter // TODO: iterate using Next or Do. +} + +func ExampleReadOnlyTransaction_ReadRow() { + ctx := context.Background() + client, err := spanner.NewClient(ctx, myDB) + if err != nil { + // TODO: Handle error. + } + row, err := client.Single().ReadRow(ctx, "Users", spanner.Key{"alice"}, + []string{"name", "email"}) + if err != nil { + // TODO: Handle error. + } + _ = row // TODO: use row +} + +func ExampleReadOnlyTransaction_Query() { + ctx := context.Background() + client, err := spanner.NewClient(ctx, myDB) + if err != nil { + // TODO: Handle error. + } + iter := client.Single().Query(ctx, spanner.NewStatement("SELECT FirstName FROM Singers")) + _ = iter // TODO: iterate using Next or Do. +} + +func ExampleNewStatement() { + stmt := spanner.NewStatement("SELECT FirstName, LastName FROM SINGERS WHERE LastName >= @start") + stmt.Params["start"] = "Dylan" + // TODO: Use stmt in Query. +} + +func ExampleNewStatement_structLiteral() { + stmt := spanner.Statement{ + SQL: "SELECT FirstName, LastName FROM SINGERS WHERE LastName >= @start", + Params: map[string]interface{}{"start": "Dylan"}, + } + _ = stmt // TODO: Use stmt in Query. +} + +func ExampleReadOnlyTransaction_Timestamp() { + ctx := context.Background() + client, err := spanner.NewClient(ctx, myDB) + if err != nil { + // TODO: Handle error. + } + txn := client.Single() + row, err := txn.ReadRow(ctx, "Users", spanner.Key{"alice"}, + []string{"name", "email"}) + if err != nil { + // TODO: Handle error. + } + readTimestamp, err := txn.Timestamp() + if err != nil { + // TODO: Handle error. + } + fmt.Println("read happened at", readTimestamp) + _ = row // TODO: use row +} + +func ExampleReadOnlyTransaction_WithTimestampBound() { + ctx := context.Background() + client, err := spanner.NewClient(ctx, myDB) + if err != nil { + // TODO: Handle error. + } + txn := client.Single().WithTimestampBound(spanner.MaxStaleness(30 * time.Second)) + row, err := txn.ReadRow(ctx, "Users", spanner.Key{"alice"}, []string{"name", "email"}) + if err != nil { + // TODO: Handle error. + } + _ = row // TODO: use row + readTimestamp, err := txn.Timestamp() + if err != nil { + // TODO: Handle error. + } + fmt.Println("read happened at", readTimestamp) +} + +func ExampleNewGenericColumnValue_Decode() { + // In real applications, rows can be retrieved by methods like client.Single().ReadRow(). + row, err := spanner.NewRow([]string{"intCol", "strCol"}, []interface{}{42, "my-text"}) + if err != nil { + // TODO: Handle error. + } + for i := 0; i < row.Size(); i++ { + var col spanner.GenericColumnValue + if err := row.Column(i, &col); err != nil { + // TODO: Handle error. + } + switch col.Type.Code { + case sppb.TypeCode_INT64: + var v int64 + if err := col.Decode(&v); err != nil { + // TODO: Handle error. + } + fmt.Println("int", v) + case sppb.TypeCode_STRING: + var v string + if err := col.Decode(&v); err != nil { + // TODO: Handle error. + } + fmt.Println("string", v) + } + } + // Output: + // int 42 + // string my-text +} diff --git a/vendor/cloud.google.com/go/spanner/internal/testutil/mockclient.go b/vendor/cloud.google.com/go/spanner/internal/testutil/mockclient.go new file mode 100644 index 000000000..f278c7cc6 --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/internal/testutil/mockclient.go @@ -0,0 +1,355 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package testutil + +import ( + "errors" + "fmt" + "reflect" + "sync" + "testing" + "time" + + "golang.org/x/net/context" + + "github.com/golang/protobuf/ptypes/empty" + proto3 "github.com/golang/protobuf/ptypes/struct" + pbt "github.com/golang/protobuf/ptypes/timestamp" + + sppb "google.golang.org/genproto/googleapis/spanner/v1" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" +) + +// Action is a mocked RPC activity that MockCloudSpannerClient will take. +type Action struct { + method string + err error +} + +// NewAction creates Action objects. +func NewAction(m string, e error) Action { + return Action{m, e} +} + +// MockCloudSpannerClient is a mock implementation of sppb.SpannerClient. +type MockCloudSpannerClient struct { + mu sync.Mutex + t *testing.T + // Live sessions on the client. + sessions map[string]bool + // Expected set of actions that will be executed by the client. + actions []Action + // Session ping history + pings []string + // Injected error, will be returned by all APIs + injErr map[string]error + // nice client will not fail on any request + nice bool +} + +// NewMockCloudSpannerClient creates new MockCloudSpannerClient instance. +func NewMockCloudSpannerClient(t *testing.T, acts ...Action) *MockCloudSpannerClient { + mc := &MockCloudSpannerClient{t: t, sessions: map[string]bool{}, injErr: map[string]error{}} + mc.SetActions(acts...) + return mc +} + +// MakeNice makes this a nice mock which will not fail on any request. +func (m *MockCloudSpannerClient) MakeNice() { + m.mu.Lock() + defer m.mu.Unlock() + m.nice = true +} + +// MakeStrict makes this a strict mock which will fail on any unexpected request. +func (m *MockCloudSpannerClient) MakeStrict() { + m.mu.Lock() + defer m.mu.Unlock() + m.nice = false +} + +// InjectError injects a global error that will be returned by all APIs regardless of +// the actions array. +func (m *MockCloudSpannerClient) InjectError(method string, err error) { + m.mu.Lock() + defer m.mu.Unlock() + m.injErr[method] = err +} + +// SetActions sets the new set of expected actions to MockCloudSpannerClient. +func (m *MockCloudSpannerClient) SetActions(acts ...Action) { + m.mu.Lock() + defer m.mu.Unlock() + m.actions = []Action{} + for _, act := range acts { + m.actions = append(m.actions, act) + } +} + +// DumpPings dumps the ping history. +func (m *MockCloudSpannerClient) DumpPings() []string { + m.mu.Lock() + defer m.mu.Unlock() + return append([]string(nil), m.pings...) +} + +// DumpSessions dumps the internal session table. +func (m *MockCloudSpannerClient) DumpSessions() map[string]bool { + m.mu.Lock() + defer m.mu.Unlock() + st := map[string]bool{} + for s, v := range m.sessions { + st[s] = v + } + return st +} + +// CreateSession is a placeholder for SpannerClient.CreateSession. +func (m *MockCloudSpannerClient) CreateSession(c context.Context, r *sppb.CreateSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) { + m.mu.Lock() + defer m.mu.Unlock() + if err := m.injErr["CreateSession"]; err != nil { + return nil, err + } + s := &sppb.Session{} + if r.Database != "mockdb" { + // Reject other databases + return s, grpc.Errorf(codes.NotFound, fmt.Sprintf("database not found: %v", r.Database)) + } + // Generate & record session name. + s.Name = fmt.Sprintf("mockdb-%v", time.Now().UnixNano()) + m.sessions[s.Name] = true + return s, nil +} + +// GetSession is a placeholder for SpannerClient.GetSession. +func (m *MockCloudSpannerClient) GetSession(c context.Context, r *sppb.GetSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) { + m.mu.Lock() + defer m.mu.Unlock() + if err := m.injErr["GetSession"]; err != nil { + return nil, err + } + m.pings = append(m.pings, r.Name) + if _, ok := m.sessions[r.Name]; !ok { + return nil, grpc.Errorf(codes.NotFound, fmt.Sprintf("Session not found: %v", r.Name)) + } + return &sppb.Session{Name: r.Name}, nil +} + +// DeleteSession is a placeholder for SpannerClient.DeleteSession. +func (m *MockCloudSpannerClient) DeleteSession(c context.Context, r *sppb.DeleteSessionRequest, opts ...grpc.CallOption) (*empty.Empty, error) { + m.mu.Lock() + defer m.mu.Unlock() + if err := m.injErr["DeleteSession"]; err != nil { + return nil, err + } + if _, ok := m.sessions[r.Name]; !ok { + // Session not found. + return &empty.Empty{}, grpc.Errorf(codes.NotFound, fmt.Sprintf("Session not found: %v", r.Name)) + } + // Delete session from in-memory table. + delete(m.sessions, r.Name) + return &empty.Empty{}, nil +} + +// ExecuteSql is a placeholder for SpannerClient.ExecuteSql. +func (m *MockCloudSpannerClient) ExecuteSql(c context.Context, r *sppb.ExecuteSqlRequest, opts ...grpc.CallOption) (*sppb.ResultSet, error) { + return nil, errors.New("Unimplemented") +} + +// ExecuteStreamingSql is a mock implementation of SpannerClient.ExecuteStreamingSql. +func (m *MockCloudSpannerClient) ExecuteStreamingSql(c context.Context, r *sppb.ExecuteSqlRequest, opts ...grpc.CallOption) (sppb.Spanner_ExecuteStreamingSqlClient, error) { + m.mu.Lock() + defer m.mu.Unlock() + if err := m.injErr["ExecuteStreamingSql"]; err != nil { + return nil, err + } + if len(m.actions) == 0 { + m.t.Fatalf("unexpected ExecuteStreamingSql executed") + } + act := m.actions[0] + m.actions = m.actions[1:] + if act.method != "ExecuteStreamingSql" { + m.t.Fatalf("unexpected ExecuteStreamingSql call, want action: %v", act) + } + wantReq := &sppb.ExecuteSqlRequest{ + Session: "mocksession", + Transaction: &sppb.TransactionSelector{ + Selector: &sppb.TransactionSelector_SingleUse{ + SingleUse: &sppb.TransactionOptions{ + Mode: &sppb.TransactionOptions_ReadOnly_{ + ReadOnly: &sppb.TransactionOptions_ReadOnly{ + TimestampBound: &sppb.TransactionOptions_ReadOnly_Strong{ + Strong: true, + }, + ReturnReadTimestamp: false, + }, + }, + }, + }, + }, + Sql: "mockquery", + Params: &proto3.Struct{ + Fields: map[string]*proto3.Value{"var1": &proto3.Value{Kind: &proto3.Value_StringValue{StringValue: "abc"}}}, + }, + ParamTypes: map[string]*sppb.Type{"var1": &sppb.Type{Code: sppb.TypeCode_STRING}}, + } + if !reflect.DeepEqual(r, wantReq) { + return nil, fmt.Errorf("got query request: %v, want: %v", r, wantReq) + } + if act.err != nil { + return nil, act.err + } + return nil, errors.New("query never succeeds on mock client") +} + +// Read is a placeholder for SpannerClient.Read. +func (m *MockCloudSpannerClient) Read(c context.Context, r *sppb.ReadRequest, opts ...grpc.CallOption) (*sppb.ResultSet, error) { + m.t.Fatalf("Read is unimplemented") + return nil, errors.New("Unimplemented") +} + +// StreamingRead is a placeholder for SpannerClient.StreamingRead. +func (m *MockCloudSpannerClient) StreamingRead(c context.Context, r *sppb.ReadRequest, opts ...grpc.CallOption) (sppb.Spanner_StreamingReadClient, error) { + m.mu.Lock() + defer m.mu.Unlock() + if err := m.injErr["StreamingRead"]; err != nil { + return nil, err + } + if len(m.actions) == 0 { + m.t.Fatalf("unexpected StreamingRead executed") + } + act := m.actions[0] + m.actions = m.actions[1:] + if act.method != "StreamingRead" && act.method != "StreamingIndexRead" { + m.t.Fatalf("unexpected read call, want action: %v", act) + } + wantReq := &sppb.ReadRequest{ + Session: "mocksession", + Transaction: &sppb.TransactionSelector{ + Selector: &sppb.TransactionSelector_SingleUse{ + SingleUse: &sppb.TransactionOptions{ + Mode: &sppb.TransactionOptions_ReadOnly_{ + ReadOnly: &sppb.TransactionOptions_ReadOnly{ + TimestampBound: &sppb.TransactionOptions_ReadOnly_Strong{ + Strong: true, + }, + ReturnReadTimestamp: false, + }, + }, + }, + }, + }, + Table: "t_mock", + Columns: []string{"col1", "col2"}, + KeySet: &sppb.KeySet{ + []*proto3.ListValue{ + &proto3.ListValue{ + Values: []*proto3.Value{ + &proto3.Value{Kind: &proto3.Value_StringValue{StringValue: "foo"}}, + }, + }, + }, + []*sppb.KeyRange{}, + false, + }, + } + if act.method == "StreamingIndexRead" { + wantReq.Index = "idx1" + } + if !reflect.DeepEqual(r, wantReq) { + return nil, fmt.Errorf("got query request: %v, want: %v", r, wantReq) + } + if act.err != nil { + return nil, act.err + } + return nil, errors.New("read never succeeds on mock client") +} + +// BeginTransaction is a placeholder for SpannerClient.BeginTransaction. +func (m *MockCloudSpannerClient) BeginTransaction(c context.Context, r *sppb.BeginTransactionRequest, opts ...grpc.CallOption) (*sppb.Transaction, error) { + m.mu.Lock() + defer m.mu.Unlock() + if !m.nice { + if err := m.injErr["BeginTransaction"]; err != nil { + return nil, err + } + if len(m.actions) == 0 { + m.t.Fatalf("unexpected Begin executed") + } + act := m.actions[0] + m.actions = m.actions[1:] + if act.method != "Begin" { + m.t.Fatalf("unexpected Begin call, want action: %v", act) + } + if act.err != nil { + return nil, act.err + } + } + resp := &sppb.Transaction{Id: []byte("transaction-1")} + if _, ok := r.Options.Mode.(*sppb.TransactionOptions_ReadOnly_); ok { + resp.ReadTimestamp = &pbt.Timestamp{Seconds: 3, Nanos: 4} + } + return resp, nil +} + +// Commit is a placeholder for SpannerClient.Commit. +func (m *MockCloudSpannerClient) Commit(c context.Context, r *sppb.CommitRequest, opts ...grpc.CallOption) (*sppb.CommitResponse, error) { + m.mu.Lock() + defer m.mu.Unlock() + if !m.nice { + if err := m.injErr["Commit"]; err != nil { + return nil, err + } + if len(m.actions) == 0 { + m.t.Fatalf("unexpected Commit executed") + } + act := m.actions[0] + m.actions = m.actions[1:] + if act.method != "Commit" { + m.t.Fatalf("unexpected Commit call, want action: %v", act) + } + if act.err != nil { + return nil, act.err + } + } + return &sppb.CommitResponse{CommitTimestamp: &pbt.Timestamp{Seconds: 1, Nanos: 2}}, nil +} + +// Rollback is a placeholder for SpannerClient.Rollback. +func (m *MockCloudSpannerClient) Rollback(c context.Context, r *sppb.RollbackRequest, opts ...grpc.CallOption) (*empty.Empty, error) { + m.mu.Lock() + defer m.mu.Unlock() + if !m.nice { + if err := m.injErr["Rollback"]; err != nil { + return nil, err + } + if len(m.actions) == 0 { + m.t.Fatalf("unexpected Rollback executed") + } + act := m.actions[0] + m.actions = m.actions[1:] + if act.method != "Rollback" { + m.t.Fatalf("unexpected Rollback call, want action: %v", act) + } + if act.err != nil { + return nil, act.err + } + } + return nil, nil +} diff --git a/vendor/cloud.google.com/go/spanner/internal/testutil/mockserver.go b/vendor/cloud.google.com/go/spanner/internal/testutil/mockserver.go new file mode 100644 index 000000000..7a04e7f7f --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/internal/testutil/mockserver.go @@ -0,0 +1,255 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package testutil + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "testing" + "time" + + "golang.org/x/net/context" + + "github.com/golang/protobuf/ptypes/empty" + proto3 "github.com/golang/protobuf/ptypes/struct" + pbt "github.com/golang/protobuf/ptypes/timestamp" + + sppb "google.golang.org/genproto/googleapis/spanner/v1" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" +) + +var ( + // KvMeta is the Metadata for mocked KV table. + KvMeta = sppb.ResultSetMetadata{ + RowType: &sppb.StructType{ + Fields: []*sppb.StructType_Field{ + { + Name: "Key", + Type: &sppb.Type{Code: sppb.TypeCode_STRING}, + }, + { + Name: "Value", + Type: &sppb.Type{Code: sppb.TypeCode_STRING}, + }, + }, + }, + } +) + +// MockCtlMsg encapsulates PartialResultSet/error that might be sent to +// client +type MockCtlMsg struct { + // If ResumeToken == true, mock server will generate a row with + // resume token. + ResumeToken bool + // If Err != nil, mock server will return error in RPC response. + Err error +} + +// MockCloudSpanner is a mock implementation of SpannerServer interface. +// TODO: make MockCloudSpanner a full-fleged Cloud Spanner implementation. +type MockCloudSpanner struct { + s *grpc.Server + t *testing.T + addr string + msgs chan MockCtlMsg + readTs time.Time + next int +} + +// Addr returns the listening address of mock server. +func (m *MockCloudSpanner) Addr() string { + return m.addr +} + +// AddMsg generates a new mocked row which can be received by client. +func (m *MockCloudSpanner) AddMsg(err error, resumeToken bool) { + msg := MockCtlMsg{ + ResumeToken: resumeToken, + Err: err, + } + if err == io.EOF { + close(m.msgs) + } else { + m.msgs <- msg + } +} + +// Done signals an end to a mocked stream. +func (m *MockCloudSpanner) Done() { + close(m.msgs) +} + +// CreateSession is a placeholder for SpannerServer.CreateSession. +func (m *MockCloudSpanner) CreateSession(c context.Context, r *sppb.CreateSessionRequest) (*sppb.Session, error) { + m.t.Fatalf("CreateSession is unimplemented") + return nil, errors.New("Unimplemented") +} + +// GetSession is a placeholder for SpannerServer.GetSession. +func (m *MockCloudSpanner) GetSession(c context.Context, r *sppb.GetSessionRequest) (*sppb.Session, error) { + m.t.Fatalf("GetSession is unimplemented") + return nil, errors.New("Unimplemented") +} + +// DeleteSession is a placeholder for SpannerServer.DeleteSession. +func (m *MockCloudSpanner) DeleteSession(c context.Context, r *sppb.DeleteSessionRequest) (*empty.Empty, error) { + m.t.Fatalf("DeleteSession is unimplemented") + return nil, errors.New("Unimplemented") +} + +// ExecuteSql is a placeholder for SpannerServer.ExecuteSql. +func (m *MockCloudSpanner) ExecuteSql(c context.Context, r *sppb.ExecuteSqlRequest) (*sppb.ResultSet, error) { + m.t.Fatalf("ExecuteSql is unimplemented") + return nil, errors.New("Unimplemented") +} + +// EncodeResumeToken return mock resume token encoding for an uint64 integer. +func EncodeResumeToken(t uint64) []byte { + rt := make([]byte, 16) + binary.PutUvarint(rt, t) + return rt +} + +// DecodeResumeToken decodes a mock resume token into an uint64 integer. +func DecodeResumeToken(t []byte) (uint64, error) { + s, n := binary.Uvarint(t) + if n <= 0 { + return 0, fmt.Errorf("invalid resume token: %v", t) + } + return s, nil +} + +// ExecuteStreamingSql is a mock implementation of SpannerServer.ExecuteStreamingSql. +func (m *MockCloudSpanner) ExecuteStreamingSql(r *sppb.ExecuteSqlRequest, s sppb.Spanner_ExecuteStreamingSqlServer) error { + switch r.Sql { + case "SELECT * from t_unavailable": + return grpc.Errorf(codes.Unavailable, "mock table unavailable") + case "SELECT t.key key, t.value value FROM t_mock t": + if r.ResumeToken != nil { + s, err := DecodeResumeToken(r.ResumeToken) + if err != nil { + return err + } + m.next = int(s) + 1 + } + for { + msg, more := <-m.msgs + if !more { + break + } + if msg.Err == nil { + var rt []byte + if msg.ResumeToken { + rt = EncodeResumeToken(uint64(m.next)) + } + meta := KvMeta + meta.Transaction = &sppb.Transaction{ + ReadTimestamp: &pbt.Timestamp{ + Seconds: m.readTs.Unix(), + Nanos: int32(m.readTs.Nanosecond()), + }, + } + err := s.Send(&sppb.PartialResultSet{ + Metadata: &meta, + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: fmt.Sprintf("foo-%02d", m.next)}}, + {Kind: &proto3.Value_StringValue{StringValue: fmt.Sprintf("bar-%02d", m.next)}}, + }, + ResumeToken: rt, + }) + m.next = m.next + 1 + if err != nil { + return err + } + continue + } + return msg.Err + } + return nil + default: + return fmt.Errorf("unsupported SQL: %v", r.Sql) + } +} + +// Read is a placeholder for SpannerServer.Read. +func (m *MockCloudSpanner) Read(c context.Context, r *sppb.ReadRequest) (*sppb.ResultSet, error) { + m.t.Fatalf("Read is unimplemented") + return nil, errors.New("Unimplemented") +} + +// StreamingRead is a placeholder for SpannerServer.StreamingRead. +func (m *MockCloudSpanner) StreamingRead(r *sppb.ReadRequest, s sppb.Spanner_StreamingReadServer) error { + m.t.Fatalf("StreamingRead is unimplemented") + return errors.New("Unimplemented") +} + +// BeginTransaction is a placeholder for SpannerServer.BeginTransaction. +func (m *MockCloudSpanner) BeginTransaction(c context.Context, r *sppb.BeginTransactionRequest) (*sppb.Transaction, error) { + m.t.Fatalf("BeginTransaction is unimplemented") + return nil, errors.New("Unimplemented") +} + +// Commit is a placeholder for SpannerServer.Commit. +func (m *MockCloudSpanner) Commit(c context.Context, r *sppb.CommitRequest) (*sppb.CommitResponse, error) { + m.t.Fatalf("Commit is unimplemented") + return nil, errors.New("Unimplemented") +} + +// Rollback is a placeholder for SpannerServer.Rollback. +func (m *MockCloudSpanner) Rollback(c context.Context, r *sppb.RollbackRequest) (*empty.Empty, error) { + m.t.Fatalf("Rollback is unimplemented") + return nil, errors.New("Unimplemented") +} + +// Serve runs a MockCloudSpanner listening on a random localhost address. +func (m *MockCloudSpanner) Serve() { + m.s = grpc.NewServer() + if m.addr == "" { + m.addr = "localhost:0" + } + lis, err := net.Listen("tcp", m.addr) + if err != nil { + m.t.Fatalf("Failed to listen: %v", err) + } + go m.s.Serve(lis) + _, port, err := net.SplitHostPort(lis.Addr().String()) + if err != nil { + m.t.Fatalf("Failed to parse listener address: %v", err) + } + sppb.RegisterSpannerServer(m.s, m) + m.addr = "localhost:" + port +} + +// Stop terminates MockCloudSpanner and closes the serving port. +func (m *MockCloudSpanner) Stop() { + m.s.Stop() +} + +// NewMockCloudSpanner creates a new MockCloudSpanner instance. +func NewMockCloudSpanner(t *testing.T, ts time.Time) *MockCloudSpanner { + mcs := &MockCloudSpanner{ + t: t, + msgs: make(chan MockCtlMsg, 1000), + readTs: ts, + } + return mcs +} diff --git a/vendor/cloud.google.com/go/spanner/key.go b/vendor/cloud.google.com/go/spanner/key.go new file mode 100644 index 000000000..1b780deff --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/key.go @@ -0,0 +1,400 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spanner + +import ( + "bytes" + "fmt" + "time" + + "google.golang.org/grpc/codes" + + "cloud.google.com/go/civil" + proto3 "github.com/golang/protobuf/ptypes/struct" + sppb "google.golang.org/genproto/googleapis/spanner/v1" +) + +// A Key can be either a Cloud Spanner row's primary key or a secondary index key. +// It is essentially an interface{} array, which represents a set of Cloud Spanner +// columns. A Key type has the following usages: +// +// - Used as primary key which uniquely identifies a Cloud Spanner row. +// - Used as secondary index key which maps to a set of Cloud Spanner rows +// indexed under it. +// - Used as endpoints of primary key/secondary index ranges, +// see also the KeyRange type. +// +// Rows that are identified by the Key type are outputs of read operation or targets of +// delete operation in a mutation. Note that for Insert/Update/InsertOrUpdate/Update +// mutation types, although they don't require a primary key explicitly, the column list +// provided must contain enough columns that can comprise a primary key. +// +// Keys are easy to construct. For example, suppose you have a table with a +// primary key of username and product ID. To make a key for this table: +// +// key := spanner.Key{"john", 16} +// +// See the description of Row and Mutation types for how Go types are +// mapped to Cloud Spanner types. For convenience, Key type supports a wide range +// of Go types: +// - int, int8, int16, int32, int64, and NullInt64 are mapped to Cloud Spanner's INT64 type. +// - uint8, uint16 and uint32 are also mapped to Cloud Spanner's INT64 type. +// - float32, float64, NullFloat64 are mapped to Cloud Spanner's FLOAT64 type. +// - bool and NullBool are mapped to Cloud Spanner's BOOL type. +// - []byte is mapped to Cloud Spanner's BYTES type. +// - string and NullString are mapped to Cloud Spanner's STRING type. +// - time.Time and NullTime are mapped to Cloud Spanner's TIMESTAMP type. +// - civil.Date and NullDate are mapped to Cloud Spanner's DATE type. +type Key []interface{} + +// errInvdKeyPartType returns error for unsupported key part type. +func errInvdKeyPartType(part interface{}) error { + return spannerErrorf(codes.InvalidArgument, "key part has unsupported type %T", part) +} + +// keyPartValue converts a part of the Key (which is a valid Cloud Spanner type) +// into a proto3.Value. Used for encoding Key type into protobuf. +func keyPartValue(part interface{}) (pb *proto3.Value, err error) { + switch v := part.(type) { + case int: + pb, _, err = encodeValue(int64(v)) + case int8: + pb, _, err = encodeValue(int64(v)) + case int16: + pb, _, err = encodeValue(int64(v)) + case int32: + pb, _, err = encodeValue(int64(v)) + case uint8: + pb, _, err = encodeValue(int64(v)) + case uint16: + pb, _, err = encodeValue(int64(v)) + case uint32: + pb, _, err = encodeValue(int64(v)) + case float32: + pb, _, err = encodeValue(float64(v)) + case int64, float64, NullInt64, NullFloat64, bool, NullBool, []byte, string, NullString, time.Time, civil.Date, NullTime, NullDate: + pb, _, err = encodeValue(v) + default: + return nil, errInvdKeyPartType(v) + } + return pb, err +} + +// proto converts a spanner.Key into a proto3.ListValue. +func (key Key) proto() (*proto3.ListValue, error) { + lv := &proto3.ListValue{} + lv.Values = make([]*proto3.Value, 0, len(key)) + for _, part := range key { + v, err := keyPartValue(part) + if err != nil { + return nil, err + } + lv.Values = append(lv.Values, v) + } + return lv, nil +} + +// keySetProto lets a single Key act as a KeySet. +func (key Key) keySetProto() (*sppb.KeySet, error) { + kp, err := key.proto() + if err != nil { + return nil, err + } + return &sppb.KeySet{Keys: []*proto3.ListValue{kp}}, nil +} + +// String implements fmt.Stringer for Key. For string, []byte and NullString, it +// prints the uninterpreted bytes of their contents, leaving caller with the +// opportunity to escape the output. +func (key Key) String() string { + b := &bytes.Buffer{} + fmt.Fprint(b, "(") + for i, part := range []interface{}(key) { + if i != 0 { + fmt.Fprint(b, ",") + } + switch v := part.(type) { + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, float32, float64, bool: + // Use %v to print numeric types and bool. + fmt.Fprintf(b, "%v", v) + case string: + fmt.Fprintf(b, "%q", v) + case []byte: + if v != nil { + fmt.Fprintf(b, "%q", v) + } else { + fmt.Fprint(b, "<null>") + } + case NullInt64, NullFloat64, NullBool, NullString, NullTime, NullDate: + // The above types implement fmt.Stringer. + fmt.Fprintf(b, "%s", v) + case civil.Date: + fmt.Fprintf(b, "%q", v) + case time.Time: + fmt.Fprintf(b, "%q", v.Format(time.RFC3339Nano)) + default: + fmt.Fprintf(b, "%v", v) + } + } + fmt.Fprint(b, ")") + return b.String() +} + +// AsPrefix returns a KeyRange for all keys where k is the prefix. +func (k Key) AsPrefix() KeyRange { + return KeyRange{ + Start: k, + End: k, + Kind: ClosedClosed, + } +} + +// KeyRangeKind describes the kind of interval represented by a KeyRange: +// whether it is open or closed on the left and right. +type KeyRangeKind int + +const ( + // ClosedOpen is closed on the left and open on the right: the Start + // key is included, the End key is excluded. + ClosedOpen KeyRangeKind = iota + + // ClosedClosed is closed on the left and the right: both keys are included. + ClosedClosed + + // OpenClosed is open on the left and closed on the right: the Start + // key is excluded, the End key is included. + OpenClosed + + // OpenOpen is open on the left and the right: neither key is included. + OpenOpen +) + +// A KeyRange represents a range of rows in a table or index. +// +// A range has a Start key and an End key. IncludeStart and IncludeEnd +// indicate whether the Start and End keys are included in the range. +// +// For example, consider the following table definition: +// +// CREATE TABLE UserEvents ( +// UserName STRING(MAX), +// EventDate STRING(10), +// ) PRIMARY KEY(UserName, EventDate); +// +// The following keys name rows in this table: +// +// spanner.Key{"Bob", "2014-09-23"} +// spanner.Key{"Alfred", "2015-06-12"} +// +// Since the UserEvents table's PRIMARY KEY clause names two columns, each +// UserEvents key has two elements; the first is the UserName, and the second +// is the EventDate. +// +// Key ranges with multiple components are interpreted lexicographically by +// component using the table or index key's declared sort order. For example, +// the following range returns all events for user "Bob" that occurred in the +// year 2015: +// +// spanner.KeyRange{ +// Start: spanner.Key{"Bob", "2015-01-01"}, +// End: spanner.Key{"Bob", "2015-12-31"}, +// Kind: ClosedClosed, +// } +// +// Start and end keys can omit trailing key components. This affects the +// inclusion and exclusion of rows that exactly match the provided key +// components: if IncludeStart is true, then rows that exactly match the +// provided components of the Start key are included; if IncludeStart is false +// then rows that exactly match are not included. IncludeEnd and End key +// behave in the same fashion. +// +// For example, the following range includes all events for "Bob" that occurred +// during and after the year 2000: +// +// spanner.KeyRange{ +// Start: spanner.Key{"Bob", "2000-01-01"}, +// End: spanner.Key{"Bob"}, +// Kind: ClosedClosed, +// } +// +// The next example retrieves all events for "Bob": +// +// spanner.Key{"Bob"}.AsPrefix() +// +// To retrieve events before the year 2000: +// +// spanner.KeyRange{ +// Start: spanner.Key{"Bob"}, +// End: spanner.Key{"Bob", "2000-01-01"}, +// Kind: ClosedOpen, +// } +// +// Although we specified a Kind for this KeyRange, we didn't need to, because +// the default is ClosedOpen. In later examples we'll omit Kind if it is +// ClosedOpen. +// +// The following range includes all rows in a table or under a +// index: +// +// spanner.AllKeys() +// +// This range returns all users whose UserName begins with any +// character from A to C: +// +// spanner.KeyRange{ +// Start: spanner.Key{"A"}, +// End: spanner.Key{"D"}, +// } +// +// This range returns all users whose UserName begins with B: +// +// spanner.KeyRange{ +// Start: spanner.Key{"B"}, +// End: spanner.Key{"C"}, +// } +// +// Key ranges honor column sort order. For example, suppose a table is defined +// as follows: +// +// CREATE TABLE DescendingSortedTable { +// Key INT64, +// ... +// ) PRIMARY KEY(Key DESC); +// +// The following range retrieves all rows with key values between 1 and 100 +// inclusive: +// +// spanner.KeyRange{ +// Start: spanner.Key{100}, +// End: spanner.Key{1}, +// Kind: ClosedClosed, +// } +// +// Note that 100 is passed as the start, and 1 is passed as the end, because +// Key is a descending column in the schema. +type KeyRange struct { + // Start specifies the left boundary of the key range; End specifies + // the right boundary of the key range. + Start, End Key + + // Kind describes whether the boundaries of the key range include + // their keys. + Kind KeyRangeKind +} + +// String implements fmt.Stringer for KeyRange type. +func (r KeyRange) String() string { + var left, right string + switch r.Kind { + case ClosedClosed: + left, right = "[", "]" + case ClosedOpen: + left, right = "[", ")" + case OpenClosed: + left, right = "(", "]" + case OpenOpen: + left, right = "(", ")" + default: + left, right = "?", "?" + } + return fmt.Sprintf("%s%s,%s%s", left, r.Start, r.End, right) +} + +// proto converts KeyRange into sppb.KeyRange. +func (r KeyRange) proto() (*sppb.KeyRange, error) { + var err error + var start, end *proto3.ListValue + pb := &sppb.KeyRange{} + if start, err = r.Start.proto(); err != nil { + return nil, err + } + if end, err = r.End.proto(); err != nil { + return nil, err + } + if r.Kind == ClosedClosed || r.Kind == ClosedOpen { + pb.StartKeyType = &sppb.KeyRange_StartClosed{StartClosed: start} + } else { + pb.StartKeyType = &sppb.KeyRange_StartOpen{StartOpen: start} + } + if r.Kind == ClosedClosed || r.Kind == OpenClosed { + pb.EndKeyType = &sppb.KeyRange_EndClosed{EndClosed: end} + } else { + pb.EndKeyType = &sppb.KeyRange_EndOpen{EndOpen: end} + } + return pb, nil +} + +// keySetProto lets a KeyRange act as a KeySet. +func (r KeyRange) keySetProto() (*sppb.KeySet, error) { + rp, err := r.proto() + if err != nil { + return nil, err + } + return &sppb.KeySet{Ranges: []*sppb.KeyRange{rp}}, nil +} + +// A KeySet defines a collection of Cloud Spanner keys and/or key ranges. All the +// keys are expected to be in the same table or index. The keys need not be sorted in +// any particular way. +// +// An individual Key can act as a KeySet, as can a KeyRange. Use the KeySets function +// to create a KeySet consisting of multiple Keys and KeyRanges. To obtain an empty +// KeySet, call KeySets with no arguments. +// +// If the same key is specified multiple times in the set (for example if two +// ranges, two keys, or a key and a range overlap), the Cloud Spanner backend behaves +// as if the key were only specified once. +type KeySet interface { + keySetProto() (*sppb.KeySet, error) +} + +// AllKeys returns a KeySet that represents all Keys of a table or a index. +func AllKeys() KeySet { + return all{} +} + +type all struct{} + +func (all) keySetProto() (*sppb.KeySet, error) { + return &sppb.KeySet{All: true}, nil +} + +// KeySets returns the union of the KeySets. If any of the KeySets is AllKeys, then +// the resulting KeySet will be equivalent to AllKeys. +func KeySets(keySets ...KeySet) KeySet { + u := make(union, len(keySets)) + copy(u, keySets) + return u +} + +type union []KeySet + +func (u union) keySetProto() (*sppb.KeySet, error) { + upb := &sppb.KeySet{} + for _, ks := range u { + pb, err := ks.keySetProto() + if err != nil { + return nil, err + } + if pb.All { + return pb, nil + } + upb.Keys = append(upb.Keys, pb.Keys...) + upb.Ranges = append(upb.Ranges, pb.Ranges...) + } + return upb, nil +} diff --git a/vendor/cloud.google.com/go/spanner/key_test.go b/vendor/cloud.google.com/go/spanner/key_test.go new file mode 100644 index 000000000..e552e8647 --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/key_test.go @@ -0,0 +1,373 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spanner + +import ( + "reflect" + "testing" + "time" + + "cloud.google.com/go/civil" + proto3 "github.com/golang/protobuf/ptypes/struct" + sppb "google.golang.org/genproto/googleapis/spanner/v1" +) + +// Test Key.String() and Key.proto(). +func TestKey(t *testing.T) { + tm, _ := time.Parse(time.RFC3339Nano, "2016-11-15T15:04:05.999999999Z") + dt, _ := civil.ParseDate("2016-11-15") + for _, test := range []struct { + k Key + wantProto *proto3.ListValue + wantStr string + }{ + { + k: Key{int(1)}, + wantProto: listValueProto(stringProto("1")), + wantStr: "(1)", + }, + { + k: Key{int8(1)}, + wantProto: listValueProto(stringProto("1")), + wantStr: "(1)", + }, + { + k: Key{int16(1)}, + wantProto: listValueProto(stringProto("1")), + wantStr: "(1)", + }, + { + k: Key{int32(1)}, + wantProto: listValueProto(stringProto("1")), + wantStr: "(1)", + }, + { + k: Key{int64(1)}, + wantProto: listValueProto(stringProto("1")), + wantStr: "(1)", + }, + { + k: Key{uint8(1)}, + wantProto: listValueProto(stringProto("1")), + wantStr: "(1)", + }, + { + k: Key{uint16(1)}, + wantProto: listValueProto(stringProto("1")), + wantStr: "(1)", + }, + { + k: Key{uint32(1)}, + wantProto: listValueProto(stringProto("1")), + wantStr: "(1)", + }, + { + k: Key{true}, + wantProto: listValueProto(boolProto(true)), + wantStr: "(true)", + }, + { + k: Key{float32(1.5)}, + wantProto: listValueProto(floatProto(1.5)), + wantStr: "(1.5)", + }, + { + k: Key{float64(1.5)}, + wantProto: listValueProto(floatProto(1.5)), + wantStr: "(1.5)", + }, + { + k: Key{"value"}, + wantProto: listValueProto(stringProto("value")), + wantStr: `("value")`, + }, + { + k: Key{[]byte(nil)}, + wantProto: listValueProto(nullProto()), + wantStr: "(<null>)", + }, + { + k: Key{[]byte{}}, + wantProto: listValueProto(stringProto("")), + wantStr: `("")`, + }, + { + k: Key{tm}, + wantProto: listValueProto(stringProto("2016-11-15T15:04:05.999999999Z")), + wantStr: `("2016-11-15T15:04:05.999999999Z")`, + }, + {k: Key{dt}, + wantProto: listValueProto(stringProto("2016-11-15")), + wantStr: `("2016-11-15")`, + }, + { + k: Key{[]byte("value")}, + wantProto: listValueProto(bytesProto([]byte("value"))), + wantStr: `("value")`, + }, + { + k: Key{NullInt64{1, true}}, + wantProto: listValueProto(stringProto("1")), + wantStr: "(1)", + }, + { + k: Key{NullInt64{2, false}}, + wantProto: listValueProto(nullProto()), + wantStr: "(<null>)", + }, + { + k: Key{NullFloat64{1.5, true}}, + wantProto: listValueProto(floatProto(1.5)), + wantStr: "(1.5)", + }, + { + k: Key{NullFloat64{2.0, false}}, + wantProto: listValueProto(nullProto()), + wantStr: "(<null>)", + }, + { + k: Key{NullBool{true, true}}, + wantProto: listValueProto(boolProto(true)), + wantStr: "(true)", + }, + { + k: Key{NullBool{true, false}}, + wantProto: listValueProto(nullProto()), + wantStr: "(<null>)", + }, + { + k: Key{NullString{"value", true}}, + wantProto: listValueProto(stringProto("value")), + wantStr: `("value")`, + }, + { + k: Key{NullString{"value", false}}, + wantProto: listValueProto(nullProto()), + wantStr: "(<null>)", + }, + { + k: Key{NullTime{tm, true}}, + wantProto: listValueProto(timeProto(tm)), + wantStr: `("2016-11-15T15:04:05.999999999Z")`, + }, + + { + k: Key{NullTime{time.Now(), false}}, + wantProto: listValueProto(nullProto()), + wantStr: "(<null>)", + }, + { + k: Key{NullDate{dt, true}}, + wantProto: listValueProto(dateProto(dt)), + wantStr: `("2016-11-15")`, + }, + { + k: Key{NullDate{civil.Date{}, false}}, + wantProto: listValueProto(nullProto()), + wantStr: "(<null>)", + }, + { + k: Key{int(1), NullString{"value", false}, "value", 1.5, true}, + wantProto: listValueProto(stringProto("1"), nullProto(), stringProto("value"), floatProto(1.5), boolProto(true)), + wantStr: `(1,<null>,"value",1.5,true)`, + }, + } { + if got := test.k.String(); got != test.wantStr { + t.Errorf("%v.String() = %v, want %v", test.k, got, test.wantStr) + } + gotProto, err := test.k.proto() + if err != nil { + t.Errorf("%v.proto() returns error %v; want nil error", test.k, err) + } + if !reflect.DeepEqual(gotProto, test.wantProto) { + t.Errorf("%v.proto() = \n%v\nwant:\n%v", test.k, gotProto, test.wantProto) + } + } +} + +// Test KeyRange.String() and KeyRange.proto(). +func TestKeyRange(t *testing.T) { + for _, test := range []struct { + kr KeyRange + wantProto *sppb.KeyRange + wantStr string + }{ + { + kr: KeyRange{Key{"A"}, Key{"D"}, OpenOpen}, + wantProto: &sppb.KeyRange{ + &sppb.KeyRange_StartOpen{listValueProto(stringProto("A"))}, + &sppb.KeyRange_EndOpen{listValueProto(stringProto("D"))}, + }, + wantStr: `(("A"),("D"))`, + }, + { + kr: KeyRange{Key{1}, Key{10}, OpenClosed}, + wantProto: &sppb.KeyRange{ + &sppb.KeyRange_StartOpen{listValueProto(stringProto("1"))}, + &sppb.KeyRange_EndClosed{listValueProto(stringProto("10"))}, + }, + wantStr: "((1),(10)]", + }, + { + kr: KeyRange{Key{1.5, 2.1, 0.2}, Key{1.9, 0.7}, ClosedOpen}, + wantProto: &sppb.KeyRange{ + &sppb.KeyRange_StartClosed{listValueProto(floatProto(1.5), floatProto(2.1), floatProto(0.2))}, + &sppb.KeyRange_EndOpen{listValueProto(floatProto(1.9), floatProto(0.7))}, + }, + wantStr: "[(1.5,2.1,0.2),(1.9,0.7))", + }, + { + kr: KeyRange{Key{NullInt64{1, true}}, Key{10}, ClosedClosed}, + wantProto: &sppb.KeyRange{ + &sppb.KeyRange_StartClosed{listValueProto(stringProto("1"))}, + &sppb.KeyRange_EndClosed{listValueProto(stringProto("10"))}, + }, + wantStr: "[(1),(10)]", + }, + } { + if got := test.kr.String(); got != test.wantStr { + t.Errorf("%v.String() = %v, want %v", test.kr, got, test.wantStr) + } + gotProto, err := test.kr.proto() + if err != nil { + t.Errorf("%v.proto() returns error %v; want nil error", test.kr, err) + } + if !reflect.DeepEqual(gotProto, test.wantProto) { + t.Errorf("%v.proto() = \n%v\nwant:\n%v", test.kr, gotProto.String(), test.wantProto.String()) + } + } +} + +func TestPrefixRange(t *testing.T) { + got := Key{1}.AsPrefix() + want := KeyRange{Start: Key{1}, End: Key{1}, Kind: ClosedClosed} + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestKeySets(t *testing.T) { + int1 := intProto(1) + int2 := intProto(2) + int3 := intProto(3) + int4 := intProto(4) + for i, test := range []struct { + ks KeySet + wantProto *sppb.KeySet + }{ + { + KeySets(), + &sppb.KeySet{}, + }, + { + Key{4}, + &sppb.KeySet{ + Keys: []*proto3.ListValue{listValueProto(int4)}, + }, + }, + { + AllKeys(), + &sppb.KeySet{All: true}, + }, + { + KeySets(Key{1, 2}, Key{3, 4}), + &sppb.KeySet{ + Keys: []*proto3.ListValue{ + listValueProto(int1, int2), + listValueProto(int3, int4), + }, + }, + }, + { + KeyRange{Key{1}, Key{2}, ClosedOpen}, + &sppb.KeySet{Ranges: []*sppb.KeyRange{ + &sppb.KeyRange{ + &sppb.KeyRange_StartClosed{listValueProto(int1)}, + &sppb.KeyRange_EndOpen{listValueProto(int2)}, + }, + }}, + }, + { + Key{2}.AsPrefix(), + &sppb.KeySet{Ranges: []*sppb.KeyRange{ + &sppb.KeyRange{ + &sppb.KeyRange_StartClosed{listValueProto(int2)}, + &sppb.KeyRange_EndClosed{listValueProto(int2)}, + }, + }}, + }, + { + KeySets( + KeyRange{Key{1}, Key{2}, ClosedClosed}, + KeyRange{Key{3}, Key{4}, OpenClosed}, + ), + &sppb.KeySet{ + Ranges: []*sppb.KeyRange{ + &sppb.KeyRange{ + &sppb.KeyRange_StartClosed{listValueProto(int1)}, + &sppb.KeyRange_EndClosed{listValueProto(int2)}, + }, + &sppb.KeyRange{ + &sppb.KeyRange_StartOpen{listValueProto(int3)}, + &sppb.KeyRange_EndClosed{listValueProto(int4)}, + }, + }, + }, + }, + { + KeySets( + Key{1}, + KeyRange{Key{2}, Key{3}, ClosedClosed}, + KeyRange{Key{4}, Key{5}, OpenClosed}, + KeySets(), + Key{6}), + &sppb.KeySet{ + Keys: []*proto3.ListValue{ + listValueProto(int1), + listValueProto(intProto(6)), + }, + Ranges: []*sppb.KeyRange{ + &sppb.KeyRange{ + &sppb.KeyRange_StartClosed{listValueProto(int2)}, + &sppb.KeyRange_EndClosed{listValueProto(int3)}, + }, + &sppb.KeyRange{ + &sppb.KeyRange_StartOpen{listValueProto(int4)}, + &sppb.KeyRange_EndClosed{listValueProto(intProto(5))}, + }, + }, + }, + }, + { + KeySets( + Key{1}, + KeyRange{Key{2}, Key{3}, ClosedClosed}, + AllKeys(), + KeyRange{Key{4}, Key{5}, OpenClosed}, + Key{6}), + &sppb.KeySet{All: true}, + }, + } { + gotProto, err := test.ks.keySetProto() + if err != nil { + t.Errorf("#%d: %v.proto() returns error %v; want nil error", i, test.ks, err) + } + if !reflect.DeepEqual(gotProto, test.wantProto) { + t.Errorf("#%d: %v.proto() = \n%v\nwant:\n%v", i, test.ks, gotProto.String(), test.wantProto.String()) + } + } +} diff --git a/vendor/cloud.google.com/go/spanner/mutation.go b/vendor/cloud.google.com/go/spanner/mutation.go new file mode 100644 index 000000000..81f25746d --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/mutation.go @@ -0,0 +1,431 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spanner + +import ( + "reflect" + + proto3 "github.com/golang/protobuf/ptypes/struct" + + sppb "google.golang.org/genproto/googleapis/spanner/v1" + "google.golang.org/grpc/codes" +) + +// op is the mutation operation. +type op int + +const ( + // opDelete removes a row from a table. Succeeds whether or not the + // key was present. + opDelete op = iota + // opInsert inserts a row into a table. If the row already exists, the + // write or transaction fails. + opInsert + // opInsertOrUpdate inserts a row into a table. If the row already + // exists, it updates it instead. Any column values not explicitly + // written are preserved. + opInsertOrUpdate + // opReplace inserts a row into a table, deleting any existing row. + // Unlike InsertOrUpdate, this means any values not explicitly written + // become NULL. + opReplace + // opUpdate updates a row in a table. If the row does not already + // exist, the write or transaction fails. + opUpdate +) + +// A Mutation describes a modification to one or more Cloud Spanner rows. The +// mutation represents an insert, update, delete, etc on a table. +// +// Many mutations can be applied in a single atomic commit. For purposes of +// constraint checking (such as foreign key constraints), the operations can be +// viewed as applying in same order as the mutations are supplied in (so that +// e.g., a row and its logical "child" can be inserted in the same commit). +// +// - The Apply function applies series of mutations. +// - A ReadWriteTransaction applies a series of mutations as part of an +// atomic read-modify-write operation. +// Example: +// +// m := spanner.Insert("User", +// []string{"user_id", "profile"}, +// []interface{}{UserID, profile}) +// _, err := client.Apply(ctx, []*spanner.Mutation{m}) +// +// In this example, we insert a new row into the User table. The primary key +// for the new row is UserID (presuming that "user_id" has been declared as the +// primary key of the "User" table). +// +// Updating a row +// +// Changing the values of columns in an existing row is very similar to +// inserting a new row: +// +// m := spanner.Update("User", +// []string{"user_id", "profile"}, +// []interface{}{UserID, profile}) +// _, err := client.Apply(ctx, []*spanner.Mutation{m}) +// +// Deleting a row +// +// To delete a row, use spanner.Delete: +// +// m := spanner.Delete("User", spanner.Key{UserId}) +// _, err := client.Apply(ctx, []*spanner.Mutation{m}) +// +// spanner.Delete accepts a KeySet, so you can also pass in a KeyRange, or use the +// spanner.KeySets function to build any combination of Keys and KeyRanges. +// +// Note that deleting a row in a table may also delete rows from other tables +// if cascading deletes are specified in those tables' schemas. Delete does +// nothing if the named row does not exist (does not yield an error). +// +// Deleting a field +// +// To delete/clear a field within a row, use spanner.Update with the value nil: +// +// m := spanner.Update("User", +// []string{"user_id", "profile"}, +// []interface{}{UserID, nil}) +// _, err := client.Apply(ctx, []*spanner.Mutation{m}) +// +// The valid Go types and their corresponding Cloud Spanner types that can be +// used in the Insert/Update/InsertOrUpdate functions are: +// +// string, NullString - STRING +// []string, []NullString - STRING ARRAY +// []byte - BYTES +// [][]byte - BYTES ARRAY +// int, int64, NullInt64 - INT64 +// []int, []int64, []NullInt64 - INT64 ARRAY +// bool, NullBool - BOOL +// []bool, []NullBool - BOOL ARRAY +// float64, NullFloat64 - FLOAT64 +// []float64, []NullFloat64 - FLOAT64 ARRAY +// time.Time, NullTime - TIMESTAMP +// []time.Time, []NullTime - TIMESTAMP ARRAY +// Date, NullDate - DATE +// []Date, []NullDate - DATE ARRAY +// +// To compare two Mutations for testing purposes, use reflect.DeepEqual. +type Mutation struct { + // op is the operation type of the mutation. + // See documentation for spanner.op for more details. + op op + // Table is the name of the taget table to be modified. + table string + // keySet is a set of primary keys that names the rows + // in a delete operation. + keySet KeySet + // columns names the set of columns that are going to be + // modified by Insert, InsertOrUpdate, Replace or Update + // operations. + columns []string + // values specifies the new values for the target columns + // named by Columns. + values []interface{} +} + +// mapToMutationParams converts Go map into mutation parameters. +func mapToMutationParams(in map[string]interface{}) ([]string, []interface{}) { + cols := []string{} + vals := []interface{}{} + for k, v := range in { + cols = append(cols, k) + vals = append(vals, v) + } + return cols, vals +} + +// errNotStruct returns error for not getting a go struct type. +func errNotStruct(in interface{}) error { + return spannerErrorf(codes.InvalidArgument, "%T is not a go struct type", in) +} + +// structToMutationParams converts Go struct into mutation parameters. +// If the input is not a valid Go struct type, structToMutationParams +// returns error. +func structToMutationParams(in interface{}) ([]string, []interface{}, error) { + if in == nil { + return nil, nil, errNotStruct(in) + } + v := reflect.ValueOf(in) + t := v.Type() + if t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct { + // t is a pointer to a struct. + if v.IsNil() { + // Return empty results. + return nil, nil, nil + } + // Get the struct value that in points to. + v = v.Elem() + t = t.Elem() + } + if t.Kind() != reflect.Struct { + return nil, nil, errNotStruct(in) + } + fields, err := fieldCache.Fields(t) + if err != nil { + return nil, nil, toSpannerError(err) + } + var cols []string + var vals []interface{} + for _, f := range fields { + cols = append(cols, f.Name) + vals = append(vals, v.FieldByIndex(f.Index).Interface()) + } + return cols, vals, nil +} + +// Insert returns a Mutation to insert a row into a table. If the row already +// exists, the write or transaction fails. +func Insert(table string, cols []string, vals []interface{}) *Mutation { + return &Mutation{ + op: opInsert, + table: table, + columns: cols, + values: vals, + } +} + +// InsertMap returns a Mutation to insert a row into a table, specified by +// a map of column name to value. If the row already exists, the write or +// transaction fails. +func InsertMap(table string, in map[string]interface{}) *Mutation { + cols, vals := mapToMutationParams(in) + return Insert(table, cols, vals) +} + +// InsertStruct returns a Mutation to insert a row into a table, specified by +// a Go struct. If the row already exists, the write or transaction fails. +// +// The in argument must be a struct or a pointer to a struct. Its exported +// fields specify the column names and values. Use a field tag like "spanner:name" +// to provide an alternative column name, or use "spanner:-" to ignore the field. +func InsertStruct(table string, in interface{}) (*Mutation, error) { + cols, vals, err := structToMutationParams(in) + if err != nil { + return nil, err + } + return Insert(table, cols, vals), nil +} + +// Update returns a Mutation to update a row in a table. If the row does not +// already exist, the write or transaction fails. +func Update(table string, cols []string, vals []interface{}) *Mutation { + return &Mutation{ + op: opUpdate, + table: table, + columns: cols, + values: vals, + } +} + +// UpdateMap returns a Mutation to update a row in a table, specified by +// a map of column to value. If the row does not already exist, the write or +// transaction fails. +func UpdateMap(table string, in map[string]interface{}) *Mutation { + cols, vals := mapToMutationParams(in) + return Update(table, cols, vals) +} + +// UpdateStruct returns a Mutation to update a row in a table, specified by a Go +// struct. If the row does not already exist, the write or transaction fails. +func UpdateStruct(table string, in interface{}) (*Mutation, error) { + cols, vals, err := structToMutationParams(in) + if err != nil { + return nil, err + } + return Update(table, cols, vals), nil +} + +// InsertOrUpdate returns a Mutation to insert a row into a table. If the row +// already exists, it updates it instead. Any column values not explicitly +// written are preserved. +// +// For a similar example, See Update. +func InsertOrUpdate(table string, cols []string, vals []interface{}) *Mutation { + return &Mutation{ + op: opInsertOrUpdate, + table: table, + columns: cols, + values: vals, + } +} + +// InsertOrUpdateMap returns a Mutation to insert a row into a table, +// specified by a map of column to value. If the row already exists, it +// updates it instead. Any column values not explicitly written are preserved. +// +// For a similar example, See UpdateMap. +func InsertOrUpdateMap(table string, in map[string]interface{}) *Mutation { + cols, vals := mapToMutationParams(in) + return InsertOrUpdate(table, cols, vals) +} + +// InsertOrUpdateStruct returns a Mutation to insert a row into a table, +// specified by a Go struct. If the row already exists, it updates it instead. +// Any column values not explicitly written are preserved. +// +// The in argument must be a struct or a pointer to a struct. Its exported +// fields specify the column names and values. Use a field tag like "spanner:name" +// to provide an alternative column name, or use "spanner:-" to ignore the field. +// +// For a similar example, See UpdateStruct. +func InsertOrUpdateStruct(table string, in interface{}) (*Mutation, error) { + cols, vals, err := structToMutationParams(in) + if err != nil { + return nil, err + } + return InsertOrUpdate(table, cols, vals), nil +} + +// Replace returns a Mutation to insert a row into a table, deleting any +// existing row. Unlike InsertOrUpdate, this means any values not explicitly +// written become NULL. +// +// For a similar example, See Update. +func Replace(table string, cols []string, vals []interface{}) *Mutation { + return &Mutation{ + op: opReplace, + table: table, + columns: cols, + values: vals, + } +} + +// ReplaceMap returns a Mutation to insert a row into a table, deleting any +// existing row. Unlike InsertOrUpdateMap, this means any values not explicitly +// written become NULL. The row is specified by a map of column to value. +// +// For a similar example, See UpdateMap. +func ReplaceMap(table string, in map[string]interface{}) *Mutation { + cols, vals := mapToMutationParams(in) + return Replace(table, cols, vals) +} + +// ReplaceStruct returns a Mutation to insert a row into a table, deleting any +// existing row. Unlike InsertOrUpdateMap, this means any values not explicitly +// written become NULL. The row is specified by a Go struct. +// +// The in argument must be a struct or a pointer to a struct. Its exported +// fields specify the column names and values. Use a field tag like "spanner:name" +// to provide an alternative column name, or use "spanner:-" to ignore the field. +// +// For a similar example, See UpdateStruct. +func ReplaceStruct(table string, in interface{}) (*Mutation, error) { + cols, vals, err := structToMutationParams(in) + if err != nil { + return nil, err + } + return Replace(table, cols, vals), nil +} + +// Delete removes the rows described by the KeySet from the table. It succeeds +// whether or not the keys were present. +func Delete(table string, ks KeySet) *Mutation { + return &Mutation{ + op: opDelete, + table: table, + keySet: ks, + } +} + +// prepareWrite generates sppb.Mutation_Write from table name, column names +// and new column values. +func prepareWrite(table string, columns []string, vals []interface{}) (*sppb.Mutation_Write, error) { + v, err := encodeValueArray(vals) + if err != nil { + return nil, err + } + return &sppb.Mutation_Write{ + Table: table, + Columns: columns, + Values: []*proto3.ListValue{v}, + }, nil +} + +// errInvdMutationOp returns error for unrecognized mutation operation. +func errInvdMutationOp(m Mutation) error { + return spannerErrorf(codes.InvalidArgument, "Unknown op type: %d", m.op) +} + +// proto converts spanner.Mutation to sppb.Mutation, in preparation to send +// RPCs. +func (m Mutation) proto() (*sppb.Mutation, error) { + var pb *sppb.Mutation + switch m.op { + case opDelete: + var kp *sppb.KeySet + if m.keySet != nil { + var err error + kp, err = m.keySet.keySetProto() + if err != nil { + return nil, err + } + } + pb = &sppb.Mutation{ + Operation: &sppb.Mutation_Delete_{ + Delete: &sppb.Mutation_Delete{ + Table: m.table, + KeySet: kp, + }, + }, + } + case opInsert: + w, err := prepareWrite(m.table, m.columns, m.values) + if err != nil { + return nil, err + } + pb = &sppb.Mutation{Operation: &sppb.Mutation_Insert{Insert: w}} + case opInsertOrUpdate: + w, err := prepareWrite(m.table, m.columns, m.values) + if err != nil { + return nil, err + } + pb = &sppb.Mutation{Operation: &sppb.Mutation_InsertOrUpdate{InsertOrUpdate: w}} + case opReplace: + w, err := prepareWrite(m.table, m.columns, m.values) + if err != nil { + return nil, err + } + pb = &sppb.Mutation{Operation: &sppb.Mutation_Replace{Replace: w}} + case opUpdate: + w, err := prepareWrite(m.table, m.columns, m.values) + if err != nil { + return nil, err + } + pb = &sppb.Mutation{Operation: &sppb.Mutation_Update{Update: w}} + default: + return nil, errInvdMutationOp(m) + } + return pb, nil +} + +// mutationsProto turns a spanner.Mutation array into a sppb.Mutation array, +// it is convenient for sending batch mutations to Cloud Spanner. +func mutationsProto(ms []*Mutation) ([]*sppb.Mutation, error) { + l := make([]*sppb.Mutation, 0, len(ms)) + for _, m := range ms { + pb, err := m.proto() + if err != nil { + return nil, err + } + l = append(l, pb) + } + return l, nil +} diff --git a/vendor/cloud.google.com/go/spanner/mutation_test.go b/vendor/cloud.google.com/go/spanner/mutation_test.go new file mode 100644 index 000000000..795db7c89 --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/mutation_test.go @@ -0,0 +1,543 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spanner + +import ( + "reflect" + "sort" + "strings" + "testing" + + proto3 "github.com/golang/protobuf/ptypes/struct" + + sppb "google.golang.org/genproto/googleapis/spanner/v1" +) + +// keysetProto returns protobuf encoding of valid spanner.KeySet. +func keysetProto(t *testing.T, ks KeySet) *sppb.KeySet { + k, err := ks.keySetProto() + if err != nil { + t.Fatalf("cannot convert keyset %v to protobuf: %v", ks, err) + } + return k +} + +// Test encoding from spanner.Mutation to protobuf. +func TestMutationToProto(t *testing.T) { + for i, test := range []struct { + m *Mutation + want *sppb.Mutation + }{ + // Delete Mutation + { + &Mutation{opDelete, "t_foo", Key{"foo"}, nil, nil}, + &sppb.Mutation{ + Operation: &sppb.Mutation_Delete_{ + Delete: &sppb.Mutation_Delete{ + Table: "t_foo", + KeySet: keysetProto(t, Key{"foo"}), + }, + }, + }, + }, + // Insert Mutation + { + &Mutation{opInsert, "t_foo", KeySets(), []string{"col1", "col2"}, []interface{}{int64(1), int64(2)}}, + &sppb.Mutation{ + Operation: &sppb.Mutation_Insert{ + Insert: &sppb.Mutation_Write{ + Table: "t_foo", + Columns: []string{"col1", "col2"}, + Values: []*proto3.ListValue{ + &proto3.ListValue{ + Values: []*proto3.Value{intProto(1), intProto(2)}, + }, + }, + }, + }, + }, + }, + // InsertOrUpdate Mutation + { + &Mutation{opInsertOrUpdate, "t_foo", KeySets(), []string{"col1", "col2"}, []interface{}{1.0, 2.0}}, + &sppb.Mutation{ + Operation: &sppb.Mutation_InsertOrUpdate{ + InsertOrUpdate: &sppb.Mutation_Write{ + Table: "t_foo", + Columns: []string{"col1", "col2"}, + Values: []*proto3.ListValue{ + &proto3.ListValue{ + Values: []*proto3.Value{floatProto(1.0), floatProto(2.0)}, + }, + }, + }, + }, + }, + }, + // Replace Mutation + { + &Mutation{opReplace, "t_foo", KeySets(), []string{"col1", "col2"}, []interface{}{"one", 2.0}}, + &sppb.Mutation{ + Operation: &sppb.Mutation_Replace{ + Replace: &sppb.Mutation_Write{ + Table: "t_foo", + Columns: []string{"col1", "col2"}, + Values: []*proto3.ListValue{ + &proto3.ListValue{ + Values: []*proto3.Value{stringProto("one"), floatProto(2.0)}, + }, + }, + }, + }, + }, + }, + // Update Mutation + { + &Mutation{opUpdate, "t_foo", KeySets(), []string{"col1", "col2"}, []interface{}{"one", []byte(nil)}}, + &sppb.Mutation{ + Operation: &sppb.Mutation_Update{ + Update: &sppb.Mutation_Write{ + Table: "t_foo", + Columns: []string{"col1", "col2"}, + Values: []*proto3.ListValue{ + &proto3.ListValue{ + Values: []*proto3.Value{stringProto("one"), nullProto()}, + }, + }, + }, + }, + }, + }, + } { + if got, err := test.m.proto(); err != nil || !reflect.DeepEqual(got, test.want) { + t.Errorf("%d: (%#v).proto() = (%v, %v), want (%v, nil)", i, test.m, got, err, test.want) + } + } +} + +// mutationColumnSorter implements sort.Interface for sorting column-value pairs in a Mutation by column names. +type mutationColumnSorter struct { + Mutation +} + +// newMutationColumnSorter creates new instance of mutationColumnSorter by duplicating the input Mutation so that +// sorting won't change the input Mutation. +func newMutationColumnSorter(m *Mutation) *mutationColumnSorter { + return &mutationColumnSorter{ + Mutation{ + m.op, + m.table, + m.keySet, + append([]string(nil), m.columns...), + append([]interface{}(nil), m.values...), + }, + } +} + +// Len implements sort.Interface.Len. +func (ms *mutationColumnSorter) Len() int { + return len(ms.columns) +} + +// Swap implements sort.Interface.Swap. +func (ms *mutationColumnSorter) Swap(i, j int) { + ms.columns[i], ms.columns[j] = ms.columns[j], ms.columns[i] + ms.values[i], ms.values[j] = ms.values[j], ms.values[i] +} + +// Less implements sort.Interface.Less. +func (ms *mutationColumnSorter) Less(i, j int) bool { + return strings.Compare(ms.columns[i], ms.columns[j]) < 0 +} + +// mutationEqual returns true if two mutations in question are equal +// to each other. +func mutationEqual(t *testing.T, m1, m2 Mutation) bool { + // Two mutations are considered to be equal even if their column values have different + // orders. + ms1 := newMutationColumnSorter(&m1) + ms2 := newMutationColumnSorter(&m2) + sort.Sort(ms1) + sort.Sort(ms2) + return reflect.DeepEqual(ms1, ms2) +} + +// Test helper functions which help to generate spanner.Mutation. +func TestMutationHelpers(t *testing.T) { + for _, test := range []struct { + m string + got *Mutation + want *Mutation + }{ + { + "Insert", + Insert("t_foo", []string{"col1", "col2"}, []interface{}{int64(1), int64(2)}), + &Mutation{opInsert, "t_foo", nil, []string{"col1", "col2"}, []interface{}{int64(1), int64(2)}}, + }, + { + "InsertMap", + InsertMap("t_foo", map[string]interface{}{"col1": int64(1), "col2": int64(2)}), + &Mutation{opInsert, "t_foo", nil, []string{"col1", "col2"}, []interface{}{int64(1), int64(2)}}, + }, + { + "InsertStruct", + func() *Mutation { + m, err := InsertStruct( + "t_foo", + struct { + notCol bool + Col1 int64 `spanner:"col1"` + Col2 int64 `spanner:"col2"` + }{false, int64(1), int64(2)}, + ) + if err != nil { + t.Errorf("cannot convert struct into mutation: %v", err) + } + return m + }(), + &Mutation{opInsert, "t_foo", nil, []string{"col1", "col2"}, []interface{}{int64(1), int64(2)}}, + }, + { + "Update", + Update("t_foo", []string{"col1", "col2"}, []interface{}{"one", []byte(nil)}), + &Mutation{opUpdate, "t_foo", nil, []string{"col1", "col2"}, []interface{}{"one", []byte(nil)}}, + }, + { + "UpdateMap", + UpdateMap("t_foo", map[string]interface{}{"col1": "one", "col2": []byte(nil)}), + &Mutation{opUpdate, "t_foo", nil, []string{"col1", "col2"}, []interface{}{"one", []byte(nil)}}, + }, + { + "UpdateStruct", + func() *Mutation { + m, err := UpdateStruct( + "t_foo", + struct { + Col1 string `spanner:"col1"` + notCol int + Col2 []byte `spanner:"col2"` + }{"one", 1, nil}, + ) + if err != nil { + t.Errorf("cannot convert struct into mutation: %v", err) + } + return m + }(), + &Mutation{opUpdate, "t_foo", nil, []string{"col1", "col2"}, []interface{}{"one", []byte(nil)}}, + }, + { + "InsertOrUpdate", + InsertOrUpdate("t_foo", []string{"col1", "col2"}, []interface{}{1.0, 2.0}), + &Mutation{opInsertOrUpdate, "t_foo", nil, []string{"col1", "col2"}, []interface{}{1.0, 2.0}}, + }, + { + "InsertOrUpdateMap", + InsertOrUpdateMap("t_foo", map[string]interface{}{"col1": 1.0, "col2": 2.0}), + &Mutation{opInsertOrUpdate, "t_foo", nil, []string{"col1", "col2"}, []interface{}{1.0, 2.0}}, + }, + { + "InsertOrUpdateStruct", + func() *Mutation { + m, err := InsertOrUpdateStruct( + "t_foo", + struct { + Col1 float64 `spanner:"col1"` + Col2 float64 `spanner:"col2"` + notCol float64 + }{1.0, 2.0, 3.0}, + ) + if err != nil { + t.Errorf("cannot convert struct into mutation: %v", err) + } + return m + }(), + &Mutation{opInsertOrUpdate, "t_foo", nil, []string{"col1", "col2"}, []interface{}{1.0, 2.0}}, + }, + { + "Replace", + Replace("t_foo", []string{"col1", "col2"}, []interface{}{"one", 2.0}), + &Mutation{opReplace, "t_foo", nil, []string{"col1", "col2"}, []interface{}{"one", 2.0}}, + }, + { + "ReplaceMap", + ReplaceMap("t_foo", map[string]interface{}{"col1": "one", "col2": 2.0}), + &Mutation{opReplace, "t_foo", nil, []string{"col1", "col2"}, []interface{}{"one", 2.0}}, + }, + { + "ReplaceStruct", + func() *Mutation { + m, err := ReplaceStruct( + "t_foo", + struct { + Col1 string `spanner:"col1"` + Col2 float64 `spanner:"col2"` + notCol string + }{"one", 2.0, "foo"}, + ) + if err != nil { + t.Errorf("cannot convert struct into mutation: %v", err) + } + return m + }(), + &Mutation{opReplace, "t_foo", nil, []string{"col1", "col2"}, []interface{}{"one", 2.0}}, + }, + { + "Delete", + Delete("t_foo", Key{"foo"}), + &Mutation{opDelete, "t_foo", Key{"foo"}, nil, nil}, + }, + { + "DeleteRange", + Delete("t_foo", KeyRange{Key{"bar"}, Key{"foo"}, ClosedClosed}), + &Mutation{opDelete, "t_foo", KeyRange{Key{"bar"}, Key{"foo"}, ClosedClosed}, nil, nil}, + }, + } { + if !mutationEqual(t, *test.got, *test.want) { + t.Errorf("%v: got Mutation %v, want %v", test.m, test.got, test.want) + } + } +} + +// Test encoding non-struct types by using *Struct helpers. +func TestBadStructs(t *testing.T) { + val := "i_am_not_a_struct" + wantErr := errNotStruct(val) + if _, gotErr := InsertStruct("t_test", val); !reflect.DeepEqual(gotErr, wantErr) { + t.Errorf("InsertStruct(%q) returns error %v, want %v", val, gotErr, wantErr) + } + if _, gotErr := InsertOrUpdateStruct("t_test", val); !reflect.DeepEqual(gotErr, wantErr) { + t.Errorf("InsertOrUpdateStruct(%q) returns error %v, want %v", val, gotErr, wantErr) + } + if _, gotErr := UpdateStruct("t_test", val); !reflect.DeepEqual(gotErr, wantErr) { + t.Errorf("UpdateStruct(%q) returns error %v, want %v", val, gotErr, wantErr) + } + if _, gotErr := ReplaceStruct("t_test", val); !reflect.DeepEqual(gotErr, wantErr) { + t.Errorf("ReplaceStruct(%q) returns error %v, want %v", val, gotErr, wantErr) + } +} + +// Test encoding Mutation into proto. +func TestEncodeMutation(t *testing.T) { + for _, test := range []struct { + name string + mutation Mutation + wantProto *sppb.Mutation + wantErr error + }{ + { + "OpDelete", + Mutation{opDelete, "t_test", Key{1}, nil, nil}, + &sppb.Mutation{ + Operation: &sppb.Mutation_Delete_{ + Delete: &sppb.Mutation_Delete{ + Table: "t_test", + KeySet: &sppb.KeySet{ + Keys: []*proto3.ListValue{listValueProto(intProto(1))}, + }, + }, + }, + }, + nil, + }, + { + "OpDelete - Key error", + Mutation{opDelete, "t_test", Key{struct{}{}}, nil, nil}, + &sppb.Mutation{ + Operation: &sppb.Mutation_Delete_{ + Delete: &sppb.Mutation_Delete{ + Table: "t_test", + KeySet: &sppb.KeySet{}, + }, + }, + }, + errInvdKeyPartType(struct{}{}), + }, + { + "OpInsert", + Mutation{opInsert, "t_test", nil, []string{"key", "val"}, []interface{}{"foo", 1}}, + &sppb.Mutation{ + Operation: &sppb.Mutation_Insert{ + Insert: &sppb.Mutation_Write{ + Table: "t_test", + Columns: []string{"key", "val"}, + Values: []*proto3.ListValue{listValueProto(stringProto("foo"), intProto(1))}, + }, + }, + }, + nil, + }, + { + "OpInsert - Value Type Error", + Mutation{opInsert, "t_test", nil, []string{"key", "val"}, []interface{}{struct{}{}, 1}}, + &sppb.Mutation{ + Operation: &sppb.Mutation_Insert{ + Insert: &sppb.Mutation_Write{}, + }, + }, + errEncoderUnsupportedType(struct{}{}), + }, + { + "OpInsertOrUpdate", + Mutation{opInsertOrUpdate, "t_test", nil, []string{"key", "val"}, []interface{}{"foo", 1}}, + &sppb.Mutation{ + Operation: &sppb.Mutation_InsertOrUpdate{ + InsertOrUpdate: &sppb.Mutation_Write{ + Table: "t_test", + Columns: []string{"key", "val"}, + Values: []*proto3.ListValue{listValueProto(stringProto("foo"), intProto(1))}, + }, + }, + }, + nil, + }, + { + "OpInsertOrUpdate - Value Type Error", + Mutation{opInsertOrUpdate, "t_test", nil, []string{"key", "val"}, []interface{}{struct{}{}, 1}}, + &sppb.Mutation{ + Operation: &sppb.Mutation_InsertOrUpdate{ + InsertOrUpdate: &sppb.Mutation_Write{}, + }, + }, + errEncoderUnsupportedType(struct{}{}), + }, + { + "OpReplace", + Mutation{opReplace, "t_test", nil, []string{"key", "val"}, []interface{}{"foo", 1}}, + &sppb.Mutation{ + Operation: &sppb.Mutation_Replace{ + Replace: &sppb.Mutation_Write{ + Table: "t_test", + Columns: []string{"key", "val"}, + Values: []*proto3.ListValue{listValueProto(stringProto("foo"), intProto(1))}, + }, + }, + }, + nil, + }, + { + "OpReplace - Value Type Error", + Mutation{opReplace, "t_test", nil, []string{"key", "val"}, []interface{}{struct{}{}, 1}}, + &sppb.Mutation{ + Operation: &sppb.Mutation_Replace{ + Replace: &sppb.Mutation_Write{}, + }, + }, + errEncoderUnsupportedType(struct{}{}), + }, + { + "OpUpdate", + Mutation{opUpdate, "t_test", nil, []string{"key", "val"}, []interface{}{"foo", 1}}, + &sppb.Mutation{ + Operation: &sppb.Mutation_Update{ + Update: &sppb.Mutation_Write{ + Table: "t_test", + Columns: []string{"key", "val"}, + Values: []*proto3.ListValue{listValueProto(stringProto("foo"), intProto(1))}, + }, + }, + }, + nil, + }, + { + "OpUpdate - Value Type Error", + Mutation{opUpdate, "t_test", nil, []string{"key", "val"}, []interface{}{struct{}{}, 1}}, + &sppb.Mutation{ + Operation: &sppb.Mutation_Update{ + Update: &sppb.Mutation_Write{}, + }, + }, + errEncoderUnsupportedType(struct{}{}), + }, + { + "OpKnown - Unknown Mutation Operation Code", + Mutation{op(100), "t_test", nil, nil, nil}, + &sppb.Mutation{}, + errInvdMutationOp(Mutation{op(100), "t_test", nil, nil, nil}), + }, + } { + gotProto, gotErr := test.mutation.proto() + if gotErr != nil { + if !reflect.DeepEqual(gotErr, test.wantErr) { + t.Errorf("%s: %v.proto() returns error %v, want %v", test.name, test.mutation, gotErr, test.wantErr) + } + continue + } + if !reflect.DeepEqual(gotProto, test.wantProto) { + t.Errorf("%s: %v.proto() = (%v, nil), want (%v, nil)", test.name, test.mutation, gotProto, test.wantProto) + } + } +} + +// Test Encoding an array of mutations. +func TestEncodeMutationArray(t *testing.T) { + for _, test := range []struct { + name string + ms []*Mutation + want []*sppb.Mutation + wantErr error + }{ + { + "Multiple Mutations", + []*Mutation{ + &Mutation{opDelete, "t_test", Key{"bar"}, nil, nil}, + &Mutation{opInsertOrUpdate, "t_test", nil, []string{"key", "val"}, []interface{}{"foo", 1}}, + }, + []*sppb.Mutation{ + &sppb.Mutation{ + Operation: &sppb.Mutation_Delete_{ + Delete: &sppb.Mutation_Delete{ + Table: "t_test", + KeySet: &sppb.KeySet{ + Keys: []*proto3.ListValue{listValueProto(stringProto("bar"))}, + }, + }, + }, + }, + &sppb.Mutation{ + Operation: &sppb.Mutation_InsertOrUpdate{ + InsertOrUpdate: &sppb.Mutation_Write{ + Table: "t_test", + Columns: []string{"key", "val"}, + Values: []*proto3.ListValue{listValueProto(stringProto("foo"), intProto(1))}, + }, + }, + }, + }, + nil, + }, + { + "Multiple Mutations - Bad Mutation", + []*Mutation{ + &Mutation{opDelete, "t_test", Key{"bar"}, nil, nil}, + &Mutation{opInsertOrUpdate, "t_test", nil, []string{"key", "val"}, []interface{}{"foo", struct{}{}}}, + }, + []*sppb.Mutation{}, + errEncoderUnsupportedType(struct{}{}), + }, + } { + gotProto, gotErr := mutationsProto(test.ms) + if gotErr != nil { + if !reflect.DeepEqual(gotErr, test.wantErr) { + t.Errorf("%v: mutationsProto(%v) returns error %v, want %v", test.name, test.ms, gotErr, test.wantErr) + } + continue + } + if !reflect.DeepEqual(gotProto, test.want) { + t.Errorf("%v: mutationsProto(%v) = (%v, nil), want (%v, nil)", test.name, test.ms, gotProto, test.want) + } + } +} diff --git a/vendor/cloud.google.com/go/spanner/protoutils.go b/vendor/cloud.google.com/go/spanner/protoutils.go new file mode 100644 index 000000000..df12432d5 --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/protoutils.go @@ -0,0 +1,113 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spanner + +import ( + "encoding/base64" + "strconv" + "time" + + "cloud.google.com/go/civil" + proto3 "github.com/golang/protobuf/ptypes/struct" + sppb "google.golang.org/genproto/googleapis/spanner/v1" +) + +// Helpers to generate protobuf values and Cloud Spanner types. + +func stringProto(s string) *proto3.Value { + return &proto3.Value{Kind: stringKind(s)} +} + +func stringKind(s string) *proto3.Value_StringValue { + return &proto3.Value_StringValue{StringValue: s} +} + +func stringType() *sppb.Type { + return &sppb.Type{Code: sppb.TypeCode_STRING} +} + +func boolProto(b bool) *proto3.Value { + return &proto3.Value{Kind: &proto3.Value_BoolValue{BoolValue: b}} +} + +func boolType() *sppb.Type { + return &sppb.Type{Code: sppb.TypeCode_BOOL} +} + +func intProto(n int64) *proto3.Value { + return &proto3.Value{Kind: &proto3.Value_StringValue{StringValue: strconv.FormatInt(n, 10)}} +} + +func intType() *sppb.Type { + return &sppb.Type{Code: sppb.TypeCode_INT64} +} + +func floatProto(n float64) *proto3.Value { + return &proto3.Value{Kind: &proto3.Value_NumberValue{NumberValue: n}} +} + +func floatType() *sppb.Type { + return &sppb.Type{Code: sppb.TypeCode_FLOAT64} +} + +func bytesProto(b []byte) *proto3.Value { + return &proto3.Value{Kind: &proto3.Value_StringValue{StringValue: base64.StdEncoding.EncodeToString(b)}} +} + +func bytesType() *sppb.Type { + return &sppb.Type{Code: sppb.TypeCode_BYTES} +} + +func timeProto(t time.Time) *proto3.Value { + return stringProto(t.UTC().Format(time.RFC3339Nano)) +} + +func timeType() *sppb.Type { + return &sppb.Type{Code: sppb.TypeCode_TIMESTAMP} +} + +func dateProto(d civil.Date) *proto3.Value { + return stringProto(d.String()) +} + +func dateType() *sppb.Type { + return &sppb.Type{Code: sppb.TypeCode_DATE} +} + +func listProto(p ...*proto3.Value) *proto3.Value { + return &proto3.Value{Kind: &proto3.Value_ListValue{ListValue: &proto3.ListValue{Values: p}}} +} + +func listValueProto(p ...*proto3.Value) *proto3.ListValue { + return &proto3.ListValue{Values: p} +} + +func listType(t *sppb.Type) *sppb.Type { + return &sppb.Type{Code: sppb.TypeCode_ARRAY, ArrayElementType: t} +} + +func mkField(n string, t *sppb.Type) *sppb.StructType_Field { + return &sppb.StructType_Field{n, t} +} + +func structType(fields ...*sppb.StructType_Field) *sppb.Type { + return &sppb.Type{Code: sppb.TypeCode_STRUCT, StructType: &sppb.StructType{Fields: fields}} +} + +func nullProto() *proto3.Value { + return &proto3.Value{Kind: &proto3.Value_NullValue{NullValue: proto3.NullValue_NULL_VALUE}} +} diff --git a/vendor/cloud.google.com/go/spanner/read.go b/vendor/cloud.google.com/go/spanner/read.go new file mode 100644 index 000000000..5d733f1ca --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/read.go @@ -0,0 +1,685 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spanner + +import ( + "bytes" + "io" + "sync/atomic" + "time" + + log "github.com/golang/glog" + proto "github.com/golang/protobuf/proto" + proto3 "github.com/golang/protobuf/ptypes/struct" + "golang.org/x/net/context" + + "google.golang.org/api/iterator" + sppb "google.golang.org/genproto/googleapis/spanner/v1" + "google.golang.org/grpc/codes" +) + +// streamingReceiver is the interface for receiving data from a client side +// stream. +type streamingReceiver interface { + Recv() (*sppb.PartialResultSet, error) +} + +// errEarlyReadEnd returns error for read finishes when gRPC stream is still active. +func errEarlyReadEnd() error { + return spannerErrorf(codes.FailedPrecondition, "read completed with active stream") +} + +// stream is the internal fault tolerant method for streaming data from +// Cloud Spanner. +func stream(ctx context.Context, rpc func(ct context.Context, resumeToken []byte) (streamingReceiver, error), setTimestamp func(time.Time), release func(error)) *RowIterator { + ctx, cancel := context.WithCancel(ctx) + return &RowIterator{ + streamd: newResumableStreamDecoder(ctx, rpc), + rowd: &partialResultSetDecoder{}, + setTimestamp: setTimestamp, + release: release, + cancel: cancel, + } +} + +// RowIterator is an iterator over Rows. +type RowIterator struct { + streamd *resumableStreamDecoder + rowd *partialResultSetDecoder + setTimestamp func(time.Time) + release func(error) + cancel func() + err error + rows []*Row +} + +// Next returns the next result. Its second return value is iterator.Done if +// there are no more results. Once Next returns Done, all subsequent calls +// will return Done. +func (r *RowIterator) Next() (*Row, error) { + if r.err != nil { + return nil, r.err + } + for len(r.rows) == 0 && r.streamd.next() { + r.rows, r.err = r.rowd.add(r.streamd.get()) + if r.err != nil { + return nil, r.err + } + if !r.rowd.ts.IsZero() && r.setTimestamp != nil { + r.setTimestamp(r.rowd.ts) + r.setTimestamp = nil + } + } + if len(r.rows) > 0 { + row := r.rows[0] + r.rows = r.rows[1:] + return row, nil + } + if err := r.streamd.lastErr(); err != nil { + r.err = toSpannerError(err) + } else if !r.rowd.done() { + r.err = errEarlyReadEnd() + } else { + r.err = iterator.Done + } + return nil, r.err +} + +// Do calls the provided function once in sequence for each row in the iteration. If the +// function returns a non-nil error, Do immediately returns that error. +// +// If there are no rows in the iterator, Do will return nil without calling the +// provided function. +// +// Do always calls Stop on the iterator. +func (r *RowIterator) Do(f func(r *Row) error) error { + defer r.Stop() + for { + row, err := r.Next() + switch err { + case iterator.Done: + return nil + case nil: + if err = f(row); err != nil { + return err + } + default: + return err + } + } +} + +// Stop terminates the iteration. It should be called after every iteration. +func (r *RowIterator) Stop() { + if r.cancel != nil { + r.cancel() + } + if r.release != nil { + r.release(r.err) + if r.err == nil { + r.err = spannerErrorf(codes.FailedPrecondition, "Next called after Stop") + } + r.release = nil + + } +} + +// partialResultQueue implements a simple FIFO queue. The zero value is a +// valid queue. +type partialResultQueue struct { + q []*sppb.PartialResultSet + first int + last int + n int // number of elements in queue +} + +// empty returns if the partialResultQueue is empty. +func (q *partialResultQueue) empty() bool { + return q.n == 0 +} + +// errEmptyQueue returns error for dequeuing an empty queue. +func errEmptyQueue() error { + return spannerErrorf(codes.OutOfRange, "empty partialResultQueue") +} + +// peekLast returns the last item in partialResultQueue; if the queue +// is empty, it returns error. +func (q *partialResultQueue) peekLast() (*sppb.PartialResultSet, error) { + if q.empty() { + return nil, errEmptyQueue() + } + return q.q[(q.last+cap(q.q)-1)%cap(q.q)], nil +} + +// push adds an item to the tail of partialResultQueue. +func (q *partialResultQueue) push(r *sppb.PartialResultSet) { + if q.q == nil { + q.q = make([]*sppb.PartialResultSet, 8 /* arbitrary */) + } + if q.n == cap(q.q) { + buf := make([]*sppb.PartialResultSet, cap(q.q)*2) + for i := 0; i < q.n; i++ { + buf[i] = q.q[(q.first+i)%cap(q.q)] + } + q.q = buf + q.first = 0 + q.last = q.n + } + q.q[q.last] = r + q.last = (q.last + 1) % cap(q.q) + q.n++ +} + +// pop removes an item from the head of partialResultQueue and returns +// it. +func (q *partialResultQueue) pop() *sppb.PartialResultSet { + if q.n == 0 { + return nil + } + r := q.q[q.first] + q.q[q.first] = nil + q.first = (q.first + 1) % cap(q.q) + q.n-- + return r +} + +// clear empties partialResultQueue. +func (q *partialResultQueue) clear() { + *q = partialResultQueue{} +} + +// dump retrieves all items from partialResultQueue and return them in a slice. +// It is used only in tests. +func (q *partialResultQueue) dump() []*sppb.PartialResultSet { + var dq []*sppb.PartialResultSet + for i := q.first; len(dq) < q.n; i = (i + 1) % cap(q.q) { + dq = append(dq, q.q[i]) + } + return dq +} + +// resumableStreamDecoderState encodes resumableStreamDecoder's status. +// See also the comments for resumableStreamDecoder.Next. +type resumableStreamDecoderState int + +const ( + unConnected resumableStreamDecoderState = iota // 0 + queueingRetryable // 1 + queueingUnretryable // 2 + aborted // 3 + finished // 4 +) + +// resumableStreamDecoder provides a resumable interface for receiving +// sppb.PartialResultSet(s) from a given query wrapped by +// resumableStreamDecoder.rpc(). +type resumableStreamDecoder struct { + // state is the current status of resumableStreamDecoder, see also + // the comments for resumableStreamDecoder.Next. + state resumableStreamDecoderState + // stateWitness when non-nil is called to observe state change, + // used for testing. + stateWitness func(resumableStreamDecoderState) + // ctx is the caller's context, used for cancel/timeout Next(). + ctx context.Context + // rpc is a factory of streamingReceiver, which might resume + // a pervious stream from the point encoded in restartToken. + // rpc is always a wrapper of a Cloud Spanner query which is + // resumable. + rpc func(ctx context.Context, restartToken []byte) (streamingReceiver, error) + // stream is the current RPC streaming receiver. + stream streamingReceiver + // q buffers received yet undecoded partial results. + q partialResultQueue + // bytesBetweenResumeTokens is the proxy of the byte size of PartialResultSets being queued + // between two resume tokens. Once bytesBetweenResumeTokens is greater than + // maxBytesBetweenResumeTokens, resumableStreamDecoder goes into queueingUnretryable state. + bytesBetweenResumeTokens int32 + // maxBytesBetweenResumeTokens is the max number of bytes that can be buffered + // between two resume tokens. It is always copied from the global maxBytesBetweenResumeTokens + // atomically. + maxBytesBetweenResumeTokens int32 + // np is the next sppb.PartialResultSet ready to be returned + // to caller of resumableStreamDecoder.Get(). + np *sppb.PartialResultSet + // resumeToken stores the resume token that resumableStreamDecoder has + // last revealed to caller. + resumeToken []byte + // retryCount is the number of retries that have been carried out so far + retryCount int + // err is the last error resumableStreamDecoder has encountered so far. + err error + // backoff to compute delays between retries. + backoff exponentialBackoff +} + +// newResumableStreamDecoder creates a new resumeableStreamDecoder instance. +// Parameter rpc should be a function that creates a new stream +// beginning at the restartToken if non-nil. +func newResumableStreamDecoder(ctx context.Context, rpc func(ct context.Context, restartToken []byte) (streamingReceiver, error)) *resumableStreamDecoder { + return &resumableStreamDecoder{ + ctx: ctx, + rpc: rpc, + maxBytesBetweenResumeTokens: atomic.LoadInt32(&maxBytesBetweenResumeTokens), + backoff: defaultBackoff, + } +} + +// changeState fulfills state transition for resumableStateDecoder. +func (d *resumableStreamDecoder) changeState(target resumableStreamDecoderState) { + if d.state == queueingRetryable && d.state != target { + // Reset bytesBetweenResumeTokens because it is only meaningful/changed under + // queueingRetryable state. + d.bytesBetweenResumeTokens = 0 + } + d.state = target + if d.stateWitness != nil { + d.stateWitness(target) + } +} + +// isNewResumeToken returns if the observed resume token is different from +// the one returned from server last time. +func (d *resumableStreamDecoder) isNewResumeToken(rt []byte) bool { + if rt == nil { + return false + } + if bytes.Compare(rt, d.resumeToken) == 0 { + return false + } + return true +} + +// Next advances to the next available partial result set. If error or no +// more, returns false, call Err to determine if an error was encountered. +// The following diagram illustrates the state machine of resumableStreamDecoder +// that Next() implements. Note that state transition can be only triggered by +// RPC activities. +/* + rpc() fails retryable + +---------+ + | | rpc() fails unretryable/ctx timeouts or cancelled + | | +------------------------------------------------+ + | | | | + | v | v + | +---+---+---+ +--------+ +------+--+ + +-----+unConnected| |finished| | aborted |<----+ + | | ++-----+-+ +------+--+ | + +---+----+--+ ^ ^ ^ | + | ^ | | | | + | | | | recv() fails | + | | | | | | + | |recv() fails retryable | | | | + | |with valid ctx | | | | + | | | | | | + rpc() succeeds | +-----------------------+ | | | + | | | recv EOF recv EOF | | + | | | | | | + v | | Queue size exceeds | | | + +---+----+---+----+threshold +-------+-----------+ | | ++---------->+ +--------------->+ +-+ | +| |queueingRetryable| |queueingUnretryable| | +| | +<---------------+ | | +| +---+----------+--+ pop() returns +--+----+-----------+ | +| | | resume token | ^ | +| | | | | | +| | | | | | ++---------------+ | | | | + recv() succeeds | +----+ | + | recv() succeeds | + | | + | | + | | + | | + | | + +--------------------------------------------------+ + recv() fails unretryable + +*/ +var ( + // maxBytesBetweenResumeTokens is the maximum amount of bytes that resumableStreamDecoder + // in queueingRetryable state can use to queue PartialResultSets before getting + // into queueingUnretryable state. + maxBytesBetweenResumeTokens = int32(128 * 1024 * 1024) +) + +func (d *resumableStreamDecoder) next() bool { + for { + select { + case <-d.ctx.Done(): + // Do context check here so that even gRPC failed to do + // so, resumableStreamDecoder can still break the loop + // as expected. + d.err = errContextCanceled(d.ctx, d.err) + d.changeState(aborted) + default: + } + switch d.state { + case unConnected: + // If no gRPC stream is available, try to initiate one. + if d.stream, d.err = d.rpc(d.ctx, d.resumeToken); d.err != nil { + if isRetryable(d.err) { + d.doBackOff() + // Be explicit about state transition, although the + // state doesn't actually change. State transition + // will be triggered only by RPC activity, regardless of + // whether there is an actual state change or not. + d.changeState(unConnected) + continue + } + d.changeState(aborted) + continue + } + d.resetBackOff() + d.changeState(queueingRetryable) + continue + case queueingRetryable: + fallthrough + case queueingUnretryable: + // Receiving queue is not empty. + last, err := d.q.peekLast() + if err != nil { + // Only the case that receiving queue is empty could cause peekLast to + // return error and in such case, we should try to receive from stream. + d.tryRecv() + continue + } + if d.isNewResumeToken(last.ResumeToken) { + // Got new resume token, return buffered sppb.PartialResultSets to caller. + d.np = d.q.pop() + if d.q.empty() { + d.bytesBetweenResumeTokens = 0 + // The new resume token was just popped out from queue, record it. + d.resumeToken = d.np.ResumeToken + d.changeState(queueingRetryable) + } + return true + } + if d.bytesBetweenResumeTokens >= d.maxBytesBetweenResumeTokens && d.state == queueingRetryable { + d.changeState(queueingUnretryable) + continue + } + if d.state == queueingUnretryable { + // When there is no resume token observed, + // only yield sppb.PartialResultSets to caller under + // queueingUnretryable state. + d.np = d.q.pop() + return true + } + // Needs to receive more from gRPC stream till a new resume token + // is observed. + d.tryRecv() + continue + case aborted: + // Discard all pending items because none of them + // should be yield to caller. + d.q.clear() + return false + case finished: + // If query has finished, check if there are still buffered messages. + if d.q.empty() { + // No buffered PartialResultSet. + return false + } + // Although query has finished, there are still buffered PartialResultSets. + d.np = d.q.pop() + return true + + default: + log.Errorf("Unexpected resumableStreamDecoder.state: %v", d.state) + return false + } + } +} + +// tryRecv attempts to receive a PartialResultSet from gRPC stream. +func (d *resumableStreamDecoder) tryRecv() { + var res *sppb.PartialResultSet + if res, d.err = d.stream.Recv(); d.err != nil { + if d.err == io.EOF { + d.err = nil + d.changeState(finished) + return + } + if isRetryable(d.err) && d.state == queueingRetryable { + d.err = nil + // Discard all queue items (none have resume tokens). + d.q.clear() + d.stream = nil + d.changeState(unConnected) + d.doBackOff() + return + } + d.changeState(aborted) + return + } + d.q.push(res) + if d.state == queueingRetryable && !d.isNewResumeToken(res.ResumeToken) { + // adjusting d.bytesBetweenResumeTokens + d.bytesBetweenResumeTokens += int32(proto.Size(res)) + } + d.resetBackOff() + d.changeState(d.state) +} + +// resetBackOff clears the internal retry counter of +// resumableStreamDecoder so that the next exponential +// backoff will start at a fresh state. +func (d *resumableStreamDecoder) resetBackOff() { + d.retryCount = 0 +} + +// doBackoff does an exponential backoff sleep. +func (d *resumableStreamDecoder) doBackOff() { + ticker := time.NewTicker(d.backoff.delay(d.retryCount)) + defer ticker.Stop() + d.retryCount++ + select { + case <-d.ctx.Done(): + case <-ticker.C: + } +} + +// get returns the most recent PartialResultSet generated by a call to next. +func (d *resumableStreamDecoder) get() *sppb.PartialResultSet { + return d.np +} + +// lastErr returns the last non-EOF error encountered. +func (d *resumableStreamDecoder) lastErr() error { + return d.err +} + +// partialResultSetDecoder assembles PartialResultSet(s) into Cloud Spanner +// Rows. +type partialResultSetDecoder struct { + row Row + tx *sppb.Transaction + chunked bool // if true, next value should be merged with last values entry. + ts time.Time // read timestamp +} + +// yield checks we have a complete row, and if so returns it. A row is not +// complete if it doesn't have enough columns, or if this is a chunked response +// and there are no further values to process. +func (p *partialResultSetDecoder) yield(chunked, last bool) *Row { + if len(p.row.vals) == len(p.row.fields) && (!chunked || !last) { + // When partialResultSetDecoder gets enough number of + // Column values, There are two cases that a new Row + // should be yield: + // 1. The incoming PartialResultSet is not chunked; + // 2. The incoming PartialResultSet is chunked, but the + // proto3.Value being merged is not the last one in + // the PartialResultSet. + // + // Use a fresh Row to simplify clients that want to use yielded results + // after the next row is retrieved. Note that fields is never changed + // so it doesn't need to be copied. + fresh := Row{ + fields: p.row.fields, + vals: make([]*proto3.Value, len(p.row.vals)), + } + copy(fresh.vals, p.row.vals) + p.row.vals = p.row.vals[:0] // empty and reuse slice + return &fresh + } + return nil +} + +// yieldTx returns transaction information via caller supplied callback. +func errChunkedEmptyRow() error { + return spannerErrorf(codes.FailedPrecondition, "got invalid chunked PartialResultSet with empty Row") +} + +// add tries to merge a new PartialResultSet into buffered Row. It returns +// any rows that have been completed as a result. +func (p *partialResultSetDecoder) add(r *sppb.PartialResultSet) ([]*Row, error) { + var rows []*Row + if r.Metadata != nil { + // Metadata should only be returned in the first result. + if p.row.fields == nil { + p.row.fields = r.Metadata.RowType.Fields + } + if p.tx == nil && r.Metadata.Transaction != nil { + p.tx = r.Metadata.Transaction + if p.tx.ReadTimestamp != nil { + p.ts = time.Unix(p.tx.ReadTimestamp.Seconds, int64(p.tx.ReadTimestamp.Nanos)) + } + } + } + if len(r.Values) == 0 { + return nil, nil + } + if p.chunked { + p.chunked = false + // Try to merge first value in r.Values into + // uncompleted row. + last := len(p.row.vals) - 1 + if last < 0 { // sanity check + return nil, errChunkedEmptyRow() + } + var err error + // If p is chunked, then we should always try to merge p.last with r.first. + if p.row.vals[last], err = p.merge(p.row.vals[last], r.Values[0]); err != nil { + return nil, err + } + r.Values = r.Values[1:] + // Merge is done, try to yield a complete Row. + if row := p.yield(r.ChunkedValue, len(r.Values) == 0); row != nil { + rows = append(rows, row) + } + } + for i, v := range r.Values { + // The rest values in r can be appened into p directly. + p.row.vals = append(p.row.vals, v) + // Again, check to see if a complete Row can be yielded because of + // the newly added value. + if row := p.yield(r.ChunkedValue, i == len(r.Values)-1); row != nil { + rows = append(rows, row) + } + } + if r.ChunkedValue { + // After dealing with all values in r, if r is chunked then p must + // be also chunked. + p.chunked = true + } + return rows, nil +} + +// isMergeable returns if a protobuf Value can be potentially merged with +// other protobuf Values. +func (p *partialResultSetDecoder) isMergeable(a *proto3.Value) bool { + switch a.Kind.(type) { + case *proto3.Value_StringValue: + return true + case *proto3.Value_ListValue: + return true + default: + return false + } +} + +// errIncompatibleMergeTypes returns error for incompatible protobuf types +// that cannot be merged by partialResultSetDecoder. +func errIncompatibleMergeTypes(a, b *proto3.Value) error { + return spannerErrorf(codes.FailedPrecondition, "incompatible type in chunked PartialResultSet. expected (%T), got (%T)", a.Kind, b.Kind) +} + +// errUnsupportedMergeType returns error for protobuf type that cannot be +// merged to other protobufs. +func errUnsupportedMergeType(a *proto3.Value) error { + return spannerErrorf(codes.FailedPrecondition, "unsupported type merge (%T)", a.Kind) +} + +// merge tries to combine two protobuf Values if possible. +func (p *partialResultSetDecoder) merge(a, b *proto3.Value) (*proto3.Value, error) { + var err error + typeErr := errIncompatibleMergeTypes(a, b) + switch t := a.Kind.(type) { + case *proto3.Value_StringValue: + s, ok := b.Kind.(*proto3.Value_StringValue) + if !ok { + return nil, typeErr + } + return &proto3.Value{ + Kind: &proto3.Value_StringValue{StringValue: t.StringValue + s.StringValue}, + }, nil + case *proto3.Value_ListValue: + l, ok := b.Kind.(*proto3.Value_ListValue) + if !ok { + return nil, typeErr + } + if l.ListValue == nil || len(l.ListValue.Values) <= 0 { + // b is an empty list, just return a. + return a, nil + } + if t.ListValue == nil || len(t.ListValue.Values) <= 0 { + // a is an empty list, just return b. + return b, nil + } + if la := len(t.ListValue.Values) - 1; p.isMergeable(t.ListValue.Values[la]) { + // When the last item in a is of type String, + // List or Struct(encoded into List by Cloud Spanner), + // try to Merge last item in a and first item in b. + t.ListValue.Values[la], err = p.merge(t.ListValue.Values[la], l.ListValue.Values[0]) + if err != nil { + return nil, err + } + l.ListValue.Values = l.ListValue.Values[1:] + } + return &proto3.Value{ + Kind: &proto3.Value_ListValue{ + ListValue: &proto3.ListValue{ + Values: append(t.ListValue.Values, l.ListValue.Values...), + }, + }, + }, nil + default: + return nil, errUnsupportedMergeType(a) + } + +} + +// Done returns if partialResultSetDecoder has already done with all buffered +// values. +func (p *partialResultSetDecoder) done() bool { + // There is no explicit end of stream marker, but ending part way + // through a row is obviously bad, or ending with the last column still + // awaiting completion. + return len(p.row.vals) == 0 && !p.chunked +} diff --git a/vendor/cloud.google.com/go/spanner/read_test.go b/vendor/cloud.google.com/go/spanner/read_test.go new file mode 100644 index 000000000..db50110ab --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/read_test.go @@ -0,0 +1,1733 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spanner + +import ( + "errors" + "fmt" + "io" + "reflect" + "sync/atomic" + "testing" + "time" + + "golang.org/x/net/context" + + proto "github.com/golang/protobuf/proto" + proto3 "github.com/golang/protobuf/ptypes/struct" + + "cloud.google.com/go/spanner/internal/testutil" + "google.golang.org/api/iterator" + sppb "google.golang.org/genproto/googleapis/spanner/v1" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" +) + +var ( + // Mocked transaction timestamp. + trxTs = time.Unix(1, 2) + // Metadata for mocked KV table, its rows are returned by SingleUse transactions. + kvMeta = func() *sppb.ResultSetMetadata { + meta := testutil.KvMeta + meta.Transaction = &sppb.Transaction{ + ReadTimestamp: timestampProto(trxTs), + } + return &meta + }() + // Metadata for mocked ListKV table, which uses List for its key and value. + // Its rows are returned by snapshot readonly transactions, as indicated in the transaction metadata. + kvListMeta = &sppb.ResultSetMetadata{ + RowType: &sppb.StructType{ + Fields: []*sppb.StructType_Field{ + { + Name: "Key", + Type: &sppb.Type{ + Code: sppb.TypeCode_ARRAY, + ArrayElementType: &sppb.Type{ + Code: sppb.TypeCode_STRING, + }, + }, + }, + { + Name: "Value", + Type: &sppb.Type{ + Code: sppb.TypeCode_ARRAY, + ArrayElementType: &sppb.Type{ + Code: sppb.TypeCode_STRING, + }, + }, + }, + }, + }, + Transaction: &sppb.Transaction{ + Id: transactionID{5, 6, 7, 8, 9}, + ReadTimestamp: timestampProto(trxTs), + }, + } + // Metadata for mocked schema of a query result set, which has two struct + // columns named "Col1" and "Col2", the struct's schema is like the + // following: + // + // STRUCT { + // INT + // LIST<STRING> + // } + // + // Its rows are returned in readwrite transaction, as indicated in the transaction metadata. + kvObjectMeta = &sppb.ResultSetMetadata{ + RowType: &sppb.StructType{ + Fields: []*sppb.StructType_Field{ + { + Name: "Col1", + Type: &sppb.Type{ + Code: sppb.TypeCode_STRUCT, + StructType: &sppb.StructType{ + Fields: []*sppb.StructType_Field{ + { + Name: "foo-f1", + Type: &sppb.Type{ + Code: sppb.TypeCode_INT64, + }, + }, + { + Name: "foo-f2", + Type: &sppb.Type{ + Code: sppb.TypeCode_ARRAY, + ArrayElementType: &sppb.Type{ + Code: sppb.TypeCode_STRING, + }, + }, + }, + }, + }, + }, + }, + { + Name: "Col2", + Type: &sppb.Type{ + Code: sppb.TypeCode_STRUCT, + StructType: &sppb.StructType{ + Fields: []*sppb.StructType_Field{ + { + Name: "bar-f1", + Type: &sppb.Type{ + Code: sppb.TypeCode_INT64, + }, + }, + { + Name: "bar-f2", + Type: &sppb.Type{ + Code: sppb.TypeCode_ARRAY, + ArrayElementType: &sppb.Type{ + Code: sppb.TypeCode_STRING, + }, + }, + }, + }, + }, + }, + }, + }, + }, + Transaction: &sppb.Transaction{ + Id: transactionID{1, 2, 3, 4, 5}, + }, + } +) + +// String implements fmt.stringer. +func (r *Row) String() string { + return fmt.Sprintf("{fields: %s, val: %s}", r.fields, r.vals) +} + +func describeRows(l []*Row) string { + // generate a nice test failure description + var s = "[" + for i, r := range l { + if i != 0 { + s += ",\n " + } + s += fmt.Sprint(r) + } + s += "]" + return s +} + +// Helper for generating proto3 Value_ListValue instances, making +// test code shorter and readable. +func genProtoListValue(v ...string) *proto3.Value_ListValue { + r := &proto3.Value_ListValue{ + ListValue: &proto3.ListValue{ + Values: []*proto3.Value{}, + }, + } + for _, e := range v { + r.ListValue.Values = append( + r.ListValue.Values, + &proto3.Value{ + Kind: &proto3.Value_StringValue{StringValue: e}, + }, + ) + } + return r +} + +// Test Row generation logics of partialResultSetDecoder. +func TestPartialResultSetDecoder(t *testing.T) { + restore := setMaxBytesBetweenResumeTokens() + defer restore() + var tests = []struct { + input []*sppb.PartialResultSet + wantF []*Row + wantTxID transactionID + wantTs time.Time + wantD bool + }{ + { + // Empty input. + wantD: true, + }, + // String merging examples. + { + // Single KV result. + input: []*sppb.PartialResultSet{ + { + Metadata: kvMeta, + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: "foo"}}, + {Kind: &proto3.Value_StringValue{StringValue: "bar"}}, + }, + }, + }, + wantF: []*Row{ + { + fields: kvMeta.RowType.Fields, + vals: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: "foo"}}, + {Kind: &proto3.Value_StringValue{StringValue: "bar"}}, + }, + }, + }, + wantTs: trxTs, + wantD: true, + }, + { + // Incomplete partial result. + input: []*sppb.PartialResultSet{ + { + Metadata: kvMeta, + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: "foo"}}, + }, + }, + }, + wantTs: trxTs, + wantD: false, + }, + { + // Complete splitted result. + input: []*sppb.PartialResultSet{ + { + Metadata: kvMeta, + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: "foo"}}, + }, + }, + { + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: "bar"}}, + }, + }, + }, + wantF: []*Row{ + { + fields: kvMeta.RowType.Fields, + vals: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: "foo"}}, + {Kind: &proto3.Value_StringValue{StringValue: "bar"}}, + }, + }, + }, + wantTs: trxTs, + wantD: true, + }, + { + // Multi-row example with splitted row in the middle. + input: []*sppb.PartialResultSet{ + { + Metadata: kvMeta, + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: "foo"}}, + {Kind: &proto3.Value_StringValue{StringValue: "bar"}}, + {Kind: &proto3.Value_StringValue{StringValue: "A"}}, + }, + }, + { + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: "1"}}, + {Kind: &proto3.Value_StringValue{StringValue: "B"}}, + {Kind: &proto3.Value_StringValue{StringValue: "2"}}, + }, + }, + }, + wantF: []*Row{ + { + fields: kvMeta.RowType.Fields, + vals: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: "foo"}}, + {Kind: &proto3.Value_StringValue{StringValue: "bar"}}, + }, + }, + { + fields: kvMeta.RowType.Fields, + vals: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: "A"}}, + {Kind: &proto3.Value_StringValue{StringValue: "1"}}, + }, + }, + { + fields: kvMeta.RowType.Fields, + vals: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: "B"}}, + {Kind: &proto3.Value_StringValue{StringValue: "2"}}, + }, + }, + }, + wantTs: trxTs, + wantD: true, + }, + { + // Merging example in result_set.proto. + input: []*sppb.PartialResultSet{ + { + Metadata: kvMeta, + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: "Hello"}}, + {Kind: &proto3.Value_StringValue{StringValue: "W"}}, + }, + ChunkedValue: true, + }, + { + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: "orl"}}, + }, + ChunkedValue: true, + }, + { + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: "d"}}, + }, + }, + }, + wantF: []*Row{ + { + fields: kvMeta.RowType.Fields, + vals: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: "Hello"}}, + {Kind: &proto3.Value_StringValue{StringValue: "World"}}, + }, + }, + }, + wantTs: trxTs, + wantD: true, + }, + { + // More complex example showing completing a merge and + // starting a new merge in the same partialResultSet. + input: []*sppb.PartialResultSet{ + { + Metadata: kvMeta, + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: "Hello"}}, + {Kind: &proto3.Value_StringValue{StringValue: "W"}}, // start split in value + }, + ChunkedValue: true, + }, + { + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: "orld"}}, // complete value + {Kind: &proto3.Value_StringValue{StringValue: "i"}}, // start split in key + }, + ChunkedValue: true, + }, + { + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: "s"}}, // complete key + {Kind: &proto3.Value_StringValue{StringValue: "not"}}, + {Kind: &proto3.Value_StringValue{StringValue: "a"}}, + {Kind: &proto3.Value_StringValue{StringValue: "qu"}}, // split in value + }, + ChunkedValue: true, + }, + { + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: "estion"}}, // complete value + }, + }, + }, + wantF: []*Row{ + { + fields: kvMeta.RowType.Fields, + vals: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: "Hello"}}, + {Kind: &proto3.Value_StringValue{StringValue: "World"}}, + }, + }, + { + fields: kvMeta.RowType.Fields, + vals: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: "is"}}, + {Kind: &proto3.Value_StringValue{StringValue: "not"}}, + }, + }, + { + fields: kvMeta.RowType.Fields, + vals: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: "a"}}, + {Kind: &proto3.Value_StringValue{StringValue: "question"}}, + }, + }, + }, + wantTs: trxTs, + wantD: true, + }, + // List merging examples. + { + // Non-splitting Lists. + input: []*sppb.PartialResultSet{ + { + Metadata: kvListMeta, + Values: []*proto3.Value{ + { + Kind: genProtoListValue("foo-1", "foo-2"), + }, + }, + }, + { + Values: []*proto3.Value{ + { + Kind: genProtoListValue("bar-1", "bar-2"), + }, + }, + }, + }, + wantF: []*Row{ + { + fields: kvListMeta.RowType.Fields, + vals: []*proto3.Value{ + { + Kind: genProtoListValue("foo-1", "foo-2"), + }, + { + Kind: genProtoListValue("bar-1", "bar-2"), + }, + }, + }, + }, + wantTxID: transactionID{5, 6, 7, 8, 9}, + wantTs: trxTs, + wantD: true, + }, + { + // Simple List merge case: splitted string element. + input: []*sppb.PartialResultSet{ + { + Metadata: kvListMeta, + Values: []*proto3.Value{ + { + Kind: genProtoListValue("foo-1", "foo-"), + }, + }, + ChunkedValue: true, + }, + { + Values: []*proto3.Value{ + { + Kind: genProtoListValue("2"), + }, + }, + }, + { + Values: []*proto3.Value{ + { + Kind: genProtoListValue("bar-1", "bar-2"), + }, + }, + }, + }, + wantF: []*Row{ + { + fields: kvListMeta.RowType.Fields, + vals: []*proto3.Value{ + { + Kind: genProtoListValue("foo-1", "foo-2"), + }, + { + Kind: genProtoListValue("bar-1", "bar-2"), + }, + }, + }, + }, + wantTxID: transactionID{5, 6, 7, 8, 9}, + wantTs: trxTs, + wantD: true, + }, + { + // Struct merging is also implemented by List merging. Note that + // Cloud Spanner uses proto.ListValue to encode Structs as well. + input: []*sppb.PartialResultSet{ + { + Metadata: kvObjectMeta, + Values: []*proto3.Value{ + { + Kind: &proto3.Value_ListValue{ + ListValue: &proto3.ListValue{ + Values: []*proto3.Value{ + {Kind: &proto3.Value_NumberValue{NumberValue: 23}}, + {Kind: genProtoListValue("foo-1", "fo")}, + }, + }, + }, + }, + }, + ChunkedValue: true, + }, + { + Values: []*proto3.Value{ + { + Kind: &proto3.Value_ListValue{ + ListValue: &proto3.ListValue{ + Values: []*proto3.Value{ + {Kind: genProtoListValue("o-2", "f")}, + }, + }, + }, + }, + }, + ChunkedValue: true, + }, + { + Values: []*proto3.Value{ + { + Kind: &proto3.Value_ListValue{ + ListValue: &proto3.ListValue{ + Values: []*proto3.Value{ + {Kind: genProtoListValue("oo-3")}, + }, + }, + }, + }, + { + Kind: &proto3.Value_ListValue{ + ListValue: &proto3.ListValue{ + Values: []*proto3.Value{ + {Kind: &proto3.Value_NumberValue{NumberValue: 45}}, + {Kind: genProtoListValue("bar-1")}, + }, + }, + }, + }, + }, + }, + }, + wantF: []*Row{ + { + fields: kvObjectMeta.RowType.Fields, + vals: []*proto3.Value{ + { + Kind: &proto3.Value_ListValue{ + ListValue: &proto3.ListValue{ + Values: []*proto3.Value{ + {Kind: &proto3.Value_NumberValue{NumberValue: 23}}, + {Kind: genProtoListValue("foo-1", "foo-2", "foo-3")}, + }, + }, + }, + }, + { + Kind: &proto3.Value_ListValue{ + ListValue: &proto3.ListValue{ + Values: []*proto3.Value{ + {Kind: &proto3.Value_NumberValue{NumberValue: 45}}, + {Kind: genProtoListValue("bar-1")}, + }, + }, + }, + }, + }, + }, + }, + wantTxID: transactionID{1, 2, 3, 4, 5}, + wantD: true, + }, + } + +nextTest: + for i, test := range tests { + var rows []*Row + p := &partialResultSetDecoder{} + for j, v := range test.input { + rs, err := p.add(v) + if err != nil { + t.Errorf("test %d.%d: partialResultSetDecoder.add(%v) = %v; want nil", i, j, v, err) + continue nextTest + } + rows = append(rows, rs...) + } + if !reflect.DeepEqual(p.ts, test.wantTs) { + t.Errorf("got transaction(%v), want %v", p.ts, test.wantTs) + } + if !reflect.DeepEqual(rows, test.wantF) { + t.Errorf("test %d: rows=\n%v\n; want\n%v\n; p.row:\n%v\n", i, describeRows(rows), describeRows(test.wantF), p.row) + } + if got := p.done(); got != test.wantD { + t.Errorf("test %d: partialResultSetDecoder.done() = %v", i, got) + } + } +} + +const ( + maxBuffers = 16 // max number of PartialResultSets that will be buffered in tests. +) + +// setMaxBytesBetweenResumeTokens sets the global maxBytesBetweenResumeTokens to a smaller +// value more suitable for tests. It returns a function which should be called to restore +// the maxBytesBetweenResumeTokens to its old value +func setMaxBytesBetweenResumeTokens() func() { + o := atomic.LoadInt32(&maxBytesBetweenResumeTokens) + atomic.StoreInt32(&maxBytesBetweenResumeTokens, int32(maxBuffers*proto.Size(&sppb.PartialResultSet{ + Metadata: kvMeta, + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}}, + {Kind: &proto3.Value_StringValue{StringValue: valStr(0)}}, + }, + }))) + return func() { + atomic.StoreInt32(&maxBytesBetweenResumeTokens, o) + } +} + +// keyStr generates key string for kvMeta schema. +func keyStr(i int) string { + return fmt.Sprintf("foo-%02d", i) +} + +// valStr generates value string for kvMeta schema. +func valStr(i int) string { + return fmt.Sprintf("bar-%02d", i) +} + +// Test state transitions of resumableStreamDecoder where state machine +// ends up to a non-blocking state(resumableStreamDecoder.Next returns +// on non-blocking state). +func TestRsdNonblockingStates(t *testing.T) { + restore := setMaxBytesBetweenResumeTokens() + defer restore() + tests := []struct { + name string + msgs []testutil.MockCtlMsg + rpc func(ct context.Context, resumeToken []byte) (streamingReceiver, error) + sql string + // Expected values + want []*sppb.PartialResultSet // PartialResultSets that should be returned to caller + queue []*sppb.PartialResultSet // PartialResultSets that should be buffered + resumeToken []byte // Resume token that is maintained by resumableStreamDecoder + stateHistory []resumableStreamDecoderState // State transition history of resumableStreamDecoder + wantErr error + }{ + { + // unConnected->queueingRetryable->finished + name: "unConnected->queueingRetryable->finished", + msgs: []testutil.MockCtlMsg{ + {}, + {}, + {Err: io.EOF, ResumeToken: false}, + }, + sql: "SELECT t.key key, t.value value FROM t_mock t", + want: []*sppb.PartialResultSet{ + { + Metadata: kvMeta, + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}}, + {Kind: &proto3.Value_StringValue{StringValue: valStr(0)}}, + }, + }, + }, + queue: []*sppb.PartialResultSet{ + { + Metadata: kvMeta, + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}}, + {Kind: &proto3.Value_StringValue{StringValue: valStr(1)}}, + }, + }, + }, + stateHistory: []resumableStreamDecoderState{ + queueingRetryable, // do RPC + queueingRetryable, // got foo-00 + queueingRetryable, // got foo-01 + finished, // got EOF + }, + }, + { + // unConnected->queueingRetryable->aborted + name: "unConnected->queueingRetryable->aborted", + msgs: []testutil.MockCtlMsg{ + {}, + {Err: nil, ResumeToken: true}, + {}, + {Err: errors.New("I quit"), ResumeToken: false}, + }, + sql: "SELECT t.key key, t.value value FROM t_mock t", + want: []*sppb.PartialResultSet{ + { + Metadata: kvMeta, + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}}, + {Kind: &proto3.Value_StringValue{StringValue: valStr(0)}}, + }, + }, + { + Metadata: kvMeta, + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}}, + {Kind: &proto3.Value_StringValue{StringValue: valStr(1)}}, + }, + ResumeToken: testutil.EncodeResumeToken(1), + }, + }, + stateHistory: []resumableStreamDecoderState{ + queueingRetryable, // do RPC + queueingRetryable, // got foo-00 + queueingRetryable, // got foo-01 + queueingRetryable, // foo-01, resume token + queueingRetryable, // got foo-02 + aborted, // got error + }, + wantErr: grpc.Errorf(codes.Unknown, "I quit"), + }, + { + // unConnected->queueingRetryable->queueingUnretryable->queueingUnretryable + name: "unConnected->queueingRetryable->queueingUnretryable->queueingUnretryable", + msgs: func() (m []testutil.MockCtlMsg) { + for i := 0; i < maxBuffers+1; i++ { + m = append(m, testutil.MockCtlMsg{}) + } + return m + }(), + sql: "SELECT t.key key, t.value value FROM t_mock t", + want: func() (s []*sppb.PartialResultSet) { + for i := 0; i < maxBuffers+1; i++ { + s = append(s, &sppb.PartialResultSet{ + Metadata: kvMeta, + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}}, + {Kind: &proto3.Value_StringValue{StringValue: valStr(i)}}, + }, + }) + } + return s + }(), + stateHistory: func() (s []resumableStreamDecoderState) { + s = append(s, queueingRetryable) // RPC + for i := 0; i < maxBuffers; i++ { + s = append(s, queueingRetryable) // the internal queue of resumableStreamDecoder fills up + } + // the first item fills up the queue and triggers state transition; + // the second item is received under queueingUnretryable state. + s = append(s, queueingUnretryable) + s = append(s, queueingUnretryable) + return s + }(), + }, + { + // unConnected->queueingRetryable->queueingUnretryable->aborted + name: "unConnected->queueingRetryable->queueingUnretryable->aborted", + msgs: func() (m []testutil.MockCtlMsg) { + for i := 0; i < maxBuffers; i++ { + m = append(m, testutil.MockCtlMsg{}) + } + m = append(m, testutil.MockCtlMsg{Err: errors.New("Just Abort It"), ResumeToken: false}) + return m + }(), + sql: "SELECT t.key key, t.value value FROM t_mock t", + want: func() (s []*sppb.PartialResultSet) { + for i := 0; i < maxBuffers; i++ { + s = append(s, &sppb.PartialResultSet{ + Metadata: kvMeta, + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}}, + {Kind: &proto3.Value_StringValue{StringValue: valStr(i)}}, + }, + }) + } + return s + }(), + stateHistory: func() (s []resumableStreamDecoderState) { + s = append(s, queueingRetryable) // RPC + for i := 0; i < maxBuffers; i++ { + s = append(s, queueingRetryable) // internal queue of resumableStreamDecoder fills up + } + s = append(s, queueingUnretryable) // the last row triggers state change + s = append(s, aborted) // Error happens + return s + }(), + wantErr: grpc.Errorf(codes.Unknown, "Just Abort It"), + }, + } +nextTest: + for _, test := range tests { + ms := testutil.NewMockCloudSpanner(t, trxTs) + ms.Serve() + opts := []grpc.DialOption{ + grpc.WithInsecure(), + } + cc, err := grpc.Dial(ms.Addr(), opts...) + if err != nil { + t.Fatalf("%v: Dial(%q) = %v", test.name, ms.Addr(), err) + } + mc := sppb.NewSpannerClient(cc) + if test.rpc == nil { + test.rpc = func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { + return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ + Sql: test.sql, + ResumeToken: resumeToken, + }) + } + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + r := newResumableStreamDecoder( + ctx, + test.rpc, + ) + st := []resumableStreamDecoderState{} + var lastErr error + // Once the expected number of state transitions are observed, + // send a signal by setting stateDone = true. + stateDone := false + // Set stateWitness to listen to state changes. + hl := len(test.stateHistory) // To avoid data race on test. + r.stateWitness = func(rs resumableStreamDecoderState) { + if !stateDone { + // Record state transitions. + st = append(st, rs) + if len(st) == hl { + lastErr = r.lastErr() + stateDone = true + } + } + } + // Let mock server stream given messages to resumableStreamDecoder. + for _, m := range test.msgs { + ms.AddMsg(m.Err, m.ResumeToken) + } + var rs []*sppb.PartialResultSet + for { + select { + case <-ctx.Done(): + t.Errorf("context cancelled or timeout during test") + continue nextTest + default: + } + if stateDone { + // Check if resumableStreamDecoder carried out expected + // state transitions. + if !reflect.DeepEqual(st, test.stateHistory) { + t.Errorf("%v: observed state transitions: \n%v\n, want \n%v\n", + test.name, st, test.stateHistory) + } + // Check if resumableStreamDecoder returns expected array of + // PartialResultSets. + if !reflect.DeepEqual(rs, test.want) { + t.Errorf("%v: received PartialResultSets: \n%v\n, want \n%v\n", test.name, rs, test.want) + } + // Verify that resumableStreamDecoder's internal buffering is also correct. + var q []*sppb.PartialResultSet + for { + item := r.q.pop() + if item == nil { + break + } + q = append(q, item) + } + if !reflect.DeepEqual(q, test.queue) { + t.Errorf("%v: PartialResultSets still queued: \n%v\n, want \n%v\n", test.name, q, test.queue) + } + // Verify resume token. + if test.resumeToken != nil && !reflect.DeepEqual(r.resumeToken, test.resumeToken) { + t.Errorf("%v: Resume token is %v, want %v\n", test.name, r.resumeToken, test.resumeToken) + } + // Verify error message. + if !reflect.DeepEqual(lastErr, test.wantErr) { + t.Errorf("%v: got error %v, want %v", test.name, lastErr, test.wantErr) + } + // Proceed to next test + continue nextTest + } + // Receive next decoded item. + if r.next() { + rs = append(rs, r.get()) + } + } + } +} + +// Test state transitions of resumableStreamDecoder where state machine +// ends up to a blocking state(resumableStreamDecoder.Next blocks +// on blocking state). +func TestRsdBlockingStates(t *testing.T) { + restore := setMaxBytesBetweenResumeTokens() + defer restore() + tests := []struct { + name string + msgs []testutil.MockCtlMsg + rpc func(ct context.Context, resumeToken []byte) (streamingReceiver, error) + sql string + // Expected values + want []*sppb.PartialResultSet // PartialResultSets that should be returned to caller + queue []*sppb.PartialResultSet // PartialResultSets that should be buffered + resumeToken []byte // Resume token that is maintained by resumableStreamDecoder + stateHistory []resumableStreamDecoderState // State transition history of resumableStreamDecoder + wantErr error + }{ + { + // unConnected -> unConnected + name: "unConnected -> unConnected", + rpc: func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { + return nil, grpc.Errorf(codes.Unavailable, "trust me: server is unavailable") + }, + sql: "SELECT * from t_whatever", + stateHistory: []resumableStreamDecoderState{unConnected, unConnected, unConnected}, + wantErr: grpc.Errorf(codes.Unavailable, "trust me: server is unavailable"), + }, + { + // unConnected -> queueingRetryable + name: "unConnected -> queueingRetryable", + sql: "SELECT t.key key, t.value value FROM t_mock t", + stateHistory: []resumableStreamDecoderState{queueingRetryable}, + }, + { + // unConnected->queueingRetryable->queueingRetryable + name: "unConnected->queueingRetryable->queueingRetryable", + msgs: []testutil.MockCtlMsg{ + {}, + {Err: nil, ResumeToken: true}, + {Err: nil, ResumeToken: true}, + {}, + }, + sql: "SELECT t.key key, t.value value FROM t_mock t", + want: []*sppb.PartialResultSet{ + { + Metadata: kvMeta, + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}}, + {Kind: &proto3.Value_StringValue{StringValue: valStr(0)}}, + }, + }, + { + Metadata: kvMeta, + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}}, + {Kind: &proto3.Value_StringValue{StringValue: valStr(1)}}, + }, + ResumeToken: testutil.EncodeResumeToken(1), + }, + { + Metadata: kvMeta, + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: keyStr(2)}}, + {Kind: &proto3.Value_StringValue{StringValue: valStr(2)}}, + }, + ResumeToken: testutil.EncodeResumeToken(2), + }, + }, + queue: []*sppb.PartialResultSet{ + { + Metadata: kvMeta, + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: keyStr(3)}}, + {Kind: &proto3.Value_StringValue{StringValue: valStr(3)}}, + }, + }, + }, + resumeToken: testutil.EncodeResumeToken(2), + stateHistory: []resumableStreamDecoderState{ + queueingRetryable, // do RPC + queueingRetryable, // got foo-00 + queueingRetryable, // got foo-01 + queueingRetryable, // foo-01, resume token + queueingRetryable, // got foo-02 + queueingRetryable, // foo-02, resume token + queueingRetryable, // got foo-03 + }, + }, + { + // unConnected->queueingRetryable->queueingUnretryable->queueingRetryable->queueingRetryable + name: "unConnected->queueingRetryable->queueingUnretryable->queueingRetryable->queueingRetryable", + msgs: func() (m []testutil.MockCtlMsg) { + for i := 0; i < maxBuffers+1; i++ { + m = append(m, testutil.MockCtlMsg{}) + } + m = append(m, testutil.MockCtlMsg{Err: nil, ResumeToken: true}) + m = append(m, testutil.MockCtlMsg{}) + return m + }(), + sql: "SELECT t.key key, t.value value FROM t_mock t", + want: func() (s []*sppb.PartialResultSet) { + for i := 0; i < maxBuffers+2; i++ { + s = append(s, &sppb.PartialResultSet{ + Metadata: kvMeta, + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}}, + {Kind: &proto3.Value_StringValue{StringValue: valStr(i)}}, + }, + }) + } + s[maxBuffers+1].ResumeToken = testutil.EncodeResumeToken(maxBuffers + 1) + return s + }(), + resumeToken: testutil.EncodeResumeToken(maxBuffers + 1), + queue: []*sppb.PartialResultSet{ + { + Metadata: kvMeta, + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: keyStr(maxBuffers + 2)}}, + {Kind: &proto3.Value_StringValue{StringValue: valStr(maxBuffers + 2)}}, + }, + }, + }, + stateHistory: func() (s []resumableStreamDecoderState) { + s = append(s, queueingRetryable) // RPC + for i := 0; i < maxBuffers; i++ { + s = append(s, queueingRetryable) // internal queue of resumableStreamDecoder filles up + } + for i := maxBuffers - 1; i < maxBuffers+1; i++ { + // the first item fills up the queue and triggers state change; + // the second item is received under queueingUnretryable state. + s = append(s, queueingUnretryable) + } + s = append(s, queueingUnretryable) // got (maxBuffers+1)th row under Unretryable state + s = append(s, queueingRetryable) // (maxBuffers+1)th row has resume token + s = append(s, queueingRetryable) // (maxBuffers+2)th row has no resume token + return s + }(), + }, + { + // unConnected->queueingRetryable->queueingUnretryable->finished + name: "unConnected->queueingRetryable->queueingUnretryable->finished", + msgs: func() (m []testutil.MockCtlMsg) { + for i := 0; i < maxBuffers; i++ { + m = append(m, testutil.MockCtlMsg{}) + } + m = append(m, testutil.MockCtlMsg{Err: io.EOF, ResumeToken: false}) + return m + }(), + sql: "SELECT t.key key, t.value value FROM t_mock t", + want: func() (s []*sppb.PartialResultSet) { + for i := 0; i < maxBuffers; i++ { + s = append(s, &sppb.PartialResultSet{ + Metadata: kvMeta, + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}}, + {Kind: &proto3.Value_StringValue{StringValue: valStr(i)}}, + }, + }) + } + return s + }(), + stateHistory: func() (s []resumableStreamDecoderState) { + s = append(s, queueingRetryable) // RPC + for i := 0; i < maxBuffers; i++ { + s = append(s, queueingRetryable) // internal queue of resumableStreamDecoder fills up + } + s = append(s, queueingUnretryable) // last row triggers state change + s = append(s, finished) // query finishes + return s + }(), + }, + } + for _, test := range tests { + ms := testutil.NewMockCloudSpanner(t, trxTs) + ms.Serve() + opts := []grpc.DialOption{ + grpc.WithInsecure(), + } + cc, err := grpc.Dial(ms.Addr(), opts...) + if err != nil { + t.Fatalf("%v: Dial(%q) = %v", test.name, ms.Addr(), err) + } + mc := sppb.NewSpannerClient(cc) + if test.rpc == nil { + // Avoid using test.sql directly in closure because for loop changes test. + sql := test.sql + test.rpc = func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { + return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ + Sql: sql, + ResumeToken: resumeToken, + }) + } + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + r := newResumableStreamDecoder( + ctx, + test.rpc, + ) + // Override backoff to make the test run faster. + r.backoff = exponentialBackoff{1 * time.Nanosecond, 1 * time.Nanosecond} + // st is the set of observed state transitions. + st := []resumableStreamDecoderState{} + // q is the content of the decoder's partial result queue when expected number of state transitions are done. + q := []*sppb.PartialResultSet{} + var lastErr error + // Once the expected number of state transitions are observed, + // send a signal to channel stateDone. + stateDone := make(chan int) + // Set stateWitness to listen to state changes. + hl := len(test.stateHistory) // To avoid data race on test. + r.stateWitness = func(rs resumableStreamDecoderState) { + select { + case <-stateDone: + // Noop after expected number of state transitions + default: + // Record state transitions. + st = append(st, rs) + if len(st) == hl { + lastErr = r.lastErr() + q = r.q.dump() + close(stateDone) + } + } + } + // Let mock server stream given messages to resumableStreamDecoder. + for _, m := range test.msgs { + ms.AddMsg(m.Err, m.ResumeToken) + } + var rs []*sppb.PartialResultSet + go func() { + for { + if !r.next() { + // Note that r.Next also exits on context cancel/timeout. + return + } + rs = append(rs, r.get()) + } + }() + // Verify that resumableStreamDecoder reaches expected state. + select { + case <-stateDone: // Note that at this point, receiver is still blocking on r.next(). + // Check if resumableStreamDecoder carried out expected + // state transitions. + if !reflect.DeepEqual(st, test.stateHistory) { + t.Errorf("%v: observed state transitions: \n%v\n, want \n%v\n", + test.name, st, test.stateHistory) + } + // Check if resumableStreamDecoder returns expected array of + // PartialResultSets. + if !reflect.DeepEqual(rs, test.want) { + t.Errorf("%v: received PartialResultSets: \n%v\n, want \n%v\n", test.name, rs, test.want) + } + // Verify that resumableStreamDecoder's internal buffering is also correct. + if !reflect.DeepEqual(q, test.queue) { + t.Errorf("%v: PartialResultSets still queued: \n%v\n, want \n%v\n", test.name, q, test.queue) + } + // Verify resume token. + if test.resumeToken != nil && !reflect.DeepEqual(r.resumeToken, test.resumeToken) { + t.Errorf("%v: Resume token is %v, want %v\n", test.name, r.resumeToken, test.resumeToken) + } + // Verify error message. + if !reflect.DeepEqual(lastErr, test.wantErr) { + t.Errorf("%v: got error %v, want %v", test.name, lastErr, test.wantErr) + } + case <-time.After(1 * time.Second): + t.Errorf("%v: Timeout in waiting for state change", test.name) + } + ms.Stop() + cc.Close() + } +} + +// sReceiver signals every receiving attempt through a channel, +// used by TestResumeToken to determine if the receiving of a certain +// PartialResultSet will be attempted next. +type sReceiver struct { + c chan int + rpcReceiver sppb.Spanner_ExecuteStreamingSqlClient +} + +// Recv() implements streamingReceiver.Recv for sReceiver. +func (sr *sReceiver) Recv() (*sppb.PartialResultSet, error) { + sr.c <- 1 + return sr.rpcReceiver.Recv() +} + +// waitn waits for nth receiving attempt from now on, until +// the signal for nth Recv() attempts is received or timeout. +// Note that because the way stream() works, the signal for the +// nth Recv() means that the previous n - 1 PartialResultSets +// has already been returned to caller or queued, if no error happened. +func (sr *sReceiver) waitn(n int) error { + for i := 0; i < n; i++ { + select { + case <-sr.c: + case <-time.After(10 * time.Second): + return fmt.Errorf("timeout in waiting for %v-th Recv()", i+1) + } + } + return nil +} + +// Test the handling of resumableStreamDecoder.bytesBetweenResumeTokens. +func TestQueueBytes(t *testing.T) { + restore := setMaxBytesBetweenResumeTokens() + defer restore() + ms := testutil.NewMockCloudSpanner(t, trxTs) + ms.Serve() + defer ms.Stop() + opts := []grpc.DialOption{ + grpc.WithInsecure(), + } + cc, err := grpc.Dial(ms.Addr(), opts...) + if err != nil { + t.Fatalf("Dial(%q) = %v", ms.Addr(), err) + } + defer cc.Close() + mc := sppb.NewSpannerClient(cc) + sr := &sReceiver{ + c: make(chan int, 1000), // will never block in this test + } + wantQueueBytes := 0 + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + r := newResumableStreamDecoder( + ctx, + func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { + r, err := mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ + Sql: "SELECT t.key key, t.value value FROM t_mock t", + ResumeToken: resumeToken, + }) + sr.rpcReceiver = r + return sr, err + }, + ) + go func() { + for r.next() { + } + }() + // Let server send maxBuffers / 2 rows. + for i := 0; i < maxBuffers/2; i++ { + wantQueueBytes += proto.Size(&sppb.PartialResultSet{ + Metadata: kvMeta, + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}}, + {Kind: &proto3.Value_StringValue{StringValue: valStr(i)}}, + }, + }) + ms.AddMsg(nil, false) + } + if err := sr.waitn(maxBuffers/2 + 1); err != nil { + t.Fatalf("failed to wait for the first %v recv() calls: %v", maxBuffers, err) + } + if int32(wantQueueBytes) != r.bytesBetweenResumeTokens { + t.Errorf("r.bytesBetweenResumeTokens = %v, want %v", r.bytesBetweenResumeTokens, wantQueueBytes) + } + // Now send a resume token to drain the queue. + ms.AddMsg(nil, true) + // Wait for all rows to be processes. + if err := sr.waitn(1); err != nil { + t.Fatalf("failed to wait for rows to be processed: %v", err) + } + if r.bytesBetweenResumeTokens != 0 { + t.Errorf("r.bytesBetweenResumeTokens = %v, want 0", r.bytesBetweenResumeTokens) + } + // Let server send maxBuffers - 1 rows. + wantQueueBytes = 0 + for i := 0; i < maxBuffers-1; i++ { + wantQueueBytes += proto.Size(&sppb.PartialResultSet{ + Metadata: kvMeta, + Values: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}}, + {Kind: &proto3.Value_StringValue{StringValue: valStr(i)}}, + }, + }) + ms.AddMsg(nil, false) + } + if err := sr.waitn(maxBuffers - 1); err != nil { + t.Fatalf("failed to wait for %v rows to be processed: %v", maxBuffers-1, err) + } + if int32(wantQueueBytes) != r.bytesBetweenResumeTokens { + t.Errorf("r.bytesBetweenResumeTokens = %v, want 0", r.bytesBetweenResumeTokens) + } + // Trigger a state transition: queueingRetryable -> queueingUnretryable. + ms.AddMsg(nil, false) + if err := sr.waitn(1); err != nil { + t.Fatalf("failed to wait for state transition: %v", err) + } + if r.bytesBetweenResumeTokens != 0 { + t.Errorf("r.bytesBetweenResumeTokens = %v, want 0", r.bytesBetweenResumeTokens) + } +} + +// Verify that client can deal with resume token correctly +func TestResumeToken(t *testing.T) { + restore := setMaxBytesBetweenResumeTokens() + defer restore() + ms := testutil.NewMockCloudSpanner(t, trxTs) + ms.Serve() + opts := []grpc.DialOption{ + grpc.WithInsecure(), + } + cc, err := grpc.Dial(ms.Addr(), opts...) + if err != nil { + t.Fatalf("Dial(%q) = %v", ms.Addr(), err) + } + defer func() { + ms.Stop() + cc.Close() + }() + mc := sppb.NewSpannerClient(cc) + sr := &sReceiver{ + c: make(chan int, 1000), // will never block in this test + } + rows := []*Row{} + done := make(chan int) + streaming := func() { + // Establish a stream to mock cloud spanner server. + iter := stream(context.Background(), + func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { + r, err := mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ + Sql: "SELECT t.key key, t.value value FROM t_mock t", + ResumeToken: resumeToken, + }) + sr.rpcReceiver = r + return sr, err + }, + nil, + func(error) {}) + defer iter.Stop() + for { + var row *Row + row, err = iter.Next() + if err == iterator.Done { + err = nil + break + } + if err != nil { + break + } + rows = append(rows, row) + } + done <- 1 + } + go streaming() + // Server streaming row 0 - 2, only row 1 has resume token. + // Client will receive row 0 - 2, so it will try receiving for + // 4 times (the last recv will block), and only row 0 - 1 will + // be yielded. + for i := 0; i < 3; i++ { + if i == 1 { + ms.AddMsg(nil, true) + } else { + ms.AddMsg(nil, false) + } + } + // Wait for 4 receive attempts, as explained above. + if err = sr.waitn(4); err != nil { + t.Fatalf("failed to wait for row 0 - 2: %v", err) + } + want := []*Row{ + { + fields: kvMeta.RowType.Fields, + vals: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}}, + {Kind: &proto3.Value_StringValue{StringValue: valStr(0)}}, + }, + }, + { + fields: kvMeta.RowType.Fields, + vals: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}}, + {Kind: &proto3.Value_StringValue{StringValue: valStr(1)}}, + }, + }, + } + if !reflect.DeepEqual(rows, want) { + t.Errorf("received rows: \n%v\n; but want\n%v\n", rows, want) + } + // Inject resumable failure. + ms.AddMsg( + grpc.Errorf(codes.Unavailable, "mock server unavailable"), + false, + ) + // Test if client detects the resumable failure and retries. + if err = sr.waitn(1); err != nil { + t.Fatalf("failed to wait for client to retry: %v", err) + } + // Client has resumed the query, now server resend row 2. + ms.AddMsg(nil, true) + if err = sr.waitn(1); err != nil { + t.Fatalf("failed to wait for resending row 2: %v", err) + } + // Now client should have received row 0 - 2. + want = append(want, &Row{ + fields: kvMeta.RowType.Fields, + vals: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: keyStr(2)}}, + {Kind: &proto3.Value_StringValue{StringValue: valStr(2)}}, + }, + }) + if !reflect.DeepEqual(rows, want) { + t.Errorf("received rows: \n%v\n, want\n%v\n", rows, want) + } + // Sending 3rd - (maxBuffers+1)th rows without resume tokens, client should buffer them. + for i := 3; i < maxBuffers+2; i++ { + ms.AddMsg(nil, false) + } + if err = sr.waitn(maxBuffers - 1); err != nil { + t.Fatalf("failed to wait for row 3-%v: %v", maxBuffers+1, err) + } + // Received rows should be unchanged. + if !reflect.DeepEqual(rows, want) { + t.Errorf("receive rows: \n%v\n, want\n%v\n", rows, want) + } + // Send (maxBuffers+2)th row to trigger state change of resumableStreamDecoder: + // queueingRetryable -> queueingUnretryable + ms.AddMsg(nil, false) + if err = sr.waitn(1); err != nil { + t.Fatalf("failed to wait for row %v: %v", maxBuffers+2, err) + } + // Client should yield row 3rd - (maxBuffers+2)th to application. Therefore, application should + // see row 0 - (maxBuffers+2)th so far. + for i := 3; i < maxBuffers+3; i++ { + want = append(want, &Row{ + fields: kvMeta.RowType.Fields, + vals: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}}, + {Kind: &proto3.Value_StringValue{StringValue: valStr(i)}}, + }, + }) + } + if !reflect.DeepEqual(rows, want) { + t.Errorf("received rows: \n%v\n; want\n%v\n", rows, want) + } + // Inject resumable error, but since resumableStreamDecoder is already at queueingUnretryable + // state, query will just fail. + ms.AddMsg( + grpc.Errorf(codes.Unavailable, "mock server wants some sleep"), + false, + ) + select { + case <-done: + case <-time.After(10 * time.Second): + t.Fatalf("timeout in waiting for failed query to return.") + } + if wantErr := toSpannerError(grpc.Errorf(codes.Unavailable, "mock server wants some sleep")); !reflect.DeepEqual(err, wantErr) { + t.Fatalf("stream() returns error: %v, but want error: %v", err, wantErr) + } + + // Reconnect to mock Cloud Spanner. + rows = []*Row{} + go streaming() + // Let server send two rows without resume token. + for i := maxBuffers + 3; i < maxBuffers+5; i++ { + ms.AddMsg(nil, false) + } + if err = sr.waitn(3); err != nil { + t.Fatalf("failed to wait for row %v - %v: %v", maxBuffers+3, maxBuffers+5, err) + } + if len(rows) > 0 { + t.Errorf("client received some rows unexpectedly: %v, want nothing", rows) + } + // Let server end the query. + ms.AddMsg(io.EOF, false) + select { + case <-done: + case <-time.After(10 * time.Second): + t.Fatalf("timeout in waiting for failed query to return") + } + if err != nil { + t.Fatalf("stream() returns unexpected error: %v, but want no error", err) + } + // Verify if a normal server side EOF flushes all queued rows. + want = []*Row{ + { + fields: kvMeta.RowType.Fields, + vals: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: keyStr(maxBuffers + 3)}}, + {Kind: &proto3.Value_StringValue{StringValue: valStr(maxBuffers + 3)}}, + }, + }, + { + fields: kvMeta.RowType.Fields, + vals: []*proto3.Value{ + {Kind: &proto3.Value_StringValue{StringValue: keyStr(maxBuffers + 4)}}, + {Kind: &proto3.Value_StringValue{StringValue: valStr(maxBuffers + 4)}}, + }, + }, + } + if !reflect.DeepEqual(rows, want) { + t.Errorf("received rows: \n%v\n; but want\n%v\n", rows, want) + } +} + +// Verify that streaming query get retried upon real gRPC server transport failures. +func TestGrpcReconnect(t *testing.T) { + restore := setMaxBytesBetweenResumeTokens() + defer restore() + ms := testutil.NewMockCloudSpanner(t, trxTs) + ms.Serve() + defer ms.Stop() + cc, err := grpc.Dial(ms.Addr(), grpc.WithInsecure()) + if err != nil { + t.Fatalf("Dial(%q) = %v", ms.Addr(), err) + } + defer cc.Close() + mc := sppb.NewSpannerClient(cc) + retry := make(chan int) + row := make(chan int) + go func() { + r := 0 + // Establish a stream to mock cloud spanner server. + iter := stream(context.Background(), + func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { + if r > 0 { + // This RPC attempt is a retry, signal it. + retry <- r + } + r++ + return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ + Sql: "SELECT t.key key, t.value value FROM t_mock t", + ResumeToken: resumeToken, + }) + + }, + nil, + func(error) {}) + defer iter.Stop() + for { + _, err = iter.Next() + if err == iterator.Done { + err = nil + break + } + if err != nil { + break + } + row <- 0 + } + }() + // Add a message and wait for the receipt. + ms.AddMsg(nil, true) + select { + case <-row: + case <-time.After(10 * time.Second): + t.Fatalf("expect stream to be established within 10 seconds, but it didn't") + } + // Error injection: force server to close all connections. + ms.Stop() + // Test to see if client respond to the real RPC failure correctly by + // retrying RPC. + select { + case r, ok := <-retry: + if ok && r == 1 { + break + } + t.Errorf("retry count = %v, want 1", r) + case <-time.After(10 * time.Second): + t.Errorf("client library failed to respond after 10 seconds, aborting") + return + } +} + +// Test cancel/timeout for client operations. +func TestCancelTimeout(t *testing.T) { + restore := setMaxBytesBetweenResumeTokens() + defer restore() + ms := testutil.NewMockCloudSpanner(t, trxTs) + ms.Serve() + defer ms.Stop() + opts := []grpc.DialOption{ + grpc.WithInsecure(), + } + cc, err := grpc.Dial(ms.Addr(), opts...) + defer cc.Close() + if err != nil { + t.Fatalf("Dial(%q) = %v", ms.Addr(), err) + } + mc := sppb.NewSpannerClient(cc) + done := make(chan int) + go func() { + for { + ms.AddMsg(nil, true) + } + }() + // Test cancelling query. + ctx, cancel := context.WithCancel(context.Background()) + go func() { + // Establish a stream to mock cloud spanner server. + iter := stream(ctx, + func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { + return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ + Sql: "SELECT t.key key, t.value value FROM t_mock t", + ResumeToken: resumeToken, + }) + }, + nil, + func(error) {}) + defer iter.Stop() + for { + _, err = iter.Next() + if err == iterator.Done { + break + } + if err != nil { + done <- 0 + break + } + } + }() + cancel() + select { + case <-done: + if ErrCode(err) != codes.Canceled { + t.Errorf("streaming query is canceled and returns error %v, want error code %v", err, codes.Canceled) + } + case <-time.After(1 * time.Second): + t.Errorf("query doesn't exit timely after being cancelled") + } + // Test query timeout. + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second) + go func() { + // Establish a stream to mock cloud spanner server. + iter := stream(ctx, + func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { + return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ + Sql: "SELECT t.key key, t.value value FROM t_mock t", + ResumeToken: resumeToken, + }) + }, + nil, + func(error) {}) + defer iter.Stop() + for { + _, err = iter.Next() + if err == iterator.Done { + err = nil + break + } + if err != nil { + break + } + } + done <- 0 + }() + select { + case <-done: + if wantErr := codes.DeadlineExceeded; ErrCode(err) != wantErr { + t.Errorf("streaming query timeout returns error %v, want error code %v", err, wantErr) + } + case <-time.After(2 * time.Second): + t.Errorf("query doesn't timeout as expected") + } +} + +func TestRowIteratorDo(t *testing.T) { + restore := setMaxBytesBetweenResumeTokens() + defer restore() + ms := testutil.NewMockCloudSpanner(t, trxTs) + ms.Serve() + defer ms.Stop() + cc, err := grpc.Dial(ms.Addr(), grpc.WithInsecure()) + if err != nil { + t.Fatalf("Dial(%q) = %v", ms.Addr(), err) + } + defer cc.Close() + mc := sppb.NewSpannerClient(cc) + + for i := 0; i < 3; i++ { + ms.AddMsg(nil, false) + } + ms.AddMsg(io.EOF, true) + nRows := 0 + iter := stream(context.Background(), + func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { + return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ + Sql: "SELECT t.key key, t.value value FROM t_mock t", + ResumeToken: resumeToken, + }) + }, + nil, + func(error) {}) + err = iter.Do(func(r *Row) error { nRows++; return nil }) + if err != nil { + t.Errorf("Using Do: %v", err) + } + if nRows != 3 { + t.Errorf("got %d rows, want 3", nRows) + } +} + +func TestIteratorStopEarly(t *testing.T) { + ctx := context.Background() + restore := setMaxBytesBetweenResumeTokens() + defer restore() + ms := testutil.NewMockCloudSpanner(t, trxTs) + ms.Serve() + defer ms.Stop() + cc, err := grpc.Dial(ms.Addr(), grpc.WithInsecure()) + if err != nil { + t.Fatalf("Dial(%q) = %v", ms.Addr(), err) + } + defer cc.Close() + mc := sppb.NewSpannerClient(cc) + + ms.AddMsg(nil, false) + ms.AddMsg(nil, false) + ms.AddMsg(io.EOF, true) + + iter := stream(ctx, + func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { + return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ + Sql: "SELECT t.key key, t.value value FROM t_mock t", + ResumeToken: resumeToken, + }) + }, + nil, + func(error) {}) + _, err = iter.Next() + if err != nil { + t.Fatalf("before Stop: %v", err) + } + iter.Stop() + // Stop sets r.err to the FailedPrecondition error "Next called after Stop". + // Override that here so this test can observe the Canceled error from the stream. + iter.err = nil + iter.Next() + if ErrCode(iter.streamd.lastErr()) != codes.Canceled { + t.Errorf("after Stop: got %v, wanted Canceled", err) + } +} + +func TestIteratorWithError(t *testing.T) { + injected := errors.New("Failed iterator") + iter := RowIterator{err: injected} + defer iter.Stop() + if _, err := iter.Next(); err != injected { + t.Fatalf("Expected error: %v, got %v", injected, err) + } +} diff --git a/vendor/cloud.google.com/go/spanner/retry.go b/vendor/cloud.google.com/go/spanner/retry.go new file mode 100644 index 000000000..6d535ef41 --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/retry.go @@ -0,0 +1,192 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spanner + +import ( + "fmt" + "strings" + "time" + + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes" + "golang.org/x/net/context" + edpb "google.golang.org/genproto/googleapis/rpc/errdetails" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" +) + +const ( + retryInfoKey = "google.rpc.retryinfo-bin" +) + +// errRetry returns an unavailable error under error namespace EsOther. It is a +// generic retryable error that is used to mask and recover unretryable errors +// in a retry loop. +func errRetry(err error) error { + if se, ok := err.(*Error); ok { + return &Error{codes.Unavailable, fmt.Sprintf("generic Cloud Spanner retryable error: { %v }", se.Error()), se.trailers} + } + return spannerErrorf(codes.Unavailable, "generic Cloud Spanner retryable error: { %v }", err.Error()) +} + +// isErrorClosing reports whether the error is generated by gRPC layer talking to a closed server. +func isErrorClosing(err error) bool { + if err == nil { + return false + } + if ErrCode(err) == codes.Internal && strings.Contains(ErrDesc(err), "transport is closing") { + // Handle the case when connection is closed unexpectedly. + // TODO: once gRPC is able to categorize + // this as retryable error, we should stop parsing the + // error message here. + return true + } + return false +} + +// isErrorRST reports whether the error is generated by gRPC client receiving a RST frame from server. +func isErrorRST(err error) bool { + if err == nil { + return false + } + if ErrCode(err) == codes.Internal && strings.Contains(ErrDesc(err), "stream terminated by RST_STREAM") { + // TODO: once gRPC is able to categorize this error as "go away" or "retryable", + // we should stop parsing the error message. + return true + } + return false +} + +// isErrorUnexpectedEOF returns true if error is generated by gRPC layer +// receiving io.EOF unexpectedly. +func isErrorUnexpectedEOF(err error) bool { + if err == nil { + return false + } + if ErrCode(err) == codes.Unknown && strings.Contains(ErrDesc(err), "unexpected EOF") { + // Unexpected EOF is an transport layer issue that + // could be recovered by retries. The most likely + // scenario is a flaky RecvMsg() call due to network + // issues. + // TODO: once gRPC is able to categorize + // this as retryable error, we should stop parsing the + // error message here. + return true + } + return false +} + +// isErrorUnavailable returns true if the error is about server being unavailable. +func isErrorUnavailable(err error) bool { + if err == nil { + return false + } + if ErrCode(err) == codes.Unavailable { + return true + } + return false +} + +// isRetryable returns true if the Cloud Spanner error being checked is a retryable error. +func isRetryable(err error) bool { + if isErrorClosing(err) { + return true + } + if isErrorUnexpectedEOF(err) { + return true + } + if isErrorRST(err) { + return true + } + if isErrorUnavailable(err) { + return true + } + return false +} + +// errContextCanceled returns *spanner.Error for canceled context. +func errContextCanceled(ctx context.Context, lastErr error) error { + if ctx.Err() == context.DeadlineExceeded { + return spannerErrorf(codes.DeadlineExceeded, "%v, lastErr is <%v>", ctx.Err(), lastErr) + } + return spannerErrorf(codes.Canceled, "%v, lastErr is <%v>", ctx.Err(), lastErr) +} + +// extractRetryDelay extracts retry backoff if present. +func extractRetryDelay(err error) (time.Duration, bool) { + trailers := errTrailers(err) + if trailers == nil { + return 0, false + } + elem, ok := trailers[retryInfoKey] + if !ok || len(elem) <= 0 { + return 0, false + } + _, b, err := metadata.DecodeKeyValue(retryInfoKey, elem[0]) + if err != nil { + return 0, false + } + var retryInfo edpb.RetryInfo + if proto.Unmarshal([]byte(b), &retryInfo) != nil { + return 0, false + } + delay, err := ptypes.Duration(retryInfo.RetryDelay) + if err != nil { + return 0, false + } + return delay, true +} + +// runRetryable keeps attempting to run f until one of the following happens: +// 1) f returns nil error or an unretryable error; +// 2) context is cancelled or timeout. +// TODO: consider using https://github.com/googleapis/gax-go once it +// becomes available internally. +func runRetryable(ctx context.Context, f func(context.Context) error) error { + var funcErr error + retryCount := 0 + for { + select { + case <-ctx.Done(): + // Do context check here so that even f() failed to do + // so (for example, gRPC implementation bug), the loop + // can still have a chance to exit as expected. + return errContextCanceled(ctx, funcErr) + default: + } + funcErr = f(ctx) + if funcErr == nil { + return nil + } + if isRetryable(funcErr) { + // Error is retryable, do exponential backoff and continue. + b, ok := extractRetryDelay(funcErr) + if !ok { + b = defaultBackoff.delay(retryCount) + } + select { + case <-ctx.Done(): + return errContextCanceled(ctx, funcErr) + case <-time.After(b): + } + retryCount++ + continue + } + // Error isn't retryable / no error, return immediately. + return toSpannerError(funcErr) + } +} diff --git a/vendor/cloud.google.com/go/spanner/retry_test.go b/vendor/cloud.google.com/go/spanner/retry_test.go new file mode 100644 index 000000000..49c5051ab --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/retry_test.go @@ -0,0 +1,108 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spanner + +import ( + "errors" + "fmt" + "reflect" + "testing" + "time" + + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes" + "golang.org/x/net/context" + edpb "google.golang.org/genproto/googleapis/rpc/errdetails" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" +) + +// Test if runRetryable loop deals with various errors correctly. +func TestRetry(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + responses := []error{ + grpc.Errorf(codes.Internal, "transport is closing"), + grpc.Errorf(codes.Unknown, "unexpected EOF"), + grpc.Errorf(codes.Internal, "stream terminated by RST_STREAM with error code: 2"), + grpc.Errorf(codes.Unavailable, "service is currently unavailable"), + errRetry(fmt.Errorf("just retry it")), + } + err := runRetryable(context.Background(), func(ct context.Context) error { + var r error + if len(responses) > 0 { + r = responses[0] + responses = responses[1:] + } + return r + }) + if err != nil { + t.Errorf("runRetryable should be able to survive all retryable errors, but it returns %v", err) + } + // Unretryable errors + injErr := errors.New("this is unretryable") + err = runRetryable(context.Background(), func(ct context.Context) error { + return injErr + }) + if wantErr := toSpannerError(injErr); !reflect.DeepEqual(err, wantErr) { + t.Errorf("runRetryable returns error %v, want %v", err, wantErr) + } + // Timeout + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + retryErr := errRetry(fmt.Errorf("still retrying")) + err = runRetryable(ctx, func(ct context.Context) error { + // Expect to trigger timeout in retryable runner after 10 executions. + <-time.After(100 * time.Millisecond) + // Let retryable runner to retry so that timeout will eventually happen. + return retryErr + }) + // Check error code and error message + if wantErrCode, wantErr := codes.DeadlineExceeded, errContextCanceled(ctx, retryErr); ErrCode(err) != wantErrCode || !reflect.DeepEqual(err, wantErr) { + t.Errorf("<err code, err>=\n<%v, %v>, want:\n<%v, %v>", ErrCode(err), err, wantErrCode, wantErr) + } + // Cancellation + ctx, cancel = context.WithCancel(context.Background()) + retries := 3 + retryErr = errRetry(fmt.Errorf("retry before cancel")) + err = runRetryable(ctx, func(ct context.Context) error { + retries-- + if retries == 0 { + cancel() + } + return retryErr + }) + // Check error code, error message, retry count + if wantErrCode, wantErr := codes.Canceled, errContextCanceled(ctx, retryErr); ErrCode(err) != wantErrCode || !reflect.DeepEqual(err, wantErr) || retries != 0 { + t.Errorf("<err code, err, retries>=\n<%v, %v, %v>, want:\n<%v, %v, %v>", ErrCode(err), err, retries, wantErrCode, wantErr, 0) + } +} + +func TestRetryInfo(t *testing.T) { + b, _ := proto.Marshal(&edpb.RetryInfo{ + RetryDelay: ptypes.DurationProto(time.Second), + }) + trailers := map[string]string{ + retryInfoKey: string(b), + } + gotDelay, ok := extractRetryDelay(errRetry(toSpannerErrorWithMetadata(grpc.Errorf(codes.Aborted, ""), metadata.New(trailers)))) + if !ok || !reflect.DeepEqual(time.Second, gotDelay) { + t.Errorf("<ok, retryDelay> = <%t, %v>, want <true, %v>", ok, gotDelay, time.Second) + } +} diff --git a/vendor/cloud.google.com/go/spanner/row.go b/vendor/cloud.google.com/go/spanner/row.go new file mode 100644 index 000000000..e8f30103d --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/row.go @@ -0,0 +1,308 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spanner + +import ( + "fmt" + "reflect" + + proto3 "github.com/golang/protobuf/ptypes/struct" + + sppb "google.golang.org/genproto/googleapis/spanner/v1" + "google.golang.org/grpc/codes" +) + +// A Row is a view of a row of data produced by a Cloud Spanner read. +// +// A row consists of a number of columns; the number depends on the columns +// used to construct the read. +// +// The column values can be accessed by index, where the indices are with +// respect to the columns. For instance, if the read specified +// []string{"photo_id", "caption", "metadata"}, then each row will +// contain three columns: the 0th column corresponds to "photo_id", the +// 1st column corresponds to "caption", etc. +// +// Column values are decoded by using one of the Column, ColumnByName, or +// Columns methods. The valid values passed to these methods depend on the +// column type. For example: +// +// var photoID int64 +// err := row.Column(0, &photoID) // Decode column 0 as an integer. +// +// var caption string +// err := row.Column(1, &caption) // Decode column 1 as a string. +// +// // The above two operations at once. +// err := row.Columns(&photoID, &caption) +// +// Supported types and their corresponding Cloud Spanner column type(s) are: +// +// *string(not NULL), *NullString - STRING +// *[]NullString - STRING ARRAY +// *[]byte - BYTES +// *[][]byte - BYTES ARRAY +// *int64(not NULL), *NullInt64 - INT64 +// *[]NullInt64 - INT64 ARRAY +// *bool(not NULL), *NullBool - BOOL +// *[]NullBool - BOOL ARRAY +// *float64(not NULL), *NullFloat64 - FLOAT64 +// *[]NullFloat64 - FLOAT64 ARRAY +// *time.Time(not NULL), *NullTime - TIMESTAMP +// *[]NullTime - TIMESTAMP ARRAY +// *Date(not NULL), *NullDate - DATE +// *[]NullDate - DATE ARRAY +// *[]*some_go_struct, *[]NullRow - STRUCT ARRAY +// *GenericColumnValue - any Cloud Spanner type +// +// For TIMESTAMP columns, returned time.Time object will be in UTC. +// +// To fetch an array of BYTES, pass a *[][]byte. To fetch an array of +// (sub)rows, pass a *[]spanner.NullRow or a *[]*some_go_struct where +// some_go_struct holds all information of the subrow, see spannr.Row.ToStruct +// for the mapping between Cloud Spanner row and Go struct. To fetch an array of +// other types, pass a *[]spanner.Null* type of the appropriate type. Use +// *GenericColumnValue when you don't know in advance what column type to +// expect. +// +// Row decodes the row contents lazily; as a result, each call to a getter has +// a chance of returning an error. +// +// A column value may be NULL if the corresponding value is not present in +// Cloud Spanner. The spanner.Null* types (spanner.NullInt64 et al.) allow fetching +// values that may be null. A NULL BYTES can be fetched into a *[]byte as nil. +// It is an error to fetch a NULL value into any other type. +type Row struct { + fields []*sppb.StructType_Field + vals []*proto3.Value // keep decoded for now +} + +// errNamesValuesMismatch returns error for when columnNames count is not equal +// to columnValues count. +func errNamesValuesMismatch(columnNames []string, columnValues []interface{}) error { + return spannerErrorf(codes.FailedPrecondition, + "different number of names(%v) and values(%v)", len(columnNames), len(columnValues)) +} + +// NewRow returns a Row containing the supplied data. This can be useful for +// mocking Cloud Spanner Read and Query responses for unit testing. +func NewRow(columnNames []string, columnValues []interface{}) (*Row, error) { + if len(columnValues) != len(columnNames) { + return nil, errNamesValuesMismatch(columnNames, columnValues) + } + r := Row{ + fields: make([]*sppb.StructType_Field, len(columnValues)), + vals: make([]*proto3.Value, len(columnValues)), + } + for i := range columnValues { + val, typ, err := encodeValue(columnValues[i]) + if err != nil { + return nil, err + } + r.fields[i] = &sppb.StructType_Field{ + Name: columnNames[i], + Type: typ, + } + r.vals[i] = val + } + return &r, nil +} + +// Size is the number of columns in the row. +func (r *Row) Size() int { + return len(r.fields) +} + +// ColumnName returns the name of column i, or empty string for invalid column. +func (r *Row) ColumnName(i int) string { + if i < 0 || i >= len(r.fields) { + return "" + } + return r.fields[i].Name +} + +// ColumnIndex returns the index of the column with the given name. The +// comparison is case-sensitive. +func (r *Row) ColumnIndex(name string) (int, error) { + found := false + var index int + if len(r.vals) != len(r.fields) { + return 0, errFieldsMismatchVals(r) + } + for i, f := range r.fields { + if f == nil { + return 0, errNilColType(i) + } + if name == f.Name { + if found { + return 0, errDupColName(name) + } + found = true + index = i + } + } + if !found { + return 0, errColNotFound(name) + } + return index, nil +} + +// ColumnNames returns all column names of the row. +func (r *Row) ColumnNames() []string { + var n []string + for _, c := range r.fields { + n = append(n, c.Name) + } + return n +} + +// errColIdxOutOfRange returns error for requested column index is out of the +// range of the target Row's columns. +func errColIdxOutOfRange(i int, r *Row) error { + return spannerErrorf(codes.OutOfRange, "column index %d out of range [0,%d)", i, len(r.vals)) +} + +// errDecodeColumn returns error for not being able to decode a indexed column. +func errDecodeColumn(i int, err error) error { + if err == nil { + return nil + } + se, ok := toSpannerError(err).(*Error) + if !ok { + return spannerErrorf(codes.InvalidArgument, "failed to decode column %v, error = <%v>", i, err) + } + se.decorate(fmt.Sprintf("failed to decode column %v", i)) + return se +} + +// errFieldsMismatchVals returns error for field count isn't equal to value count in a Row. +func errFieldsMismatchVals(r *Row) error { + return spannerErrorf(codes.FailedPrecondition, "row has different number of fields(%v) and values(%v)", + len(r.fields), len(r.vals)) +} + +// errNilColType returns error for column type for column i being nil in the row. +func errNilColType(i int) error { + return spannerErrorf(codes.FailedPrecondition, "column(%v)'s type is nil", i) +} + +// Column fetches the value from the ith column, decoding it into ptr. +// See the Row documentation for the list of acceptable argument types. +// see Client.ReadWriteTransaction for an example. +func (r *Row) Column(i int, ptr interface{}) error { + if len(r.vals) != len(r.fields) { + return errFieldsMismatchVals(r) + } + if i < 0 || i >= len(r.fields) { + return errColIdxOutOfRange(i, r) + } + if r.fields[i] == nil { + return errNilColType(i) + } + if err := decodeValue(r.vals[i], r.fields[i].Type, ptr); err != nil { + return errDecodeColumn(i, err) + } + return nil +} + +// errDupColName returns error for duplicated column name in the same row. +func errDupColName(n string) error { + return spannerErrorf(codes.FailedPrecondition, "ambiguous column name %q", n) +} + +// errColNotFound returns error for not being able to find a named column. +func errColNotFound(n string) error { + return spannerErrorf(codes.NotFound, "column %q not found", n) +} + +// ColumnByName fetches the value from the named column, decoding it into ptr. +// See the Row documentation for the list of acceptable argument types. +func (r *Row) ColumnByName(name string, ptr interface{}) error { + index, err := r.ColumnIndex(name) + if err != nil { + return err + } + return r.Column(index, ptr) +} + +// errNumOfColValue returns error for providing wrong number of values to Columns. +func errNumOfColValue(n int, r *Row) error { + return spannerErrorf(codes.InvalidArgument, + "Columns(): number of arguments (%d) does not match row size (%d)", n, len(r.vals)) +} + +// Columns fetches all the columns in the row at once. +// +// The value of the kth column will be decoded into the kth argument to +// Columns. See above for the list of acceptable argument types. The number of +// arguments must be equal to the number of columns. Pass nil to specify that a +// column should be ignored. +func (r *Row) Columns(ptrs ...interface{}) error { + if len(ptrs) != len(r.vals) { + return errNumOfColValue(len(ptrs), r) + } + if len(r.vals) != len(r.fields) { + return errFieldsMismatchVals(r) + } + for i, p := range ptrs { + if p == nil { + continue + } + if err := r.Column(i, p); err != nil { + return err + } + } + return nil +} + +// errToStructArgType returns error for p not having the correct data type(pointer to Go struct) to +// be the argument of Row.ToStruct. +func errToStructArgType(p interface{}) error { + return spannerErrorf(codes.InvalidArgument, "ToStruct(): type %T is not a valid pointer to Go struct", p) +} + +// ToStruct fetches the columns in a row into the fields of a struct. +// The rules for mapping a row's columns into a struct's exported fields +// are as the following: +// 1. If a field has a `spanner: "column_name"` tag, then decode column +// 'column_name' into the field. A special case is the `spanner: "-"` +// tag, which instructs ToStruct to ignore the field during decoding. +// 2. Otherwise, if the name of a field matches the name of a column (ignoring case), +// decode the column into the field. +// +// The fields of the destination struct can be of any type that is acceptable +// to (*spanner.Row).Column. +// +// Slice and pointer fields will be set to nil if the source column +// is NULL, and a non-nil value if the column is not NULL. To decode NULL +// values of other types, use one of the spanner.Null* as the type of the +// destination field. +func (r *Row) ToStruct(p interface{}) error { + // Check if p is a pointer to a struct + if t := reflect.TypeOf(p); t == nil || t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Struct { + return errToStructArgType(p) + } + if len(r.vals) != len(r.fields) { + return errFieldsMismatchVals(r) + } + // Call decodeStruct directly to decode the row as a typed proto.ListValue. + return decodeStruct( + &sppb.StructType{Fields: r.fields}, + &proto3.ListValue{Values: r.vals}, + p, + ) +} diff --git a/vendor/cloud.google.com/go/spanner/row_test.go b/vendor/cloud.google.com/go/spanner/row_test.go new file mode 100644 index 000000000..2120421ac --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/row_test.go @@ -0,0 +1,1775 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spanner + +import ( + "encoding/base64" + "reflect" + "strconv" + "strings" + "testing" + "time" + + "cloud.google.com/go/civil" + proto "github.com/golang/protobuf/proto" + proto3 "github.com/golang/protobuf/ptypes/struct" + sppb "google.golang.org/genproto/googleapis/spanner/v1" +) + +var ( + tm = time.Date(2016, 11, 15, 0, 0, 0, 0, time.UTC) + dt, _ = civil.ParseDate("2016-11-15") + // row contains a column for each unique Cloud Spanner type. + row = Row{ + []*sppb.StructType_Field{ + // STRING / STRING ARRAY + {"STRING", stringType()}, + {"NULL_STRING", stringType()}, + {"STRING_ARRAY", listType(stringType())}, + {"NULL_STRING_ARRAY", listType(stringType())}, + // BYTES / BYTES ARRAY + {"BYTES", bytesType()}, + {"NULL_BYTES", bytesType()}, + {"BYTES_ARRAY", listType(bytesType())}, + {"NULL_BYTES_ARRAY", listType(bytesType())}, + // INT64 / INT64 ARRAY + {"INT64", intType()}, + {"NULL_INT64", intType()}, + {"INT64_ARRAY", listType(intType())}, + {"NULL_INT64_ARRAY", listType(intType())}, + // BOOL / BOOL ARRAY + {"BOOL", boolType()}, + {"NULL_BOOL", boolType()}, + {"BOOL_ARRAY", listType(boolType())}, + {"NULL_BOOL_ARRAY", listType(boolType())}, + // FLOAT64 / FLOAT64 ARRAY + {"FLOAT64", floatType()}, + {"NULL_FLOAT64", floatType()}, + {"FLOAT64_ARRAY", listType(floatType())}, + {"NULL_FLOAT64_ARRAY", listType(floatType())}, + // TIMESTAMP / TIMESTAMP ARRAY + {"TIMESTAMP", timeType()}, + {"NULL_TIMESTAMP", timeType()}, + {"TIMESTAMP_ARRAY", listType(timeType())}, + {"NULL_TIMESTAMP_ARRAY", listType(timeType())}, + // DATE / DATE ARRAY + {"DATE", dateType()}, + {"NULL_DATE", dateType()}, + {"DATE_ARRAY", listType(dateType())}, + {"NULL_DATE_ARRAY", listType(dateType())}, + + // STRUCT ARRAY + { + "STRUCT_ARRAY", + listType( + structType( + mkField("Col1", intType()), + mkField("Col2", floatType()), + mkField("Col3", stringType()), + ), + ), + }, + { + "NULL_STRUCT_ARRAY", + listType( + structType( + mkField("Col1", intType()), + mkField("Col2", floatType()), + mkField("Col3", stringType()), + ), + ), + }, + }, + []*proto3.Value{ + // STRING / STRING ARRAY + stringProto("value"), + nullProto(), + listProto(stringProto("value1"), nullProto(), stringProto("value3")), + nullProto(), + // BYTES / BYTES ARRAY + bytesProto([]byte("value")), + nullProto(), + listProto(bytesProto([]byte("value1")), nullProto(), bytesProto([]byte("value3"))), + nullProto(), + // INT64 / INT64 ARRAY + intProto(17), + nullProto(), + listProto(intProto(1), intProto(2), nullProto()), + nullProto(), + // BOOL / BOOL ARRAY + boolProto(true), + nullProto(), + listProto(nullProto(), boolProto(true), boolProto(false)), + nullProto(), + // FLOAT64 / FLOAT64 ARRAY + floatProto(1.7), + nullProto(), + listProto(nullProto(), nullProto(), floatProto(1.7)), + nullProto(), + // TIMESTAMP / TIMESTAMP ARRAY + timeProto(tm), + nullProto(), + listProto(nullProto(), timeProto(tm)), + nullProto(), + // DATE / DATE ARRAY + dateProto(dt), + nullProto(), + listProto(nullProto(), dateProto(dt)), + nullProto(), + // STRUCT ARRAY + listProto( + nullProto(), + listProto(intProto(3), floatProto(33.3), stringProto("three")), + nullProto(), + ), + nullProto(), + }, + } +) + +// Test helpers for getting column values. +func TestColumnValues(t *testing.T) { + vals := []interface{}{} + wantVals := []interface{}{} + // Test getting column values. + for i, wants := range [][]interface{}{ + // STRING / STRING ARRAY + {"value", NullString{"value", true}}, + {NullString{}}, + {[]NullString{{"value1", true}, {}, {"value3", true}}}, + {[]NullString(nil)}, + // BYTES / BYTES ARRAY + {[]byte("value")}, + {[]byte(nil)}, + {[][]byte{[]byte("value1"), nil, []byte("value3")}}, + {[][]byte(nil)}, + // INT64 / INT64 ARRAY + {int64(17), NullInt64{17, true}}, + {NullInt64{}}, + {[]NullInt64{{1, true}, {2, true}, {}}}, + {[]NullInt64(nil)}, + // BOOL / BOOL ARRAY + {true, NullBool{true, true}}, + {NullBool{}}, + {[]NullBool{{}, {true, true}, {false, true}}}, + {[]NullBool(nil)}, + // FLOAT64 / FLOAT64 ARRAY + {1.7, NullFloat64{1.7, true}}, + {NullFloat64{}}, + {[]NullFloat64{{}, {}, {1.7, true}}}, + {[]NullFloat64(nil)}, + // TIMESTAMP / TIMESTAMP ARRAY + {tm, NullTime{tm, true}}, + {NullTime{}}, + {[]NullTime{{}, {tm, true}}}, + {[]NullTime(nil)}, + // DATE / DATE ARRAY + {dt, NullDate{dt, true}}, + {NullDate{}}, + {[]NullDate{{}, {dt, true}}}, + {[]NullDate(nil)}, + // STRUCT ARRAY + { + []*struct { + Col1 NullInt64 + Col2 NullFloat64 + Col3 string + }{ + nil, + &struct { + Col1 NullInt64 + Col2 NullFloat64 + Col3 string + }{ + NullInt64{3, true}, + NullFloat64{33.3, true}, + "three", + }, + nil, + }, + []NullRow{ + {}, + { + Row: Row{ + fields: []*sppb.StructType_Field{ + mkField("Col1", intType()), + mkField("Col2", floatType()), + mkField("Col3", stringType()), + }, + vals: []*proto3.Value{ + intProto(3), + floatProto(33.3), + stringProto("three"), + }, + }, + Valid: true, + }, + {}, + }, + }, + { + []*struct { + Col1 NullInt64 + Col2 NullFloat64 + Col3 string + }(nil), + []NullRow(nil), + }, + } { + for j, want := range wants { + // Prepare Value vector to test Row.Columns. + if j == 0 { + vals = append(vals, reflect.New(reflect.TypeOf(want)).Interface()) + wantVals = append(wantVals, want) + } + // Column + gotp := reflect.New(reflect.TypeOf(want)) + err := row.Column(i, gotp.Interface()) + if err != nil { + t.Errorf("\t row.Column(%v, %T) returns error: %v, want nil", i, gotp.Interface(), err) + } + if got := reflect.Indirect(gotp).Interface(); !reflect.DeepEqual(got, want) { + t.Errorf("\t row.Column(%v, %T) retrives %v, want %v", i, gotp.Interface(), got, want) + } + // ColumnByName + gotp = reflect.New(reflect.TypeOf(want)) + err = row.ColumnByName(row.fields[i].Name, gotp.Interface()) + if err != nil { + t.Errorf("\t row.ColumnByName(%v, %T) returns error: %v, want nil", row.fields[i].Name, gotp.Interface(), err) + } + if got := reflect.Indirect(gotp).Interface(); !reflect.DeepEqual(got, want) { + t.Errorf("\t row.ColumnByName(%v, %T) retrives %v, want %v", row.fields[i].Name, gotp.Interface(), got, want) + } + } + } + // Test Row.Columns. + if err := row.Columns(vals...); err != nil { + t.Errorf("row.Columns() returns error: %v, want nil", err) + } + for i, want := range wantVals { + if got := reflect.Indirect(reflect.ValueOf(vals[i])).Interface(); !reflect.DeepEqual(got, want) { + t.Errorf("\t got %v(%T) for column[%v], want %v(%T)", got, got, row.fields[i].Name, want, want) + } + } +} + +// Test decoding into nil destination. +func TestNilDst(t *testing.T) { + for i, test := range []struct { + r *Row + dst interface{} + wantErr error + structDst interface{} + wantToStructErr error + }{ + { + &Row{ + []*sppb.StructType_Field{ + {"Col0", stringType()}, + }, + []*proto3.Value{stringProto("value")}, + }, + nil, + errDecodeColumn(0, errNilDst(nil)), + nil, + errToStructArgType(nil), + }, + { + &Row{ + []*sppb.StructType_Field{ + {"Col0", stringType()}, + }, + []*proto3.Value{stringProto("value")}, + }, + (*string)(nil), + errDecodeColumn(0, errNilDst((*string)(nil))), + (*struct{ STRING string })(nil), + errNilDst((*struct{ STRING string })(nil)), + }, + { + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType( + structType( + mkField("Col1", intType()), + mkField("Col2", floatType()), + ), + ), + }, + }, + []*proto3.Value{listProto( + listProto(intProto(3), floatProto(33.3)), + )}, + }, + (*[]*struct { + Col1 int + Col2 float64 + })(nil), + errDecodeColumn(0, errNilDst((*[]*struct { + Col1 int + Col2 float64 + })(nil))), + (*struct { + StructArray []*struct { + Col1 int + Col2 float64 + } `spanner:"STRUCT_ARRAY"` + })(nil), + errNilDst((*struct { + StructArray []*struct { + Col1 int + Col2 float64 + } `spanner:"STRUCT_ARRAY"` + })(nil)), + }, + } { + if gotErr := test.r.Column(0, test.dst); !reflect.DeepEqual(gotErr, test.wantErr) { + t.Errorf("%v: test.r.Column() returns error %v, want %v", i, gotErr, test.wantErr) + } + if gotErr := test.r.ColumnByName("Col0", test.dst); !reflect.DeepEqual(gotErr, test.wantErr) { + t.Errorf("%v: test.r.ColumnByName() returns error %v, want %v", i, gotErr, test.wantErr) + } + // Row.Columns(T) should return nil on T == nil, otherwise, it should return test.wantErr. + wantColumnsErr := test.wantErr + if test.dst == nil { + wantColumnsErr = nil + } + if gotErr := test.r.Columns(test.dst); !reflect.DeepEqual(gotErr, wantColumnsErr) { + t.Errorf("%v: test.r.Columns() returns error %v, want %v", i, gotErr, wantColumnsErr) + } + if gotErr := test.r.ToStruct(test.structDst); !reflect.DeepEqual(gotErr, test.wantToStructErr) { + t.Errorf("%v: test.r.ToStruct() returns error %v, want %v", i, gotErr, test.wantToStructErr) + } + } +} + +// Test decoding NULL columns using Go types that don't support NULL. +func TestNullTypeErr(t *testing.T) { + var tm time.Time + ntoi := func(n string) int { + for i, f := range row.fields { + if f.Name == n { + return i + } + } + t.Errorf("cannot find column name %q in row", n) + return 0 + } + for _, test := range []struct { + colName string + dst interface{} + }{ + { + "NULL_STRING", + proto.String(""), + }, + { + "NULL_INT64", + proto.Int64(0), + }, + { + "NULL_BOOL", + proto.Bool(false), + }, + { + "NULL_FLOAT64", + proto.Float64(0.0), + }, + { + "NULL_TIMESTAMP", + &tm, + }, + { + "NULL_DATE", + &dt, + }, + } { + wantErr := errDecodeColumn(ntoi(test.colName), errDstNotForNull(test.dst)) + if gotErr := row.ColumnByName(test.colName, test.dst); !reflect.DeepEqual(gotErr, wantErr) { + t.Errorf("row.ColumnByName(%v) returns error %v, want %v", test.colName, gotErr, wantErr) + } + } +} + +// Test using wrong destination type in column decoders. +func TestColumnTypeErr(t *testing.T) { + // badDst cannot hold any of the column values. + badDst := &struct{}{} + for i, f := range row.fields { // For each of the columns, try to decode it into badDst. + tc := f.Type.Code + isArray := strings.Contains(f.Name, "ARRAY") + if isArray { + tc = f.Type.ArrayElementType.Code + } + wantErr := errDecodeColumn(i, errTypeMismatch(tc, isArray, badDst)) + if gotErr := row.Column(i, badDst); !reflect.DeepEqual(gotErr, wantErr) { + t.Errorf("Column(%v): decoding into destination with wrong type %T returns error %v, want %v", + i, badDst, gotErr, wantErr) + } + if gotErr := row.ColumnByName(f.Name, badDst); !reflect.DeepEqual(gotErr, wantErr) { + t.Errorf("ColumnByName(%v): decoding into destination with wrong type %T returns error %v, want %v", + f.Name, badDst, gotErr, wantErr) + } + } + wantErr := errDecodeColumn(1, errTypeMismatch(sppb.TypeCode_STRING, false, badDst)) + // badDst is used to receive column 1. + vals := []interface{}{nil, badDst} // Row.Column() is expected to fail at column 1. + // Skip decoding the rest columns by providing nils as the destinations. + for i := 2; i < len(row.fields); i++ { + vals = append(vals, nil) + } + if gotErr := row.Columns(vals...); !reflect.DeepEqual(gotErr, wantErr) { + t.Errorf("Columns(): decoding column 1 with wrong type %T returns error %v, want %v", + badDst, gotErr, wantErr) + } +} + +// Test the handling of invalid column decoding requests which cannot be mapped to correct column(s). +func TestInvalidColumnRequest(t *testing.T) { + for _, test := range []struct { + desc string + f func() error + wantErr error + }{ + { + "Request column index is out of range", + func() error { + return row.Column(10000, &struct{}{}) + }, + errColIdxOutOfRange(10000, &row), + }, + { + "Cannot find the named column", + func() error { + return row.ColumnByName("string", &struct{}{}) + }, + errColNotFound("string"), + }, + { + "Not enough arguments to call row.Columns()", + func() error { + return row.Columns(nil, nil) + }, + errNumOfColValue(2, &row), + }, + { + "Call ColumnByName on row with duplicated column names", + func() error { + var s string + r := &Row{ + []*sppb.StructType_Field{ + {"Val", stringType()}, + {"Val", stringType()}, + }, + []*proto3.Value{stringProto("value1"), stringProto("value2")}, + } + return r.ColumnByName("Val", &s) + }, + errDupColName("Val"), + }, + { + "Call ToStruct on row with duplicated column names", + func() error { + s := &struct { + Val string + }{} + r := &Row{ + []*sppb.StructType_Field{ + {"Val", stringType()}, + {"Val", stringType()}, + }, + []*proto3.Value{stringProto("value1"), stringProto("value2")}, + } + return r.ToStruct(s) + }, + errDupSpannerField("Val", &sppb.StructType{ + Fields: []*sppb.StructType_Field{ + {"Val", stringType()}, + {"Val", stringType()}, + }, + }), + }, + { + "Call ToStruct on a row with unnamed field", + func() error { + s := &struct { + Val string + }{} + r := &Row{ + []*sppb.StructType_Field{ + {"", stringType()}, + }, + []*proto3.Value{stringProto("value1")}, + } + return r.ToStruct(s) + }, + errUnnamedField(&sppb.StructType{Fields: []*sppb.StructType_Field{{"", stringType()}}}, 0), + }, + } { + if gotErr := test.f(); !reflect.DeepEqual(gotErr, test.wantErr) { + t.Errorf("%v: test.f() returns error %v, want %v", test.desc, gotErr, test.wantErr) + } + } +} + +// Test decoding the row with row.ToStruct into an invalid destination. +func TestToStructInvalidDst(t *testing.T) { + for _, test := range []struct { + desc string + dst interface{} + wantErr error + }{ + { + "Decode row as STRUCT into int32", + proto.Int(1), + errToStructArgType(proto.Int(1)), + }, + { + "Decode row as STRUCT to nil Go struct", + (*struct{})(nil), + errNilDst((*struct{})(nil)), + }, + { + "Decode row as STRUCT to Go struct with duplicated fields for the PK column", + &struct { + PK1 string `spanner:"STRING"` + PK2 string `spanner:"STRING"` + }{}, + errNoOrDupGoField(&struct { + PK1 string `spanner:"STRING"` + PK2 string `spanner:"STRING"` + }{}, "STRING"), + }, + { + "Decode row as STRUCT to Go struct with no field for the PK column", + &struct { + PK1 string `spanner:"_STRING"` + }{}, + errNoOrDupGoField(&struct { + PK1 string `spanner:"_STRING"` + }{}, "STRING"), + }, + { + "Decode row as STRUCT to Go struct with wrong type for the PK column", + &struct { + PK1 int64 `spanner:"STRING"` + }{}, + errDecodeStructField(&sppb.StructType{Fields: row.fields}, "STRING", + errTypeMismatch(sppb.TypeCode_STRING, false, proto.Int64(0))), + }, + } { + if gotErr := row.ToStruct(test.dst); !reflect.DeepEqual(gotErr, test.wantErr) { + t.Errorf("%v: decoding:\ngot %v\nwant %v", test.desc, gotErr, test.wantErr) + } + } +} + +// Test decoding a broken row. +func TestBrokenRow(t *testing.T) { + for i, test := range []struct { + row *Row + dst interface{} + wantErr error + }{ + { + // A row with no field. + &Row{ + []*sppb.StructType_Field{}, + []*proto3.Value{stringProto("value")}, + }, + &NullString{"value", true}, + errFieldsMismatchVals(&Row{ + []*sppb.StructType_Field{}, + []*proto3.Value{stringProto("value")}, + }), + }, + { + // A row with nil field. + &Row{ + []*sppb.StructType_Field{nil}, + []*proto3.Value{stringProto("value")}, + }, + &NullString{"value", true}, + errNilColType(0), + }, + { + // Field is not nil, but its type is nil. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + nil, + }, + }, + []*proto3.Value{listProto(stringProto("value1"), stringProto("value2"))}, + }, + &[]NullString{}, + errDecodeColumn(0, errNilSpannerType()), + }, + { + // Field is not nil, field type is not nil, but it is an array and its array element type is nil. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + &sppb.Type{ + Code: sppb.TypeCode_ARRAY, + }, + }, + }, + []*proto3.Value{listProto(stringProto("value1"), stringProto("value2"))}, + }, + &[]NullString{}, + errDecodeColumn(0, errNilArrElemType(&sppb.Type{Code: sppb.TypeCode_ARRAY})), + }, + { + // Field specifies valid type, value is nil. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + intType(), + }, + }, + []*proto3.Value{nil}, + }, + &NullInt64{1, true}, + errDecodeColumn(0, errNilSrc()), + }, + { + // Field specifies INT64 type, value is having a nil Kind. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + intType(), + }, + }, + []*proto3.Value{{Kind: (*proto3.Value_StringValue)(nil)}}, + }, + &NullInt64{1, true}, + errDecodeColumn(0, errSrcVal(&proto3.Value{Kind: (*proto3.Value_StringValue)(nil)}, "String")), + }, + { + // Field specifies INT64 type, but value is for Number type. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + intType(), + }, + }, + []*proto3.Value{floatProto(1.0)}, + }, + &NullInt64{1, true}, + errDecodeColumn(0, errSrcVal(floatProto(1.0), "String")), + }, + { + // Field specifies INT64 type, but value is wrongly encoded. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + intType(), + }, + }, + []*proto3.Value{stringProto("&1")}, + }, + proto.Int64(0), + errDecodeColumn(0, errBadEncoding(stringProto("&1"), func() error { + _, err := strconv.ParseInt("&1", 10, 64) + return err + }())), + }, + { + // Field specifies INT64 type, but value is wrongly encoded. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + intType(), + }, + }, + []*proto3.Value{stringProto("&1")}, + }, + &NullInt64{}, + errDecodeColumn(0, errBadEncoding(stringProto("&1"), func() error { + _, err := strconv.ParseInt("&1", 10, 64) + return err + }())), + }, + { + // Field specifies STRING type, but value is having a nil Kind. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + stringType(), + }, + }, + []*proto3.Value{{Kind: (*proto3.Value_StringValue)(nil)}}, + }, + &NullString{"value", true}, + errDecodeColumn(0, errSrcVal(&proto3.Value{Kind: (*proto3.Value_StringValue)(nil)}, "String")), + }, + { + // Field specifies STRING type, but value is for ARRAY type. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + stringType(), + }, + }, + []*proto3.Value{listProto(stringProto("value"))}, + }, + &NullString{"value", true}, + errDecodeColumn(0, errSrcVal(listProto(stringProto("value")), "String")), + }, + { + // Field specifies FLOAT64 type, value is having a nil Kind. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + floatType(), + }, + }, + []*proto3.Value{{Kind: (*proto3.Value_NumberValue)(nil)}}, + }, + &NullFloat64{1.0, true}, + errDecodeColumn(0, errSrcVal(&proto3.Value{Kind: (*proto3.Value_NumberValue)(nil)}, "Number")), + }, + { + // Field specifies FLOAT64 type, but value is for BOOL type. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + floatType(), + }, + }, + []*proto3.Value{boolProto(true)}, + }, + &NullFloat64{1.0, true}, + errDecodeColumn(0, errSrcVal(boolProto(true), "Number")), + }, + { + // Field specifies FLOAT64 type, but value is wrongly encoded. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + floatType(), + }, + }, + []*proto3.Value{stringProto("nan")}, + }, + &NullFloat64{}, + errDecodeColumn(0, errUnexpectedNumStr("nan")), + }, + { + // Field specifies FLOAT64 type, but value is wrongly encoded. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + floatType(), + }, + }, + []*proto3.Value{stringProto("nan")}, + }, + proto.Float64(0), + errDecodeColumn(0, errUnexpectedNumStr("nan")), + }, + { + // Field specifies BYTES type, value is having a nil Kind. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + bytesType(), + }, + }, + []*proto3.Value{{Kind: (*proto3.Value_StringValue)(nil)}}, + }, + &[]byte{}, + errDecodeColumn(0, errSrcVal(&proto3.Value{Kind: (*proto3.Value_StringValue)(nil)}, "String")), + }, + { + // Field specifies BYTES type, but value is for BOOL type. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + bytesType(), + }, + }, + []*proto3.Value{boolProto(false)}, + }, + &[]byte{}, + errDecodeColumn(0, errSrcVal(boolProto(false), "String")), + }, + { + // Field specifies BYTES type, but value is wrongly encoded. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + bytesType(), + }, + }, + []*proto3.Value{stringProto("&&")}, + }, + &[]byte{}, + errDecodeColumn(0, errBadEncoding(stringProto("&&"), func() error { + _, err := base64.StdEncoding.DecodeString("&&") + return err + }())), + }, + { + // Field specifies BOOL type, value is having a nil Kind. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + boolType(), + }, + }, + []*proto3.Value{{Kind: (*proto3.Value_BoolValue)(nil)}}, + }, + &NullBool{false, true}, + errDecodeColumn(0, errSrcVal(&proto3.Value{Kind: (*proto3.Value_BoolValue)(nil)}, "Bool")), + }, + { + // Field specifies BOOL type, but value is for STRING type. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + boolType(), + }, + }, + []*proto3.Value{stringProto("false")}, + }, + &NullBool{false, true}, + errDecodeColumn(0, errSrcVal(stringProto("false"), "Bool")), + }, + { + // Field specifies TIMESTAMP type, value is having a nil Kind. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + timeType(), + }, + }, + []*proto3.Value{{Kind: (*proto3.Value_StringValue)(nil)}}, + }, + &NullTime{time.Now(), true}, + errDecodeColumn(0, errSrcVal(&proto3.Value{Kind: (*proto3.Value_StringValue)(nil)}, "String")), + }, + { + // Field specifies TIMESTAMP type, but value is for BOOL type. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + timeType(), + }, + }, + []*proto3.Value{boolProto(false)}, + }, + &NullTime{time.Now(), true}, + errDecodeColumn(0, errSrcVal(boolProto(false), "String")), + }, + { + // Field specifies TIMESTAMP type, but value is invalid timestamp. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + timeType(), + }, + }, + []*proto3.Value{stringProto("junk")}, + }, + &NullTime{time.Now(), true}, + errDecodeColumn(0, errBadEncoding(stringProto("junk"), func() error { + _, err := time.Parse(time.RFC3339Nano, "junk") + return err + }())), + }, + { + // Field specifies DATE type, value is having a nil Kind. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + dateType(), + }, + }, + []*proto3.Value{{Kind: (*proto3.Value_StringValue)(nil)}}, + }, + &NullDate{civil.Date{}, true}, + errDecodeColumn(0, errSrcVal(&proto3.Value{Kind: (*proto3.Value_StringValue)(nil)}, "String")), + }, + { + // Field specifies DATE type, but value is for BOOL type. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + dateType(), + }, + }, + []*proto3.Value{boolProto(false)}, + }, + &NullDate{civil.Date{}, true}, + errDecodeColumn(0, errSrcVal(boolProto(false), "String")), + }, + { + // Field specifies DATE type, but value is invalid timestamp. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + dateType(), + }, + }, + []*proto3.Value{stringProto("junk")}, + }, + &NullDate{civil.Date{}, true}, + errDecodeColumn(0, errBadEncoding(stringProto("junk"), func() error { + _, err := civil.ParseDate("junk") + return err + }())), + }, + + { + // Field specifies ARRAY<INT64> type, value is having a nil Kind. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType(intType()), + }, + }, + []*proto3.Value{{Kind: (*proto3.Value_ListValue)(nil)}}, + }, + &[]NullInt64{}, + errDecodeColumn(0, errSrcVal(&proto3.Value{Kind: (*proto3.Value_ListValue)(nil)}, "List")), + }, + { + // Field specifies ARRAY<INT64> type, value is having a nil ListValue. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType(intType()), + }, + }, + []*proto3.Value{{Kind: &proto3.Value_ListValue{}}}, + }, + &[]NullInt64{}, + errDecodeColumn(0, errNilListValue("INT64")), + }, + { + // Field specifies ARRAY<INT64> type, but value is for BYTES type. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType(intType()), + }, + }, + []*proto3.Value{bytesProto([]byte("value"))}, + }, + &[]NullInt64{}, + errDecodeColumn(0, errSrcVal(bytesProto([]byte("value")), "List")), + }, + { + // Field specifies ARRAY<INT64> type, but value is for ARRAY<BOOL> type. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType(intType()), + }, + }, + []*proto3.Value{listProto(boolProto(true))}, + }, + &[]NullInt64{}, + errDecodeColumn(0, errDecodeArrayElement(0, boolProto(true), + "INT64", errSrcVal(boolProto(true), "String"))), + }, + { + // Field specifies ARRAY<STRING> type, value is having a nil Kind. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType(stringType()), + }, + }, + []*proto3.Value{{Kind: (*proto3.Value_ListValue)(nil)}}, + }, + &[]NullString{}, + errDecodeColumn(0, errSrcVal(&proto3.Value{Kind: (*proto3.Value_ListValue)(nil)}, "List")), + }, + { + // Field specifies ARRAY<STRING> type, value is having a nil ListValue. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType(stringType()), + }, + }, + []*proto3.Value{{Kind: &proto3.Value_ListValue{}}}, + }, + &[]NullString{}, + errDecodeColumn(0, errNilListValue("STRING")), + }, + { + // Field specifies ARRAY<STRING> type, but value is for BOOL type. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType(stringType()), + }, + }, + []*proto3.Value{boolProto(true)}, + }, + &[]NullString{}, + errDecodeColumn(0, errSrcVal(boolProto(true), "List")), + }, + { + // Field specifies ARRAY<STRING> type, but value is for ARRAY<BOOL> type. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType(stringType()), + }, + }, + []*proto3.Value{listProto(boolProto(true))}, + }, + &[]NullString{}, + errDecodeColumn(0, errDecodeArrayElement(0, boolProto(true), + "STRING", errSrcVal(boolProto(true), "String"))), + }, + { + // Field specifies ARRAY<FLOAT64> type, value is having a nil Kind. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType(floatType()), + }, + }, + []*proto3.Value{{Kind: (*proto3.Value_ListValue)(nil)}}, + }, + &[]NullFloat64{}, + errDecodeColumn(0, errSrcVal(&proto3.Value{Kind: (*proto3.Value_ListValue)(nil)}, "List")), + }, + { + // Field specifies ARRAY<FLOAT64> type, value is having a nil ListValue. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType(floatType()), + }, + }, + []*proto3.Value{{Kind: &proto3.Value_ListValue{}}}, + }, + &[]NullFloat64{}, + errDecodeColumn(0, errNilListValue("FLOAT64")), + }, + { + // Field specifies ARRAY<FLOAT64> type, but value is for STRING type. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType(floatType()), + }, + }, + []*proto3.Value{stringProto("value")}, + }, + &[]NullFloat64{}, + errDecodeColumn(0, errSrcVal(stringProto("value"), "List")), + }, + { + // Field specifies ARRAY<FLOAT64> type, but value is for ARRAY<BOOL> type. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType(floatType()), + }, + }, + []*proto3.Value{listProto(boolProto(true))}, + }, + &[]NullFloat64{}, + errDecodeColumn(0, errDecodeArrayElement(0, boolProto(true), + "FLOAT64", errSrcVal(boolProto(true), "Number"))), + }, + { + // Field specifies ARRAY<BYTES> type, value is having a nil Kind. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType(bytesType()), + }, + }, + []*proto3.Value{{Kind: (*proto3.Value_ListValue)(nil)}}, + }, + &[][]byte{}, + errDecodeColumn(0, errSrcVal(&proto3.Value{Kind: (*proto3.Value_ListValue)(nil)}, "List")), + }, + { + // Field specifies ARRAY<BYTES> type, value is having a nil ListValue. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType(bytesType()), + }, + }, + []*proto3.Value{{Kind: &proto3.Value_ListValue{}}}, + }, + &[][]byte{}, + errDecodeColumn(0, errNilListValue("BYTES")), + }, + { + // Field specifies ARRAY<BYTES> type, but value is for FLOAT64 type. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType(bytesType()), + }, + }, + []*proto3.Value{floatProto(1.0)}, + }, + &[][]byte{}, + errDecodeColumn(0, errSrcVal(floatProto(1.0), "List")), + }, + { + // Field specifies ARRAY<BYTES> type, but value is for ARRAY<FLOAT64> type. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType(bytesType()), + }, + }, + []*proto3.Value{listProto(floatProto(1.0))}, + }, + &[][]byte{}, + errDecodeColumn(0, errDecodeArrayElement(0, floatProto(1.0), + "BYTES", errSrcVal(floatProto(1.0), "String"))), + }, + { + // Field specifies ARRAY<BOOL> type, value is having a nil Kind. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType(boolType()), + }, + }, + []*proto3.Value{{Kind: (*proto3.Value_ListValue)(nil)}}, + }, + &[]NullBool{}, + errDecodeColumn(0, errSrcVal(&proto3.Value{Kind: (*proto3.Value_ListValue)(nil)}, "List")), + }, + { + // Field specifies ARRAY<BOOL> type, value is having a nil ListValue. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType(boolType()), + }, + }, + []*proto3.Value{{Kind: &proto3.Value_ListValue{}}}, + }, + &[]NullBool{}, + errDecodeColumn(0, errNilListValue("BOOL")), + }, + { + // Field specifies ARRAY<BOOL> type, but value is for FLOAT64 type. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType(boolType()), + }, + }, + []*proto3.Value{floatProto(1.0)}, + }, + &[]NullBool{}, + errDecodeColumn(0, errSrcVal(floatProto(1.0), "List")), + }, + { + // Field specifies ARRAY<BOOL> type, but value is for ARRAY<FLOAT64> type. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType(boolType()), + }, + }, + []*proto3.Value{listProto(floatProto(1.0))}, + }, + &[]NullBool{}, + errDecodeColumn(0, errDecodeArrayElement(0, floatProto(1.0), + "BOOL", errSrcVal(floatProto(1.0), "Bool"))), + }, + { + // Field specifies ARRAY<TIMESTAMP> type, value is having a nil Kind. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType(timeType()), + }, + }, + []*proto3.Value{{Kind: (*proto3.Value_ListValue)(nil)}}, + }, + &[]NullTime{}, + errDecodeColumn(0, errSrcVal(&proto3.Value{Kind: (*proto3.Value_ListValue)(nil)}, "List")), + }, + { + // Field specifies ARRAY<TIMESTAMP> type, value is having a nil ListValue. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType(timeType()), + }, + }, + []*proto3.Value{{Kind: &proto3.Value_ListValue{}}}, + }, + &[]NullTime{}, + errDecodeColumn(0, errNilListValue("TIMESTAMP")), + }, + { + // Field specifies ARRAY<TIMESTAMP> type, but value is for FLOAT64 type. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType(timeType()), + }, + }, + []*proto3.Value{floatProto(1.0)}, + }, + &[]NullTime{}, + errDecodeColumn(0, errSrcVal(floatProto(1.0), "List")), + }, + { + // Field specifies ARRAY<TIMESTAMP> type, but value is for ARRAY<FLOAT64> type. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType(timeType()), + }, + }, + []*proto3.Value{listProto(floatProto(1.0))}, + }, + &[]NullTime{}, + errDecodeColumn(0, errDecodeArrayElement(0, floatProto(1.0), + "TIMESTAMP", errSrcVal(floatProto(1.0), "String"))), + }, + { + // Field specifies ARRAY<DATE> type, value is having a nil Kind. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType(dateType()), + }, + }, + []*proto3.Value{{Kind: (*proto3.Value_ListValue)(nil)}}, + }, + &[]NullDate{}, + errDecodeColumn(0, errSrcVal(&proto3.Value{Kind: (*proto3.Value_ListValue)(nil)}, "List")), + }, + { + // Field specifies ARRAY<DATE> type, value is having a nil ListValue. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType(dateType()), + }, + }, + []*proto3.Value{{Kind: &proto3.Value_ListValue{}}}, + }, + &[]NullDate{}, + errDecodeColumn(0, errNilListValue("DATE")), + }, + { + // Field specifies ARRAY<DATE> type, but value is for FLOAT64 type. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType(dateType()), + }, + }, + []*proto3.Value{floatProto(1.0)}, + }, + &[]NullDate{}, + errDecodeColumn(0, errSrcVal(floatProto(1.0), "List")), + }, + { + // Field specifies ARRAY<DATE> type, but value is for ARRAY<FLOAT64> type. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType(dateType()), + }, + }, + []*proto3.Value{listProto(floatProto(1.0))}, + }, + &[]NullDate{}, + errDecodeColumn(0, errDecodeArrayElement(0, floatProto(1.0), + "DATE", errSrcVal(floatProto(1.0), "String"))), + }, + { + // Field specifies ARRAY<STRUCT> type, value is having a nil Kind. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType( + structType( + mkField("Col1", intType()), + mkField("Col2", floatType()), + mkField("Col3", stringType()), + ), + ), + }, + }, + []*proto3.Value{{Kind: (*proto3.Value_ListValue)(nil)}}, + }, + &[]*struct { + Col1 int64 + Col2 float64 + Col3 string + }{}, + errDecodeColumn(0, errSrcVal(&proto3.Value{Kind: (*proto3.Value_ListValue)(nil)}, "List")), + }, + { + // Field specifies ARRAY<STRUCT> type, value is having a nil ListValue. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType( + structType( + mkField("Col1", intType()), + mkField("Col2", floatType()), + mkField("Col3", stringType()), + ), + ), + }, + }, + []*proto3.Value{{Kind: &proto3.Value_ListValue{}}}, + }, + &[]*struct { + Col1 int64 + Col2 float64 + Col3 string + }{}, + errDecodeColumn(0, errNilListValue("STRUCT")), + }, + { + // Field specifies ARRAY<STRUCT> type, value is having a nil ListValue. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType( + structType( + mkField("Col1", intType()), + mkField("Col2", floatType()), + mkField("Col3", stringType()), + ), + ), + }, + }, + []*proto3.Value{{Kind: &proto3.Value_ListValue{}}}, + }, + &[]NullRow{}, + errDecodeColumn(0, errNilListValue("STRUCT")), + }, + { + // Field specifies ARRAY<STRUCT> type, value is for BYTES type. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType( + structType( + mkField("Col1", intType()), + mkField("Col2", floatType()), + mkField("Col3", stringType()), + ), + ), + }, + }, + []*proto3.Value{bytesProto([]byte("value"))}, + }, + &[]*struct { + Col1 int64 + Col2 float64 + Col3 string + }{}, + errDecodeColumn(0, errSrcVal(bytesProto([]byte("value")), "List")), + }, + { + // Field specifies ARRAY<STRUCT> type, value is for BYTES type. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType( + structType( + mkField("Col1", intType()), + mkField("Col2", floatType()), + mkField("Col3", stringType()), + ), + ), + }, + }, + []*proto3.Value{listProto(bytesProto([]byte("value")))}, + }, + &[]NullRow{}, + errDecodeColumn(0, errNotStructElement(0, bytesProto([]byte("value")))), + }, + { + // Field specifies ARRAY<STRUCT> type, value is for ARRAY<BYTES> type. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType( + structType( + mkField("Col1", intType()), + mkField("Col2", floatType()), + mkField("Col3", stringType()), + ), + ), + }, + }, + []*proto3.Value{listProto(bytesProto([]byte("value")))}, + }, + &[]*struct { + Col1 int64 + Col2 float64 + Col3 string + }{}, + errDecodeColumn(0, errDecodeArrayElement(0, bytesProto([]byte("value")), + "STRUCT", errSrcVal(bytesProto([]byte("value")), "List"))), + }, + { + // Field specifies ARRAY<STRUCT>, but is having nil StructType. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType( + &sppb.Type{Code: sppb.TypeCode_STRUCT}, + ), + }, + }, + []*proto3.Value{listProto(listProto(intProto(1), floatProto(2.0), stringProto("3")))}, + }, + &[]*struct { + Col1 int64 + Col2 float64 + Col3 string + }{}, + errDecodeColumn(0, errDecodeArrayElement(0, listProto(intProto(1), floatProto(2.0), stringProto("3")), + "STRUCT", errNilSpannerStructType())), + }, + { + // Field specifies ARRAY<STRUCT>, but the second struct value is for BOOL type instead of FLOAT64. + &Row{ + []*sppb.StructType_Field{ + { + "Col0", + listType( + structType( + mkField("Col1", intType()), + mkField("Col2", floatType()), + mkField("Col3", stringType()), + ), + ), + }, + }, + []*proto3.Value{listProto(listProto(intProto(1), boolProto(true), stringProto("3")))}, + }, + &[]*struct { + Col1 int64 + Col2 float64 + Col3 string + }{}, + errDecodeColumn( + 0, + errDecodeArrayElement( + 0, listProto(intProto(1), boolProto(true), stringProto("3")), "STRUCT", + errDecodeStructField( + &sppb.StructType{ + Fields: []*sppb.StructType_Field{ + mkField("Col1", intType()), + mkField("Col2", floatType()), + mkField("Col3", stringType()), + }, + }, + "Col2", + errSrcVal(boolProto(true), "Number"), + ), + ), + ), + }, + } { + if gotErr := test.row.Column(0, test.dst); !reflect.DeepEqual(gotErr, test.wantErr) { + t.Errorf("%v: test.row.Column(0) got error %v, want %v", i, gotErr, test.wantErr) + } + if gotErr := test.row.ColumnByName("Col0", test.dst); !reflect.DeepEqual(gotErr, test.wantErr) { + t.Errorf("%v: test.row.ColumnByName(%q) got error %v, want %v", i, "Col0", gotErr, test.wantErr) + } + if gotErr := test.row.Columns(test.dst); !reflect.DeepEqual(gotErr, test.wantErr) { + t.Errorf("%v: test.row.Columns(%T) got error %v, want %v", i, test.dst, gotErr, test.wantErr) + } + } +} + +// Test Row.ToStruct(). +func TestToStruct(t *testing.T) { + s := []struct { + // STRING / STRING ARRAY + PrimaryKey string `spanner:"STRING"` + NullString NullString `spanner:"NULL_STRING"` + StringArray []NullString `spanner:"STRING_ARRAY"` + NullStringArray []NullString `spanner:"NULL_STRING_ARRAY"` + // BYTES / BYTES ARRAY + Bytes []byte `spanner:"BYTES"` + NullBytes []byte `spanner:"NULL_BYTES"` + BytesArray [][]byte `spanner:"BYTES_ARRAY"` + NullBytesArray [][]byte `spanner:"NULL_BYTES_ARRAY"` + // INT64 / INT64 ARRAY + Int64 int64 `spanner:"INT64"` + NullInt64 NullInt64 `spanner:"NULL_INT64"` + Int64Array []NullInt64 `spanner:"INT64_ARRAY"` + NullInt64Array []NullInt64 `spanner:"NULL_INT64_ARRAY"` + // BOOL / BOOL ARRAY + Bool bool `spanner:"BOOL"` + NullBool NullBool `spanner:"NULL_BOOL"` + BoolArray []NullBool `spanner:"BOOL_ARRAY"` + NullBoolArray []NullBool `spanner:"NULL_BOOL_ARRAY"` + // FLOAT64 / FLOAT64 ARRAY + Float64 float64 `spanner:"FLOAT64"` + NullFloat64 NullFloat64 `spanner:"NULL_FLOAT64"` + Float64Array []NullFloat64 `spanner:"FLOAT64_ARRAY"` + NullFloat64Array []NullFloat64 `spanner:"NULL_FLOAT64_ARRAY"` + // TIMESTAMP / TIMESTAMP ARRAY + Timestamp time.Time `spanner:"TIMESTAMP"` + NullTimestamp NullTime `spanner:"NULL_TIMESTAMP"` + TimestampArray []NullTime `spanner:"TIMESTAMP_ARRAY"` + NullTimestampArray []NullTime `spanner:"NULL_TIMESTAMP_ARRAY"` + // DATE / DATE ARRAY + Date civil.Date `spanner:"DATE"` + NullDate NullDate `spanner:"NULL_DATE"` + DateArray []NullDate `spanner:"DATE_ARRAY"` + NullDateArray []NullDate `spanner:"NULL_DATE_ARRAY"` + + // STRUCT ARRAY + StructArray []*struct { + Col1 int64 + Col2 float64 + Col3 string + } `spanner:"STRUCT_ARRAY"` + NullStructArray []*struct { + Col1 int64 + Col2 float64 + Col3 string + } `spanner:"NULL_STRUCT_ARRAY"` + }{ + {}, // got + { + // STRING / STRING ARRAY + "value", + NullString{}, + []NullString{{"value1", true}, {}, {"value3", true}}, + []NullString(nil), + // BYTES / BYTES ARRAY + []byte("value"), + []byte(nil), + [][]byte{[]byte("value1"), nil, []byte("value3")}, + [][]byte(nil), + // INT64 / INT64 ARRAY + int64(17), + NullInt64{}, + []NullInt64{{int64(1), true}, {int64(2), true}, {}}, + []NullInt64(nil), + // BOOL / BOOL ARRAY + true, + NullBool{}, + []NullBool{{}, {true, true}, {false, true}}, + []NullBool(nil), + // FLOAT64 / FLOAT64 ARRAY + 1.7, + NullFloat64{}, + []NullFloat64{{}, {}, {1.7, true}}, + []NullFloat64(nil), + // TIMESTAMP / TIMESTAMP ARRAY + tm, + NullTime{}, + []NullTime{{}, {tm, true}}, + []NullTime(nil), + // DATE / DATE ARRAY + dt, + NullDate{}, + []NullDate{{}, {dt, true}}, + []NullDate(nil), + // STRUCT ARRAY + []*struct { + Col1 int64 + Col2 float64 + Col3 string + }{ + nil, + &struct { + Col1 int64 + Col2 float64 + Col3 string + }{3, 33.3, "three"}, + nil, + }, + []*struct { + Col1 int64 + Col2 float64 + Col3 string + }(nil), + }, // want + } + err := row.ToStruct(&s[0]) + if err != nil { + t.Errorf("row.ToStruct() returns error: %v, want nil", err) + } + if !reflect.DeepEqual(s[0], s[1]) { + t.Errorf("row.ToStruct() fetches struct %v, want %v", s[0], s[1]) + } +} + +// Test helpers for getting column names. +func TestColumnNameAndIndex(t *testing.T) { + // Test Row.Size(). + if rs := row.Size(); rs != len(row.fields) { + t.Errorf("row.Size() returns %v, want %v", rs, len(row.fields)) + } + // Test Row.Size() on empty Row. + if rs := (&Row{}).Size(); rs != 0 { + t.Errorf("empty_row.Size() returns %v, want %v", rs, 0) + } + // Test Row.ColumnName() + for i, col := range row.fields { + if cn := row.ColumnName(i); cn != col.Name { + t.Errorf("row.ColumnName(%v) returns %q, want %q", i, cn, col.Name) + } + goti, err := row.ColumnIndex(col.Name) + if err != nil { + t.Errorf("ColumnIndex(%q) error %v", col.Name, err) + continue + } + if goti != i { + t.Errorf("ColumnIndex(%q) = %d, want %d", col.Name, goti, i) + } + } + // Test Row.ColumnName on empty Row. + if cn := (&Row{}).ColumnName(0); cn != "" { + t.Errorf("empty_row.ColumnName(%v) returns %q, want %q", 0, cn, "") + } + // Test Row.ColumnIndex on empty Row. + if _, err := (&Row{}).ColumnIndex(""); err == nil { + t.Error("empty_row.ColumnIndex returns nil, want error") + } +} + +func TestNewRow(t *testing.T) { + for _, test := range []struct { + names []string + values []interface{} + want *Row + wantErr error + }{ + { + want: &Row{fields: []*sppb.StructType_Field{}, vals: []*proto3.Value{}}, + }, + { + names: []string{}, + values: []interface{}{}, + want: &Row{fields: []*sppb.StructType_Field{}, vals: []*proto3.Value{}}, + }, + { + names: []string{"a", "b"}, + values: []interface{}{}, + want: nil, + wantErr: errNamesValuesMismatch([]string{"a", "b"}, []interface{}{}), + }, + { + names: []string{"a", "b", "c"}, + values: []interface{}{5, "abc", GenericColumnValue{listType(intType()), listProto(intProto(91), nullProto(), intProto(87))}}, + want: &Row{ + []*sppb.StructType_Field{ + {"a", intType()}, + {"b", stringType()}, + {"c", listType(intType())}, + }, + []*proto3.Value{ + intProto(5), + stringProto("abc"), + listProto(intProto(91), nullProto(), intProto(87)), + }, + }, + }, + } { + got, err := NewRow(test.names, test.values) + if !reflect.DeepEqual(err, test.wantErr) { + t.Errorf("NewRow(%v,%v).err = %s, want %s", test.names, test.values, err, test.wantErr) + continue + } + if !reflect.DeepEqual(got, test.want) { + t.Errorf("NewRow(%v,%v) = %s, want %s", test.names, test.values, got, test.want) + continue + } + } +} diff --git a/vendor/cloud.google.com/go/spanner/session.go b/vendor/cloud.google.com/go/spanner/session.go new file mode 100644 index 000000000..6930a3aba --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/session.go @@ -0,0 +1,968 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spanner + +import ( + "container/heap" + "container/list" + "fmt" + "math/rand" + "strings" + "sync" + "time" + + log "github.com/golang/glog" + "golang.org/x/net/context" + + sppb "google.golang.org/genproto/googleapis/spanner/v1" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" +) + +// sessionHandle is an interface for transactions to access Cloud Spanner sessions safely. It is generated by sessionPool.take(). +type sessionHandle struct { + // mu guarantees that inner session object is returned / destroyed only once. + mu sync.Mutex + // session is a pointer to a session object. Transactions never need to access it directly. + session *session +} + +// recycle gives the inner session object back to its home session pool. It is safe to call recycle multiple times but only the first one would take effect. +func (sh *sessionHandle) recycle() { + sh.mu.Lock() + defer sh.mu.Unlock() + if sh.session == nil { + // sessionHandle has already been recycled. + return + } + sh.session.recycle() + sh.session = nil +} + +// getID gets the Cloud Spanner session ID from the internal session object. getID returns empty string if the sessionHandle is nil or the inner session +// object has been released by recycle / destroy. +func (sh *sessionHandle) getID() string { + sh.mu.Lock() + defer sh.mu.Unlock() + if sh.session == nil { + // sessionHandle has already been recycled/destroyed. + return "" + } + return sh.session.getID() +} + +// getClient gets the Cloud Spanner RPC client associated with the session ID in sessionHandle. +func (sh *sessionHandle) getClient() sppb.SpannerClient { + sh.mu.Lock() + defer sh.mu.Unlock() + if sh.session == nil { + return nil + } + return sh.session.client +} + +// getMetadata returns the metadata associated with the session in sessionHandle. +func (sh *sessionHandle) getMetadata() metadata.MD { + sh.mu.Lock() + defer sh.mu.Unlock() + if sh.session == nil { + return nil + } + return sh.session.md +} + +// getTransactionID returns the transaction id in the session if available. +func (sh *sessionHandle) getTransactionID() transactionID { + sh.mu.Lock() + defer sh.mu.Unlock() + if sh.session == nil { + return nil + } + return sh.session.tx +} + +// destroy destroys the inner session object. It is safe to call destroy multiple times and only the first call would attempt to +// destroy the inner session object. +func (sh *sessionHandle) destroy() { + sh.mu.Lock() + s := sh.session + sh.session = nil + sh.mu.Unlock() + if s == nil { + // sessionHandle has already been destroyed. + return + } + s.destroy(false) +} + +// session wraps a Cloud Spanner session ID through which transactions are created and executed. +type session struct { + // client is the RPC channel to Cloud Spanner. It is set only once during session's creation. + client sppb.SpannerClient + // id is the unique id of the session in Cloud Spanner. It is set only once during session's creation. + id string + // pool is the session's home session pool where it was created. It is set only once during session's creation. + pool *sessionPool + // createTime is the timestamp of the session's creation. It is set only once during session's creation. + createTime time.Time + + // mu protects the following fields from concurrent access: both healthcheck workers and transactions can modify them. + mu sync.Mutex + // valid marks the validity of a session. + valid bool + // hcIndex is the index of the session inside the global healthcheck queue. If hcIndex < 0, session has been unregistered from the queue. + hcIndex int + // idleList is the linkedlist node which links the session to its home session pool's idle list. If idleList == nil, the + // session is not in idle list. + idleList *list.Element + // nextCheck is the timestamp of next scheduled healthcheck of the session. It is maintained by the global health checker. + nextCheck time.Time + // checkingHelath is true if currently this session is being processed by health checker. Must be modified under health checker lock. + checkingHealth bool + // md is the Metadata to be sent with each request. + md metadata.MD + // tx contains the transaction id if the session has been prepared for write. + tx transactionID +} + +// isValid returns true if the session is still valid for use. +func (s *session) isValid() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.valid +} + +// isWritePrepared returns true if the session is prepared for write. +func (s *session) isWritePrepared() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.tx != nil +} + +// String implements fmt.Stringer for session. +func (s *session) String() string { + s.mu.Lock() + defer s.mu.Unlock() + return fmt.Sprintf("<id=%v, hcIdx=%v, idleList=%p, valid=%v, create=%v, nextcheck=%v>", + s.id, s.hcIndex, s.idleList, s.valid, s.createTime, s.nextCheck) +} + +// ping verifies if the session is still alive in Cloud Spanner. +func (s *session) ping() error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + return runRetryable(ctx, func(ctx context.Context) error { + _, err := s.client.GetSession(contextWithOutgoingMetadata(ctx, s.pool.md), &sppb.GetSessionRequest{Name: s.getID()}) // s.getID is safe even when s is invalid. + return err + }) +} + +// refreshIdle refreshes the session's session ID if it is in its home session pool's idle list +// and returns true if successful. +func (s *session) refreshIdle() bool { + s.mu.Lock() + validAndIdle := s.valid && s.idleList != nil + s.mu.Unlock() + if !validAndIdle { + // Optimization: return early if s is not valid or if s is not in idle list. + return false + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var sid string + err := runRetryable(ctx, func(ctx context.Context) error { + session, e := s.client.CreateSession(contextWithOutgoingMetadata(ctx, s.pool.md), &sppb.CreateSessionRequest{Database: s.pool.db}) + if e != nil { + return e + } + sid = session.Name + return nil + }) + if err != nil { + return false + } + s.pool.mu.Lock() + s.mu.Lock() + var recycle bool + if s.valid && s.idleList != nil { + // session is in idle list, refresh its session id. + sid, s.id = s.id, sid + if s.tx != nil { + s.tx = nil + s.pool.idleWriteList.Remove(s.idleList) + // We need to put this session back into the pool. + recycle = true + } + } + s.mu.Unlock() + s.pool.mu.Unlock() + if recycle { + s.pool.recycle(s) + } + // If we fail to explicitly destroy the session, it will be eventually garbage collected by + // Cloud Spanner. + if err = runRetryable(ctx, func(ctx context.Context) error { + _, e := s.client.DeleteSession(contextWithOutgoingMetadata(ctx, s.pool.md), &sppb.DeleteSessionRequest{Name: sid}) + return e + }); err != nil && log.V(2) { + log.Warningf("Failed to delete session %v. Error: %v", sid, err) + } + return true +} + +// setHcIndex atomically sets the session's index in the healthcheck queue and returns the old index. +func (s *session) setHcIndex(i int) int { + s.mu.Lock() + defer s.mu.Unlock() + oi := s.hcIndex + s.hcIndex = i + return oi +} + +// setIdleList atomically sets the session's idle list link and returns the old link. +func (s *session) setIdleList(le *list.Element) *list.Element { + s.mu.Lock() + defer s.mu.Unlock() + old := s.idleList + s.idleList = le + return old +} + +// invalidate marks a session as invalid and returns the old validity. +func (s *session) invalidate() bool { + s.mu.Lock() + defer s.mu.Unlock() + ov := s.valid + s.valid = false + return ov +} + +// setNextCheck sets the timestamp for next healthcheck on the session. +func (s *session) setNextCheck(t time.Time) { + s.mu.Lock() + defer s.mu.Unlock() + s.nextCheck = t +} + +// setTransactionID sets the transaction id in the session +func (s *session) setTransactionID(tx transactionID) { + s.mu.Lock() + defer s.mu.Unlock() + s.tx = tx +} + +// getID returns the session ID which uniquely identifies the session in Cloud Spanner. +func (s *session) getID() string { + s.mu.Lock() + defer s.mu.Unlock() + return s.id +} + +// getHcIndex returns the session's index into the global healthcheck priority queue. +func (s *session) getHcIndex() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.hcIndex +} + +// getIdleList returns the session's link in its home session pool's idle list. +func (s *session) getIdleList() *list.Element { + s.mu.Lock() + defer s.mu.Unlock() + return s.idleList +} + +// getNextCheck returns the timestamp for next healthcheck on the session. +func (s *session) getNextCheck() time.Time { + s.mu.Lock() + defer s.mu.Unlock() + return s.nextCheck +} + +// recycle turns the session back to its home session pool. +func (s *session) recycle() { + s.setTransactionID(nil) + if !s.pool.recycle(s) { + // s is rejected by its home session pool because it expired and the session pool is currently having enough number of open sessions. + s.destroy(false) + } +} + +// destroy removes the session from its home session pool, healthcheck queue and Cloud Spanner service. +func (s *session) destroy(isExpire bool) bool { + // Remove s from session pool. + if !s.pool.remove(s, isExpire) { + return false + } + // Unregister s from healthcheck queue. + s.pool.hc.unregister(s) + // Remove s from Cloud Spanner service. + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + // Ignore the error returned by runRetryable because even if we fail to explicitly destroy the session, + // it will be eventually garbage collected by Cloud Spanner. + err := runRetryable(ctx, func(ctx context.Context) error { + _, e := s.client.DeleteSession(ctx, &sppb.DeleteSessionRequest{Name: s.getID()}) + return e + }) + if err != nil && log.V(2) { + log.Warningf("Failed to delete session %v. Error: %v", s.getID(), err) + } + return true +} + +// prepareForWrite prepares the session for write if it is not already in that state. +func (s *session) prepareForWrite(ctx context.Context) error { + if s.isWritePrepared() { + return nil + } + tx, err := beginTransaction(ctx, s.getID(), s.client) + if err != nil { + return err + } + s.setTransactionID(tx) + return nil +} + +// SessionPoolConfig stores configurations of a session pool. +type SessionPoolConfig struct { + // getRPCClient is the caller supplied method for getting a gRPC client to Cloud Spanner, this makes session pool able to use client pooling. + getRPCClient func() (sppb.SpannerClient, error) + // MaxOpened is the maximum number of opened sessions that is allowed by the + // session pool. Default to NumChannels * 100. + MaxOpened uint64 + // MinOpened is the minimum number of opened sessions that the session pool + // tries to maintain. Session pool won't continue to expire sessions if number + // of opened connections drops below MinOpened. However, if session is found + // to be broken, it will still be evicted from session pool, therefore it is + // posssible that the number of opened sessions drops below MinOpened. + MinOpened uint64 + // MaxSessionAge is the maximum duration that a session can be reused, zero + // means session pool will never expire sessions. + MaxSessionAge time.Duration + // MaxBurst is the maximum number of concurrent session creation requests. Defaults to 10. + MaxBurst uint64 + // WriteSessions is the fraction of sessions we try to keep prepared for write. + WriteSessions float64 + // HealthCheckWorkers is number of workers used by health checker for this pool. + HealthCheckWorkers int + // HealthCheckInterval is how often the health checker pings a session. + HealthCheckInterval time.Duration +} + +// errNoRPCGetter returns error for SessionPoolConfig missing getRPCClient method. +func errNoRPCGetter() error { + return spannerErrorf(codes.InvalidArgument, "require SessionPoolConfig.getRPCClient != nil, got nil") +} + +// errMinOpenedGTMapOpened returns error for SessionPoolConfig.MaxOpened < SessionPoolConfig.MinOpened when SessionPoolConfig.MaxOpened is set. +func errMinOpenedGTMaxOpened(spc *SessionPoolConfig) error { + return spannerErrorf(codes.InvalidArgument, + "require SessionPoolConfig.MaxOpened >= SessionPoolConfig.MinOpened, got %v and %v", spc.MaxOpened, spc.MinOpened) +} + +// validate verifies that the SessionPoolConfig is good for use. +func (spc *SessionPoolConfig) validate() error { + if spc.getRPCClient == nil { + return errNoRPCGetter() + } + if spc.MinOpened > spc.MaxOpened && spc.MaxOpened > 0 { + return errMinOpenedGTMaxOpened(spc) + } + return nil +} + +// sessionPool creates and caches Cloud Spanner sessions. +type sessionPool struct { + // mu protects sessionPool from concurrent access. + mu sync.Mutex + // valid marks the validity of the session pool. + valid bool + // db is the database name that all sessions in the pool are associated with. + db string + // idleList caches idle session IDs. Session IDs in this list can be allocated for use. + idleList list.List + // idleWriteList caches idle sessions which have been prepared for write. + idleWriteList list.List + // mayGetSession is for broadcasting that session retrival/creation may proceed. + mayGetSession chan struct{} + // numOpened is the total number of open sessions from the session pool. + numOpened uint64 + // createReqs is the number of ongoing session creation requests. + createReqs uint64 + // prepareReqs is the number of ongoing session preparation request. + prepareReqs uint64 + // configuration of the session pool. + SessionPoolConfig + // Metadata to be sent with each request + md metadata.MD + // hc is the health checker + hc *healthChecker +} + +// newSessionPool creates a new session pool. +func newSessionPool(db string, config SessionPoolConfig, md metadata.MD) (*sessionPool, error) { + if err := config.validate(); err != nil { + return nil, err + } + pool := &sessionPool{ + db: db, + valid: true, + mayGetSession: make(chan struct{}), + SessionPoolConfig: config, + md: md, + } + if config.HealthCheckWorkers == 0 { + // With 10 workers and assuming average latency of 5 ms for BeginTransaction, we will be able to + // prepare 2000 tx/sec in advance. If the rate of takeWriteSession is more than that, it will + // degrade to doing BeginTransaction inline. + // TODO: consider resizing the worker pool dynamically according to the load. + config.HealthCheckWorkers = 10 + } + if config.HealthCheckInterval == 0 { + config.HealthCheckInterval = 5 * time.Minute + } + // On GCE VM, within the same region an healthcheck ping takes on average 10ms to finish, given a 5 minutes interval and + // 10 healthcheck workers, a healthChecker can effectively mantain 100 checks_per_worker/sec * 10 workers * 300 seconds = 300K sessions. + pool.hc = newHealthChecker(config.HealthCheckInterval, config.HealthCheckWorkers, pool) + return pool, nil +} + +// isValid checks if the session pool is still valid. +func (p *sessionPool) isValid() bool { + if p == nil { + return false + } + p.mu.Lock() + defer p.mu.Unlock() + return p.valid +} + +// close marks the session pool as closed. +func (p *sessionPool) close() { + if p == nil { + return + } + p.mu.Lock() + if !p.valid { + p.mu.Unlock() + return + } + p.valid = false + p.mu.Unlock() + p.hc.close() + // destroy all the sessions + p.hc.mu.Lock() + allSessions := make([]*session, len(p.hc.queue.sessions)) + copy(allSessions, p.hc.queue.sessions) + p.hc.mu.Unlock() + for _, s := range allSessions { + s.destroy(false) + } +} + +// errInvalidSessionPool returns error for using an invalid session pool. +func errInvalidSessionPool() error { + return spannerErrorf(codes.InvalidArgument, "invalid session pool") +} + +// errGetSessionTimeout returns error for context timeout during sessionPool.take(). +func errGetSessionTimeout() error { + return spannerErrorf(codes.Canceled, "timeout / context canceled during getting session") +} + +// shouldPrepareWrite returns true if we should prepare more sessions for write. +func (p *sessionPool) shouldPrepareWrite() bool { + return float64(p.numOpened)*p.WriteSessions > float64(p.idleWriteList.Len()+int(p.prepareReqs)) +} + +func (p *sessionPool) createSession(ctx context.Context) (*session, error) { + doneCreate := func(done bool) { + p.mu.Lock() + if !done { + // Session creation failed, give budget back. + p.numOpened-- + } + p.createReqs-- + // Notify other waiters blocking on session creation. + close(p.mayGetSession) + p.mayGetSession = make(chan struct{}) + p.mu.Unlock() + } + sc, err := p.getRPCClient() + if err != nil { + doneCreate(false) + return nil, err + } + var s *session + err = runRetryable(ctx, func(ctx context.Context) error { + sid, e := sc.CreateSession(ctx, &sppb.CreateSessionRequest{Database: p.db}) + if e != nil { + return e + } + // If no error, construct the new session. + s = &session{valid: true, client: sc, id: sid.Name, pool: p, createTime: time.Now(), md: p.md} + p.hc.register(s) + return nil + }) + if err != nil { + doneCreate(false) + // Should return error directly because of the previous retries on CreateSession RPC. + return nil, err + } + doneCreate(true) + return s, nil +} + +func (p *sessionPool) isHealthy(s *session) bool { + if s.getNextCheck().Add(2 * p.hc.getInterval()).Before(time.Now()) { + // TODO: figure out if we need to schedule a new healthcheck worker here. + if err := s.ping(); shouldDropSession(err) { + // The session is already bad, continue to fetch/create a new one. + s.destroy(false) + return false + } + p.hc.scheduledHC(s) + } + return true +} + +// take returns a cached session if there are available ones; if there isn't any, it tries to allocate a new one. +// Session returned by take should be used for read operations. +func (p *sessionPool) take(ctx context.Context) (*sessionHandle, error) { + ctx = contextWithOutgoingMetadata(ctx, p.md) + for { + var ( + s *session + err error + ) + + p.mu.Lock() + if !p.valid { + p.mu.Unlock() + return nil, errInvalidSessionPool() + } + if p.idleList.Len() > 0 { + // Idle sessions are available, get one from the top of the idle list. + s = p.idleList.Remove(p.idleList.Front()).(*session) + } else if p.idleWriteList.Len() > 0 { + s = p.idleWriteList.Remove(p.idleWriteList.Front()).(*session) + } + if s != nil { + s.setIdleList(nil) + p.mu.Unlock() + // From here, session is no longer in idle list, so healthcheck workers won't destroy it. + // If healthcheck workers failed to schedule healthcheck for the session timely, do the check here. + // Because session check is still much cheaper than session creation, they should be reused as much as possible. + if !p.isHealthy(s) { + continue + } + return &sessionHandle{session: s}, nil + } + // Idle list is empty, block if session pool has reached max session creation concurrency or max number of open sessions. + if (p.MaxOpened > 0 && p.numOpened >= p.MaxOpened) || (p.MaxBurst > 0 && p.createReqs >= p.MaxBurst) { + mayGetSession := p.mayGetSession + p.mu.Unlock() + select { + case <-ctx.Done(): + return nil, errGetSessionTimeout() + case <-mayGetSession: + } + continue + } + // Take budget before the actual session creation. + p.numOpened++ + p.createReqs++ + p.mu.Unlock() + if s, err = p.createSession(ctx); err != nil { + return nil, toSpannerError(err) + } + return &sessionHandle{session: s}, nil + } +} + +// takeWriteSession returns a write prepared cached session if there are available ones; if there isn't any, it tries to allocate a new one. +// Session returned should be used for read write transactions. +func (p *sessionPool) takeWriteSession(ctx context.Context) (*sessionHandle, error) { + ctx = contextWithOutgoingMetadata(ctx, p.md) + for { + var ( + s *session + err error + ) + + p.mu.Lock() + if !p.valid { + p.mu.Unlock() + return nil, errInvalidSessionPool() + } + if p.idleWriteList.Len() > 0 { + // Idle sessions are available, get one from the top of the idle list. + s = p.idleWriteList.Remove(p.idleWriteList.Front()).(*session) + } else if p.idleList.Len() > 0 { + s = p.idleList.Remove(p.idleList.Front()).(*session) + } + if s != nil { + s.setIdleList(nil) + p.mu.Unlock() + // From here, session is no longer in idle list, so healthcheck workers won't destroy it. + // If healthcheck workers failed to schedule healthcheck for the session timely, do the check here. + // Because session check is still much cheaper than session creation, they should be reused as much as possible. + if !p.isHealthy(s) { + continue + } + if !s.isWritePrepared() { + if err = s.prepareForWrite(ctx); err != nil { + return nil, toSpannerError(err) + } + } + return &sessionHandle{session: s}, nil + } + // Idle list is empty, block if session pool has reached max session creation concurrency or max number of open sessions. + if (p.MaxOpened > 0 && p.numOpened >= p.MaxOpened) || (p.MaxBurst > 0 && p.createReqs >= p.MaxBurst) { + mayGetSession := p.mayGetSession + p.mu.Unlock() + select { + case <-ctx.Done(): + return nil, errGetSessionTimeout() + case <-mayGetSession: + } + continue + } + + // Take budget before the actual session creation. + p.numOpened++ + p.createReqs++ + p.mu.Unlock() + if s, err = p.createSession(ctx); err != nil { + return nil, toSpannerError(err) + } + if err = s.prepareForWrite(ctx); err != nil { + return nil, toSpannerError(err) + } + return &sessionHandle{session: s}, nil + } +} + +// recycle puts session s back to the session pool's idle list, it returns true if the session pool successfully recycles session s. +func (p *sessionPool) recycle(s *session) bool { + p.mu.Lock() + defer p.mu.Unlock() + if !s.isValid() || !p.valid { + // Reject the session if session is invalid or pool itself is invalid. + return false + } + if p.MaxSessionAge != 0 && s.createTime.Add(p.MaxSessionAge).Before(time.Now()) && p.numOpened > p.MinOpened { + // session expires and number of opened sessions exceeds MinOpened, let the session destroy itself. + return false + } + // Hot sessions will be converging at the front of the list, cold sessions will be evicted by healthcheck workers. + if s.isWritePrepared() { + s.setIdleList(p.idleWriteList.PushFront(s)) + } else { + s.setIdleList(p.idleList.PushFront(s)) + } + // Broadcast that a session has been returned to idle list. + close(p.mayGetSession) + p.mayGetSession = make(chan struct{}) + return true +} + +// remove atomically removes session s from the session pool and invalidates s. +// If isExpire == true, the removal is triggered by session expiration and in such cases, only idle sessions can be removed. +func (p *sessionPool) remove(s *session, isExpire bool) bool { + p.mu.Lock() + defer p.mu.Unlock() + if isExpire && (p.numOpened <= p.MinOpened || s.getIdleList() == nil) { + // Don't expire session if the session is not in idle list (in use), or if number of open sessions is going below p.MinOpened. + return false + } + ol := s.setIdleList(nil) + // If the session is in the idlelist, remove it. + if ol != nil { + // Remove from whichever list it is in. + p.idleList.Remove(ol) + p.idleWriteList.Remove(ol) + } + if s.invalidate() { + // Decrease the number of opened sessions. + p.numOpened-- + // Broadcast that a session has been destroyed. + close(p.mayGetSession) + p.mayGetSession = make(chan struct{}) + return true + } + return false +} + +// hcHeap implements heap.Interface. It is used to create the priority queue for session healthchecks. +type hcHeap struct { + sessions []*session +} + +// Len impelemnts heap.Interface.Len. +func (h hcHeap) Len() int { + return len(h.sessions) +} + +// Less implements heap.Interface.Less. +func (h hcHeap) Less(i, j int) bool { + return h.sessions[i].getNextCheck().Before(h.sessions[j].getNextCheck()) +} + +// Swap implements heap.Interface.Swap. +func (h hcHeap) Swap(i, j int) { + h.sessions[i], h.sessions[j] = h.sessions[j], h.sessions[i] + h.sessions[i].setHcIndex(i) + h.sessions[j].setHcIndex(j) +} + +// Push implements heap.Interface.Push. +func (h *hcHeap) Push(s interface{}) { + ns := s.(*session) + ns.setHcIndex(len(h.sessions)) + h.sessions = append(h.sessions, ns) +} + +// Pop implements heap.Interface.Pop. +func (h *hcHeap) Pop() interface{} { + old := h.sessions + n := len(old) + s := old[n-1] + h.sessions = old[:n-1] + s.setHcIndex(-1) + return s +} + +// healthChecker performs periodical healthchecks on registered sessions. +type healthChecker struct { + // mu protects concurrent access to hcQueue. + mu sync.Mutex + // queue is the priority queue for session healthchecks. Sessions with lower nextCheck rank higher in the queue. + queue hcHeap + // interval is the average interval between two healthchecks on a session. + interval time.Duration + // workers is the number of concurrent healthcheck workers. + workers int + // waitWorkers waits for all healthcheck workers to exit + waitWorkers sync.WaitGroup + // pool is the underlying session pool. + pool *sessionPool + // closed marks if a healthChecker has been closed. + closed bool +} + +// newHealthChecker initializes new instance of healthChecker. +func newHealthChecker(interval time.Duration, workers int, pool *sessionPool) *healthChecker { + if workers <= 0 { + workers = 1 + } + hc := &healthChecker{ + interval: interval, + workers: workers, + pool: pool, + } + for i := 0; i < hc.workers; i++ { + hc.waitWorkers.Add(1) + go hc.worker(i) + } + return hc +} + +// close closes the healthChecker and waits for all healthcheck workers to exit. +func (hc *healthChecker) close() { + hc.mu.Lock() + hc.closed = true + hc.mu.Unlock() + hc.waitWorkers.Wait() +} + +// isClosing checks if a healthChecker is already closing. +func (hc *healthChecker) isClosing() bool { + hc.mu.Lock() + defer hc.mu.Unlock() + return hc.closed +} + +// getInterval gets the healthcheck interval. +func (hc *healthChecker) getInterval() time.Duration { + hc.mu.Lock() + defer hc.mu.Unlock() + return hc.interval +} + +// scheduledHCLocked schedules next healthcheck on session s with the assumption that hc.mu is being held. +func (hc *healthChecker) scheduledHCLocked(s *session) { + // The next healthcheck will be scheduled after [interval*0.5, interval*1.5) nanoseconds. + nsFromNow := rand.Int63n(int64(hc.interval)) + int64(hc.interval)/2 + s.setNextCheck(time.Now().Add(time.Duration(nsFromNow))) + if hi := s.getHcIndex(); hi != -1 { + // Session is still being tracked by healthcheck workers. + heap.Fix(&hc.queue, hi) + } +} + +// scheduledHC schedules next healthcheck on session s. It is safe to be called concurrently. +func (hc *healthChecker) scheduledHC(s *session) { + hc.mu.Lock() + defer hc.mu.Unlock() + hc.scheduledHCLocked(s) +} + +// register registers a session with healthChecker for periodical healthcheck. +func (hc *healthChecker) register(s *session) { + hc.mu.Lock() + defer hc.mu.Unlock() + hc.scheduledHCLocked(s) + heap.Push(&hc.queue, s) +} + +// unregister unregisters a session from healthcheck queue. +func (hc *healthChecker) unregister(s *session) { + hc.mu.Lock() + defer hc.mu.Unlock() + oi := s.setHcIndex(-1) + if oi >= 0 { + heap.Remove(&hc.queue, oi) + } +} + +// markDone marks that health check for session has been performed. +func (hc *healthChecker) markDone(s *session) { + hc.mu.Lock() + defer hc.mu.Unlock() + s.checkingHealth = false +} + +// healthCheck checks the health of the session and pings it if needed. +func (hc *healthChecker) healthCheck(s *session) { + defer hc.markDone(s) + if !s.pool.isValid() { + // Session pool is closed, perform a garbage collection. + s.destroy(false) + return + } + if s.pool.MaxSessionAge != 0 && s.createTime.Add(s.pool.MaxSessionAge).Before(time.Now()) { + // Session reaches its maximum age, retire it. Failing that try to refresh it. + if s.destroy(true) || !s.refreshIdle() { + return + } + } + if err := s.ping(); shouldDropSession(err) { + // Ping failed, destroy the session. + s.destroy(false) + } +} + +// worker performs the healthcheck on sessions in healthChecker's priority queue. +func (hc *healthChecker) worker(i int) { + if log.V(2) { + log.Infof("Starting health check worker %v", i) + } + // Returns a session which we should ping to keep it alive. + getNextForPing := func() *session { + hc.pool.mu.Lock() + defer hc.pool.mu.Unlock() + hc.mu.Lock() + defer hc.mu.Unlock() + if hc.queue.Len() <= 0 { + // Queue is empty. + return nil + } + s := hc.queue.sessions[0] + if s.getNextCheck().After(time.Now()) && hc.pool.valid { + // All sessions have been checked recently. + return nil + } + hc.scheduledHCLocked(s) + if !s.checkingHealth { + s.checkingHealth = true + return s + } + return nil + } + + // Returns a session which we should prepare for write. + getNextForTx := func() *session { + hc.pool.mu.Lock() + defer hc.pool.mu.Unlock() + if hc.pool.shouldPrepareWrite() { + if hc.pool.idleList.Len() > 0 && hc.pool.valid { + hc.mu.Lock() + defer hc.mu.Unlock() + if hc.pool.idleList.Front().Value.(*session).checkingHealth { + return nil + } + session := hc.pool.idleList.Remove(hc.pool.idleList.Front()).(*session) + session.checkingHealth = true + hc.pool.prepareReqs++ + return session + } + } + return nil + } + + for { + if hc.isClosing() { + if log.V(2) { + log.Infof("Closing health check worker %v", i) + } + // Exit when the pool has been closed and all sessions have been destroyed + // or when health checker has been closed. + hc.waitWorkers.Done() + return + } + ws := getNextForTx() + if ws != nil { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + ws.prepareForWrite(contextWithOutgoingMetadata(ctx, hc.pool.md)) + hc.pool.recycle(ws) + hc.pool.mu.Lock() + hc.pool.prepareReqs-- + hc.pool.mu.Unlock() + hc.markDone(ws) + } + rs := getNextForPing() + if rs == nil { + if ws == nil { + // No work to be done so sleep to avoid burning cpu + pause := int64(100 * time.Millisecond) + if pause > int64(hc.interval) { + pause = int64(hc.interval) + } + <-time.After(time.Duration(rand.Int63n(pause) + pause/2)) + } + continue + } + hc.healthCheck(rs) + } +} + +// shouldDropSession returns true if a particular error leads to the removal of a session +func shouldDropSession(err error) bool { + if err == nil { + return false + } + // If a Cloud Spanner can no longer locate the session (for example, if session is garbage collected), then caller + // should not try to return the session back into the session pool. + // TODO: once gRPC can return auxilary error information, stop parsing the error message. + if ErrCode(err) == codes.NotFound && strings.Contains(ErrDesc(err), "Session not found:") { + return true + } + return false +} diff --git a/vendor/cloud.google.com/go/spanner/session_test.go b/vendor/cloud.google.com/go/spanner/session_test.go new file mode 100644 index 000000000..7c3d4f88e --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/session_test.go @@ -0,0 +1,792 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spanner + +import ( + "container/heap" + "math/rand" + "reflect" + "sync" + "testing" + "time" + + "golang.org/x/net/context" + + "cloud.google.com/go/spanner/internal/testutil" + sppb "google.golang.org/genproto/googleapis/spanner/v1" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" +) + +// setup prepares test environment for regular session pool tests. +func setup(t *testing.T, spc SessionPoolConfig) (sp *sessionPool, sc *testutil.MockCloudSpannerClient, cancel func()) { + sc = testutil.NewMockCloudSpannerClient(t) + spc.getRPCClient = func() (sppb.SpannerClient, error) { + return sc, nil + } + spc.HealthCheckInterval = 50 * time.Millisecond + sp, err := newSessionPool("mockdb", spc, nil) + if err != nil { + t.Fatalf("cannot create session pool: %v", err) + } + cancel = func() { + sp.close() + } + return +} + +// TestSessionCreation tests session creation during sessionPool.Take(). +func TestSessionCreation(t *testing.T) { + sp, sc, cancel := setup(t, SessionPoolConfig{}) + defer cancel() + // Take three sessions from session pool, this should trigger session pool to create three new sessions. + shs := make([]*sessionHandle, 3) + // gotDs holds the unique sessions taken from session pool. + gotDs := map[string]bool{} + for i := 0; i < len(shs); i++ { + var err error + shs[i], err = sp.take(context.Background()) + if err != nil { + t.Errorf("failed to get session(%v): %v", i, err) + } + gotDs[shs[i].getID()] = true + } + if len(gotDs) != len(shs) { + t.Errorf("session pool created %v sessions, want %v", len(gotDs), len(shs)) + } + if wantDs := sc.DumpSessions(); !reflect.DeepEqual(gotDs, wantDs) { + t.Errorf("session pool creates sessions %v, want %v", gotDs, wantDs) + } + // Verify that created sessions are recorded correctly in session pool. + sp.mu.Lock() + if int(sp.numOpened) != len(shs) { + t.Errorf("session pool reports %v open sessions, want %v", sp.numOpened, len(shs)) + } + if sp.createReqs != 0 { + t.Errorf("session pool reports %v session create requests, want 0", int(sp.createReqs)) + } + sp.mu.Unlock() + // Verify that created sessions are tracked correctly by healthcheck queue. + hc := sp.hc + hc.mu.Lock() + if hc.queue.Len() != len(shs) { + t.Errorf("healthcheck queue length = %v, want %v", hc.queue.Len(), len(shs)) + } + for _, s := range hc.queue.sessions { + if !gotDs[s.getID()] { + t.Errorf("session %v is in healthcheck queue, but it is not created by session pool", s.getID()) + } + } + hc.mu.Unlock() +} + +// TestTakeFromIdleList tests taking sessions from session pool's idle list. +func TestTakeFromIdleList(t *testing.T) { + sp, sc, cancel := setup(t, SessionPoolConfig{}) + defer cancel() + // Take ten sessions from session pool and recycle them. + shs := make([]*sessionHandle, 10) + for i := 0; i < len(shs); i++ { + var err error + shs[i], err = sp.take(context.Background()) + if err != nil { + t.Errorf("failed to get session(%v): %v", i, err) + } + } + for i := 0; i < len(shs); i++ { + shs[i].recycle() + } + // Further session requests from session pool won't cause mockclient to create more sessions. + wantSessions := sc.DumpSessions() + // Take ten sessions from session pool again, this time all sessions should come from idle list. + gotSessions := map[string]bool{} + for i := 0; i < len(shs); i++ { + sh, err := sp.take(context.Background()) + if err != nil { + t.Errorf("cannot take session from session pool: %v", err) + } + gotSessions[sh.getID()] = true + } + if len(gotSessions) != 10 { + t.Errorf("got %v unique sessions, want 10", len(gotSessions)) + } + if !reflect.DeepEqual(gotSessions, wantSessions) { + t.Errorf("got sessions: %v, want %v", gotSessions, wantSessions) + } +} + +// TesttakeWriteSessionFromIdleList tests taking write sessions from session pool's idle list. +func TestTakeWriteSessionFromIdleList(t *testing.T) { + sp, sc, cancel := setup(t, SessionPoolConfig{}) + defer cancel() + act := testutil.NewAction("Begin", nil) + acts := make([]testutil.Action, 20) + for i := 0; i < len(acts); i++ { + acts[i] = act + } + sc.SetActions(acts...) + // Take ten sessions from session pool and recycle them. + shs := make([]*sessionHandle, 10) + for i := 0; i < len(shs); i++ { + var err error + shs[i], err = sp.takeWriteSession(context.Background()) + if err != nil { + t.Errorf("failed to get session(%v): %v", i, err) + } + } + for i := 0; i < len(shs); i++ { + shs[i].recycle() + } + // Further session requests from session pool won't cause mockclient to create more sessions. + wantSessions := sc.DumpSessions() + // Take ten sessions from session pool again, this time all sessions should come from idle list. + gotSessions := map[string]bool{} + for i := 0; i < len(shs); i++ { + sh, err := sp.takeWriteSession(context.Background()) + if err != nil { + t.Errorf("cannot take session from session pool: %v", err) + } + gotSessions[sh.getID()] = true + } + if len(gotSessions) != 10 { + t.Errorf("got %v unique sessions, want 10", len(gotSessions)) + } + if !reflect.DeepEqual(gotSessions, wantSessions) { + t.Errorf("got sessions: %v, want %v", gotSessions, wantSessions) + } +} + +// TestTakeFromIdleListChecked tests taking sessions from session pool's idle list, but with a extra ping check. +func TestTakeFromIdleListChecked(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + sp, sc, cancel := setup(t, SessionPoolConfig{}) + defer cancel() + // Stop healthcheck workers to simulate slow pings. + sp.hc.close() + // Create a session and recycle it. + sh, err := sp.take(context.Background()) + if err != nil { + t.Errorf("failed to get session: %v", err) + } + wantSid := sh.getID() + sh.recycle() + <-time.After(time.Second) + // Two back-to-back session requests, both of them should return the same session created before and + // none of them should trigger a session ping. + for i := 0; i < 2; i++ { + // Take the session from the idle list and recycle it. + sh, err = sp.take(context.Background()) + if err != nil { + t.Errorf("%v - failed to get session: %v", i, err) + } + if gotSid := sh.getID(); gotSid != wantSid { + t.Errorf("%v - got session id: %v, want %v", i, gotSid, wantSid) + } + // The two back-to-back session requests shouldn't trigger any session pings because sessionPool.Take + // reschedules the next healthcheck. + if got, want := sc.DumpPings(), ([]string{wantSid}); !reflect.DeepEqual(got, want) { + t.Errorf("%v - got ping session requests: %v, want %v", i, got, want) + } + sh.recycle() + } + // Inject session error to mockclient, and take the session from the session pool, the old session should be destroyed and + // the session pool will create a new session. + sc.InjectError("GetSession", grpc.Errorf(codes.NotFound, "Session not found:")) + // Delay to trigger sessionPool.Take to ping the session. + <-time.After(time.Second) + sh, err = sp.take(context.Background()) + if err != nil { + t.Errorf("failed to get session: %v", err) + } + ds := sc.DumpSessions() + if len(ds) != 1 { + t.Errorf("dumped sessions from mockclient: %v, want %v", ds, sh.getID()) + } + if sh.getID() == wantSid { + t.Errorf("sessionPool.Take still returns the same session %v, want it to create a new one", wantSid) + } +} + +// TestTakeFromIdleWriteListChecked tests taking sessions from session pool's idle list, but with a extra ping check. +func TestTakeFromIdleWriteListChecked(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + sp, sc, cancel := setup(t, SessionPoolConfig{}) + defer cancel() + sc.MakeNice() + // Stop healthcheck workers to simulate slow pings. + sp.hc.close() + // Create a session and recycle it. + sh, err := sp.takeWriteSession(context.Background()) + if err != nil { + t.Errorf("failed to get session: %v", err) + } + wantSid := sh.getID() + sh.recycle() + <-time.After(time.Second) + // Two back-to-back session requests, both of them should return the same session created before and + // none of them should trigger a session ping. + for i := 0; i < 2; i++ { + // Take the session from the idle list and recycle it. + sh, err = sp.takeWriteSession(context.Background()) + if err != nil { + t.Errorf("%v - failed to get session: %v", i, err) + } + if gotSid := sh.getID(); gotSid != wantSid { + t.Errorf("%v - got session id: %v, want %v", i, gotSid, wantSid) + } + // The two back-to-back session requests shouldn't trigger any session pings because sessionPool.Take + // reschedules the next healthcheck. + if got, want := sc.DumpPings(), ([]string{wantSid}); !reflect.DeepEqual(got, want) { + t.Errorf("%v - got ping session requests: %v, want %v", i, got, want) + } + sh.recycle() + } + // Inject session error to mockclient, and take the session from the session pool, the old session should be destroyed and + // the session pool will create a new session. + sc.InjectError("GetSession", grpc.Errorf(codes.NotFound, "Session not found:")) + // Delay to trigger sessionPool.Take to ping the session. + <-time.After(time.Second) + sh, err = sp.takeWriteSession(context.Background()) + if err != nil { + t.Errorf("failed to get session: %v", err) + } + ds := sc.DumpSessions() + if len(ds) != 1 { + t.Errorf("dumped sessions from mockclient: %v, want %v", ds, sh.getID()) + } + if sh.getID() == wantSid { + t.Errorf("sessionPool.Take still returns the same session %v, want it to create a new one", wantSid) + } +} + +// TestMaxOpenedSessions tests max open sessions constraint. +func TestMaxOpenedSessions(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + sp, _, cancel := setup(t, SessionPoolConfig{MaxOpened: 1}) + defer cancel() + sh1, err := sp.take(context.Background()) + if err != nil { + t.Errorf("cannot take session from session pool: %v", err) + } + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + // Session request will timeout due to the max open sessions constraint. + sh2, gotErr := sp.take(ctx) + if wantErr := errGetSessionTimeout(); !reflect.DeepEqual(gotErr, wantErr) { + t.Errorf("the second session retrival returns error %v, want %v", gotErr, wantErr) + } + go func() { + <-time.After(time.Second) + // destroy the first session to allow the next session request to proceed. + sh1.destroy() + }() + // Now session request can be processed because the first session will be destroyed. + sh2, err = sp.take(context.Background()) + if err != nil { + t.Errorf("after the first session is destroyed, session retrival still returns error %v, want nil", err) + } + if !sh2.session.isValid() || sh2.getID() == "" { + t.Errorf("got invalid session: %v", sh2.session) + } +} + +// TestMinOpenedSessions tests min open session constraint. +func TestMinOpenedSessions(t *testing.T) { + sp, _, cancel := setup(t, SessionPoolConfig{MinOpened: 1}) + defer cancel() + // Take ten sessions from session pool and recycle them. + var ss []*session + var shs []*sessionHandle + for i := 0; i < 10; i++ { + sh, err := sp.take(context.Background()) + if err != nil { + t.Errorf("failed to get session(%v): %v", i, err) + } + ss = append(ss, sh.session) + shs = append(shs, sh) + sh.recycle() + } + for _, sh := range shs { + sh.recycle() + } + // Simulate session expiration. + for _, s := range ss { + s.destroy(true) + } + sp.mu.Lock() + defer sp.mu.Unlock() + // There should be still one session left in idle list due to the min open sessions constraint. + if sp.idleList.Len() != 1 { + t.Errorf("got %v sessions in idle list, want 1", sp.idleList.Len()) + } +} + +// TestMaxBurst tests max burst constraint. +func TestMaxBurst(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + sp, sc, cancel := setup(t, SessionPoolConfig{MaxBurst: 1}) + defer cancel() + // Will cause session creation RPC to be retried forever. + sc.InjectError("CreateSession", grpc.Errorf(codes.Unavailable, "try later")) + // This session request will never finish until the injected error is cleared. + go sp.take(context.Background()) + // Poll for the execution of the first session request. + for { + sp.mu.Lock() + cr := sp.createReqs + sp.mu.Unlock() + if cr == 0 { + <-time.After(time.Second) + continue + } + // The first session request is being executed. + break + } + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + sh, gotErr := sp.take(ctx) + // Since MaxBurst == 1, the second session request should block. + if wantErr := errGetSessionTimeout(); !reflect.DeepEqual(gotErr, wantErr) { + t.Errorf("session retrival returns error %v, want %v", gotErr, wantErr) + } + // Let the first session request succeed. + sc.InjectError("CreateSession", nil) + // Now new session request can proceed because the first session request will eventually succeed. + sh, err := sp.take(context.Background()) + if err != nil { + t.Errorf("session retrival returns error %v, want nil", err) + } + if !sh.session.isValid() || sh.getID() == "" { + t.Errorf("got invalid session: %v", sh.session) + } +} + +// TestSessionrecycle tests recycling sessions. +func TestSessionRecycle(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + sp, _, cancel := setup(t, SessionPoolConfig{MaxSessionAge: 100 * time.Millisecond, MinOpened: 1}) + // Healthcheck is explicitly turned off in this test because it might aggressively expire sessions in idle list. + sp.hc.close() + defer cancel() + var ss []*session + shs := make([]*sessionHandle, 2) + for i := 0; i < len(shs); i++ { + var err error + shs[i], err = sp.take(context.Background()) + if err != nil { + t.Errorf("cannot get the session %v: %v", i, err) + } + ss = append(ss, shs[i].session) + } + // recycle the first session immediately. + shs[0].recycle() + // Let the second session expire. + <-time.After(time.Second) + // recycle the second session. + shs[1].recycle() + // Now the first session should be still valid, but the second session should have been destroyed. + if !ss[0].isValid() { + t.Errorf("the first session (%v) is invalid, want it to be valid", ss[0]) + } + if ss[1].isValid() { + t.Errorf("the second session (%v) is valid, want it to be invalid", ss[1]) + } +} + +// TestSessionDestroy tests destroying sessions. +func TestSessionDestroy(t *testing.T) { + sp, _, cancel := setup(t, SessionPoolConfig{MinOpened: 1}) + defer cancel() + sh, err := sp.take(context.Background()) + if err != nil { + t.Errorf("cannot get session from session pool: %v", err) + } + s := sh.session + sh.recycle() + if d := s.destroy(true); d || !s.isValid() { + // Session should be remaining because of min open sessions constraint. + t.Errorf("session %v was destroyed in expiration mode, want it to stay alive", s) + } + if d := s.destroy(false); !d || s.isValid() { + // Session should be destroyed. + t.Errorf("failed to destroy session %s", s) + } +} + +// TestHcHeap tests heap operation on top of hcHeap. +func TestHcHeap(t *testing.T) { + in := []*session{ + &session{nextCheck: time.Unix(10, 0)}, + &session{nextCheck: time.Unix(0, 5)}, + &session{nextCheck: time.Unix(1, 8)}, + &session{nextCheck: time.Unix(11, 7)}, + &session{nextCheck: time.Unix(6, 3)}, + } + want := []*session{ + &session{nextCheck: time.Unix(1, 8), hcIndex: 0}, + &session{nextCheck: time.Unix(6, 3), hcIndex: 1}, + &session{nextCheck: time.Unix(8, 2), hcIndex: 2}, + &session{nextCheck: time.Unix(10, 0), hcIndex: 3}, + &session{nextCheck: time.Unix(11, 7), hcIndex: 4}, + } + hh := hcHeap{} + for _, s := range in { + heap.Push(&hh, s) + } + // Change top of the heap and do a adjustment. + hh.sessions[0].nextCheck = time.Unix(8, 2) + heap.Fix(&hh, 0) + for idx := 0; hh.Len() > 0; idx++ { + got := heap.Pop(&hh).(*session) + want[idx].hcIndex = -1 + if !reflect.DeepEqual(got, want[idx]) { + t.Errorf("%v: heap.Pop returns %v, want %v", idx, got, want[idx]) + } + } +} + +// TestHealthCheckScheduler tests if healthcheck workers can schedule and perform healthchecks properly. +func TestHealthCheckScheduler(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + sp, sc, cancel := setup(t, SessionPoolConfig{}) + defer cancel() + // Create 50 sessions. + ss := []string{} + for i := 0; i < 50; i++ { + sh, err := sp.take(context.Background()) + if err != nil { + t.Errorf("cannot get session from session pool: %v", err) + } + ss = append(ss, sh.getID()) + } + // Sleep for 1s, allowing healthcheck workers to perform some session pings. + <-time.After(time.Second) + dp := sc.DumpPings() + gotPings := map[string]int64{} + for _, p := range dp { + gotPings[p]++ + } + for _, s := range ss { + // The average ping interval is 50ms. + want := int64(time.Second) / int64(50*time.Millisecond) + if got := gotPings[s]; got < want/2 || got > want+want/2 { + t.Errorf("got %v healthchecks on session %v, want it between (%v, %v)", got, s, want/2, want+want/2) + } + } +} + +// Tests that a fractions of sessions are prepared for write by health checker. +func TestWriteSessionsPrepared(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + sp, sc, cancel := setup(t, SessionPoolConfig{WriteSessions: 0.5}) + sc.MakeNice() + defer cancel() + shs := make([]*sessionHandle, 10) + var err error + for i := 0; i < 10; i++ { + shs[i], err = sp.take(context.Background()) + if err != nil { + t.Errorf("cannot get session from session pool: %v", err) + } + } + // Now there are 10 sessions in the pool. Release them. + for _, sh := range shs { + sh.recycle() + } + // Sleep for 1s, allowing healthcheck workers to invoke begin transaction. + <-time.After(time.Second) + wshs := make([]*sessionHandle, 5) + for i := 0; i < 5; i++ { + wshs[i], err = sp.takeWriteSession(context.Background()) + if err != nil { + t.Errorf("cannot get session from session pool: %v", err) + } + if wshs[i].getTransactionID() == nil { + t.Errorf("got nil transaction id from session pool") + } + } + for _, sh := range wshs { + sh.recycle() + } + <-time.After(time.Second) + // Now force creation of 10 more sessions. + shs = make([]*sessionHandle, 20) + for i := 0; i < 20; i++ { + shs[i], err = sp.take(context.Background()) + if err != nil { + t.Errorf("cannot get session from session pool: %v", err) + } + } + // Now there are 20 sessions in the pool. Release them. + for _, sh := range shs { + sh.recycle() + } + <-time.After(time.Second) + if sp.idleWriteList.Len() != 10 { + t.Errorf("Expect 10 write prepared session, got: %d", sp.idleWriteList.Len()) + } +} + +// TestTakeFromWriteQueue tests that sessionPool.take() returns write prepared sessions as well. +func TestTakeFromWriteQueue(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + sp, sc, cancel := setup(t, SessionPoolConfig{MaxOpened: 1, WriteSessions: 1.0}) + sc.MakeNice() + defer cancel() + sh, err := sp.take(context.Background()) + if err != nil { + t.Errorf("cannot get session from session pool: %v", err) + } + sh.recycle() + <-time.After(time.Second) + // The session should now be in write queue but take should also return it. + if sp.idleWriteList.Len() == 0 { + t.Errorf("write queue unexpectedly empty") + } + if sp.idleList.Len() != 0 { + t.Errorf("read queue not empty") + } + sh, err = sp.take(context.Background()) + if err != nil { + t.Errorf("cannot get session from session pool: %v", err) + } + sh.recycle() +} + +// TestSessionHealthCheck tests healthchecking cases. +func TestSessionHealthCheck(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + sp, sc, cancel := setup(t, SessionPoolConfig{MaxSessionAge: 2 * time.Second}) + defer cancel() + // Test pinging sessions. + sh, err := sp.take(context.Background()) + if err != nil { + t.Errorf("cannot get session from session pool: %v", err) + } + <-time.After(time.Second) + pings := sc.DumpPings() + if len(pings) == 0 || pings[0] != sh.getID() { + t.Errorf("healthchecker didn't send any ping to session %v", sh.getID()) + } + // Test expiring sessions. + s := sh.session + sh.recycle() + // Sleep enough long for session in idle list to expire. + <-time.After(2 * time.Second) + if s.isValid() { + t.Errorf("session(%v) is still alive, want it to expire", s) + } + // Test broken session detection. + sh, err = sp.take(context.Background()) + if err != nil { + t.Errorf("cannot get session from session pool: %v", err) + } + sc.InjectError("GetSession", grpc.Errorf(codes.NotFound, "Session not found:")) + // Wait for healthcheck workers to find the broken session and tear it down. + <-time.After(1 * time.Second) + if sh.session.isValid() { + t.Errorf("session(%v) is still alive, want it to be dropped by healthcheck workers", s) + } + sc.InjectError("GetSession", nil) + // Test garbage collection. + sh, err = sp.take(context.Background()) + if err != nil { + t.Errorf("cannot get session from session pool: %v", err) + } + sp.close() + if sh.session.isValid() { + t.Errorf("session(%v) is still alive, want it to be garbage collected", s) + } + // Test session id refresh. + // Recreate the session pool with min open sessions constraint. + sp, err = newSessionPool("mockdb", SessionPoolConfig{ + MaxSessionAge: time.Second, + MinOpened: 1, + getRPCClient: func() (sppb.SpannerClient, error) { + return sc, nil + }, + HealthCheckInterval: 50 * time.Millisecond, + }, nil) + sh, err = sp.take(context.Background()) + if err != nil { + t.Errorf("cannot get session from session pool: %v", err) + } + oid := sh.getID() + s = sh.session + sh.recycle() + <-time.After(2 * time.Second) + nid := s.getID() + if nid == "" || nid == oid { + t.Errorf("healthcheck workers failed to refresh session: oid=%v, nid=%v", oid, nid) + } + if gotDs, wantDs := sc.DumpSessions(), (map[string]bool{nid: true}); !reflect.DeepEqual(gotDs, wantDs) { + t.Errorf("sessions in mockclient: %v, want %v", gotDs, wantDs) + } +} + +// TestStressSessionPool does stress test on session pool by the following concurrent operations: +// 1) Test worker gets a session from the pool. +// 2) Test worker turns a session back into the pool. +// 3) Test worker destroys a session got from the pool. +// 4) Healthcheck retires an old session from the pool's idlelist by refreshing its session id. +// 5) Healthcheck destroys a broken session (because a worker has already destroyed it). +// 6) Test worker closes the session pool. +// +// During the test, it is expected that all sessions that are taken from session pool remains valid and +// when all test workers and healthcheck workers exit, mockclient, session pool and healthchecker should be in consistent state. +func TestStressSessionPool(t *testing.T) { + // Use concurrent workers to test different session pool built from different configurations. + if testing.Short() { + t.SkipNow() + } + for ti, cfg := range []SessionPoolConfig{ + SessionPoolConfig{}, + SessionPoolConfig{MaxSessionAge: 20 * time.Millisecond}, + SessionPoolConfig{MinOpened: 10, MaxOpened: 100}, + SessionPoolConfig{MaxBurst: 50}, + SessionPoolConfig{MaxSessionAge: 20 * time.Millisecond, MinOpened: 10, MaxOpened: 200, MaxBurst: 5}, + SessionPoolConfig{MaxSessionAge: 20 * time.Millisecond, MinOpened: 10, MaxOpened: 200, MaxBurst: 5, WriteSessions: 0.2}, + } { + var wg sync.WaitGroup + // Create a more aggressive session healthchecker to increase test concurrency. + cfg.HealthCheckInterval = 50 * time.Millisecond + cfg.HealthCheckWorkers = 50 + sc := testutil.NewMockCloudSpannerClient(t) + sc.MakeNice() + cfg.getRPCClient = func() (sppb.SpannerClient, error) { + return sc, nil + } + sp, _ := newSessionPool("mockdb", cfg, nil) + for i := 0; i < 100; i++ { + wg.Add(1) + // Schedule a test worker. + go func(idx int, pool *sessionPool, client sppb.SpannerClient) { + defer wg.Done() + // Test worker iterates 1K times and tries different session / session pool operations. + for j := 0; j < 1000; j++ { + if idx%10 == 0 && j >= 900 { + // Close the pool in selected set of workers during the middle of the test. + pool.close() + } + // Take a write sessions ~ 20% of the times. + takeWrite := rand.Intn(5) == 4 + var ( + sh *sessionHandle + gotErr error + ) + if takeWrite { + sh, gotErr = pool.takeWriteSession(context.Background()) + } else { + sh, gotErr = pool.take(context.Background()) + } + if gotErr != nil { + if pool.isValid() { + t.Errorf("%v.%v: pool.take returns error when pool is still valid: %v", ti, idx, gotErr) + } + if wantErr := errInvalidSessionPool(); !reflect.DeepEqual(gotErr, wantErr) { + t.Errorf("%v.%v: got error when pool is closed: %v, want %v", ti, idx, gotErr, wantErr) + } + continue + } + // Verify if session is valid when session pool is valid. Note that if session pool is invalid after sh is taken, + // then sh might be invalidated by healthcheck workers. + if (sh.getID() == "" || sh.session == nil || !sh.session.isValid()) && pool.isValid() { + t.Errorf("%v.%v.%v: pool.take returns invalid session %v", ti, idx, takeWrite, sh.session) + } + if takeWrite && sh.getTransactionID() == nil { + t.Errorf("%v.%v: pool.takeWriteSession returns session %v without transaction", ti, idx, sh.session) + } + if int64(cfg.MaxSessionAge) > 0 && rand.Intn(100) < idx { + // Random sleep before destroying/recycling the session, to give healthcheck worker a chance to step in. + <-time.After(time.Duration(rand.Int63n(int64(cfg.MaxSessionAge)))) + } + if rand.Intn(100) < idx { + // destroy the session. + sh.destroy() + continue + } + // recycle the session. + sh.recycle() + } + }(i, sp, sc) + } + wg.Wait() + sp.hc.close() + // Here the states of healthchecker, session pool and mockclient are stable. + idleSessions := map[string]bool{} + hcSessions := map[string]bool{} + mockSessions := sc.DumpSessions() + // Dump session pool's idle list. + for sl := sp.idleList.Front(); sl != nil; sl = sl.Next() { + s := sl.Value.(*session) + if idleSessions[s.getID()] { + t.Errorf("%v: found duplicated session in idle list: %v", ti, s.getID()) + } + idleSessions[s.getID()] = true + } + for sl := sp.idleWriteList.Front(); sl != nil; sl = sl.Next() { + s := sl.Value.(*session) + if idleSessions[s.getID()] { + t.Errorf("%v: found duplicated session in idle write list: %v", ti, s.getID()) + } + idleSessions[s.getID()] = true + } + if int(sp.numOpened) != len(idleSessions) { + t.Errorf("%v: number of opened sessions (%v) != number of idle sessions (%v)", ti, sp.numOpened, len(idleSessions)) + } + if sp.createReqs != 0 { + t.Errorf("%v: number of pending session creations = %v, want 0", ti, sp.createReqs) + } + // Dump healthcheck queue. + for _, s := range sp.hc.queue.sessions { + if hcSessions[s.getID()] { + t.Errorf("%v: found duplicated session in healthcheck queue: %v", ti, s.getID()) + } + hcSessions[s.getID()] = true + } + // Verify that idleSessions == hcSessions == mockSessions. + if !reflect.DeepEqual(idleSessions, hcSessions) { + t.Errorf("%v: sessions in idle list (%v) != sessions in healthcheck queue (%v)", ti, idleSessions, hcSessions) + } + if !reflect.DeepEqual(hcSessions, mockSessions) { + t.Errorf("%v: sessions in healthcheck queue (%v) != sessions in mockclient (%v)", ti, hcSessions, mockSessions) + } + sp.close() + mockSessions = sc.DumpSessions() + if len(mockSessions) != 0 { + t.Errorf("Found live sessions: %v", mockSessions) + } + } +} diff --git a/vendor/cloud.google.com/go/spanner/spanner_test.go b/vendor/cloud.google.com/go/spanner/spanner_test.go new file mode 100644 index 000000000..ede4e80d2 --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/spanner_test.go @@ -0,0 +1,1234 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spanner + +import ( + "fmt" + "math" + "reflect" + "strings" + "sync" + "testing" + "time" + + "cloud.google.com/go/civil" + "cloud.google.com/go/internal/testutil" + database "cloud.google.com/go/spanner/admin/database/apiv1" + "golang.org/x/net/context" + "google.golang.org/api/iterator" + "google.golang.org/api/option" + "google.golang.org/grpc/codes" + + adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" +) + +var ( + // testProjectID specifies the project used for testing. + // It can be changed by setting environment variable GCLOUD_TESTS_GOLANG_PROJECT_ID. + testProjectID = testutil.ProjID() + // testInstanceID specifies the Cloud Spanner instance used for testing. + testInstanceID = "go-integration-test" + + // client is a spanner.Client. + client *Client + // admin is a spanner.DatabaseAdminClient. + admin *database.DatabaseAdminClient + // db is the path of the testing database. + db string + // dbName is the short name of the testing database. + dbName string +) + +var ( + singerDBStatements = []string{ + `CREATE TABLE Singers ( + SingerId INT64 NOT NULL, + FirstName STRING(1024), + LastName STRING(1024), + SingerInfo BYTES(MAX) + ) PRIMARY KEY (SingerId)`, + `CREATE INDEX SingerByName ON Singers(FirstName, LastName)`, + `CREATE TABLE Accounts ( + AccountId INT64 NOT NULL, + Nickname STRING(100), + Balance INT64 NOT NULL, + ) PRIMARY KEY (AccountId)`, + `CREATE INDEX AccountByNickname ON Accounts(Nickname) STORING (Balance)`, + `CREATE TABLE Types ( + RowID INT64 NOT NULL, + String STRING(MAX), + StringArray ARRAY<STRING(MAX)>, + Bytes BYTES(MAX), + BytesArray ARRAY<BYTES(MAX)>, + Int64a INT64, + Int64Array ARRAY<INT64>, + Bool BOOL, + BoolArray ARRAY<BOOL>, + Float64 FLOAT64, + Float64Array ARRAY<FLOAT64>, + Date DATE, + DateArray ARRAY<DATE>, + Timestamp TIMESTAMP, + TimestampArray ARRAY<TIMESTAMP>, + ) PRIMARY KEY (RowID)`, + } + + readDBStatements = []string{ + `CREATE TABLE TestTable ( + Key STRING(MAX) NOT NULL, + StringValue STRING(MAX) + ) PRIMARY KEY (Key)`, + `CREATE INDEX TestTableByValue ON TestTable(StringValue)`, + `CREATE INDEX TestTableByValueDesc ON TestTable(StringValue DESC)`, + } +) + +type testTableRow struct{ Key, StringValue string } + +// prepare initializes Cloud Spanner testing DB and clients. +func prepare(ctx context.Context, t *testing.T, statements []string) error { + if testing.Short() { + t.Skip("Integration tests skipped in short mode") + } + if testProjectID == "" { + t.Skip("Integration tests skipped: GCLOUD_TESTS_GOLANG_PROJECT_ID is missing") + } + ts := testutil.TokenSource(ctx, AdminScope, Scope) + if ts == nil { + t.Skip("Integration test skipped: cannot get service account credential from environment variable %v", "GCLOUD_TESTS_GOLANG_KEY") + } + var err error + // Create Admin client and Data client. + // TODO: Remove the EndPoint option once this is the default. + admin, err = database.NewDatabaseAdminClient(ctx, option.WithTokenSource(ts), option.WithEndpoint("spanner.googleapis.com:443")) + if err != nil { + t.Errorf("cannot create admin client: %v", err) + return err + } + // Construct test DB name. + dbName = fmt.Sprintf("gotest_%v", time.Now().UnixNano()) + db = fmt.Sprintf("projects/%v/instances/%v/databases/%v", testProjectID, testInstanceID, dbName) + // Create database and tables. + op, err := admin.CreateDatabase(ctx, &adminpb.CreateDatabaseRequest{ + Parent: fmt.Sprintf("projects/%v/instances/%v", testProjectID, testInstanceID), + CreateStatement: "CREATE DATABASE " + dbName, + ExtraStatements: statements, + }) + if err != nil { + t.Errorf("cannot create testing DB %v: %v", db, err) + return err + } + if _, err := op.Wait(ctx); err != nil { + t.Errorf("cannot create testing DB %v: %v", db, err) + return err + } + client, err = NewClientWithConfig(ctx, db, ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + WriteSessions: 0.2, + }, + }, option.WithTokenSource(ts)) + if err != nil { + t.Errorf("cannot create data client on DB %v: %v", db, err) + return err + } + return nil +} + +// tearDown tears down the testing environment created by prepare(). +func tearDown(ctx context.Context, t *testing.T) { + if admin != nil { + if err := admin.DropDatabase(ctx, &adminpb.DropDatabaseRequest{db}); err != nil { + t.Logf("failed to drop testing database: %v, might need a manual removal", db) + } + admin.Close() + } + if client != nil { + client.Close() + } + admin = nil + client = nil + db = "" +} + +// Test SingleUse transaction. +func TestSingleUse(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + // Set up testing environment. + if err := prepare(ctx, t, singerDBStatements); err != nil { + // If prepare() fails, tear down whatever that's already up. + tearDown(ctx, t) + t.Fatalf("cannot set up testing environment: %v", err) + } + // After all tests, tear down testing environment. + defer tearDown(ctx, t) + + writes := []struct { + row []interface{} + ts time.Time + }{ + {row: []interface{}{1, "Marc", "Foo"}}, + {row: []interface{}{2, "Tars", "Bar"}}, + {row: []interface{}{3, "Alpha", "Beta"}}, + {row: []interface{}{4, "Last", "End"}}, + } + // Try to write four rows through the Apply API. + for i, w := range writes { + var err error + m := InsertOrUpdate("Singers", + []string{"SingerId", "FirstName", "LastName"}, + w.row) + if writes[i].ts, err = client.Apply(ctx, []*Mutation{m}, ApplyAtLeastOnce()); err != nil { + t.Fatal(err) + } + } + + // For testing timestamp bound staleness. + <-time.After(time.Second) + + // Test reading rows with different timestamp bounds. + for i, test := range []struct { + want [][]interface{} + tb TimestampBound + checkTs func(time.Time) error + }{ + { + // strong + [][]interface{}{{int64(1), "Marc", "Foo"}, {int64(3), "Alpha", "Beta"}, {int64(4), "Last", "End"}}, + StrongRead(), + func(ts time.Time) error { + // writes[3] is the last write, all subsequent strong read should have a timestamp larger than that. + if ts.Before(writes[3].ts) { + return fmt.Errorf("read got timestamp %v, want it to be no later than %v", ts, writes[3].ts) + } + return nil + }, + }, + { + // min_read_timestamp + [][]interface{}{{int64(1), "Marc", "Foo"}, {int64(3), "Alpha", "Beta"}, {int64(4), "Last", "End"}}, + MinReadTimestamp(writes[3].ts), + func(ts time.Time) error { + if ts.Before(writes[3].ts) { + return fmt.Errorf("read got timestamp %v, want it to be no later than %v", ts, writes[3].ts) + } + return nil + }, + }, + { + // max_staleness + [][]interface{}{{int64(1), "Marc", "Foo"}, {int64(3), "Alpha", "Beta"}, {int64(4), "Last", "End"}}, + MaxStaleness(time.Second), + func(ts time.Time) error { + if ts.Before(writes[3].ts) { + return fmt.Errorf("read got timestamp %v, want it to be no later than %v", ts, writes[3].ts) + } + return nil + }, + }, + { + // read_timestamp + [][]interface{}{{int64(1), "Marc", "Foo"}, {int64(3), "Alpha", "Beta"}}, + ReadTimestamp(writes[2].ts), + func(ts time.Time) error { + if ts != writes[2].ts { + return fmt.Errorf("read got timestamp %v, expect %v", ts, writes[2].ts) + } + return nil + }, + }, + { + // exact_staleness + nil, + // Specify a staleness which should be already before this test because + // context timeout is set to be 10s. + ExactStaleness(11 * time.Second), + func(ts time.Time) error { + if ts.After(writes[0].ts) { + return fmt.Errorf("read got timestamp %v, want it to be no earlier than %v", ts, writes[0].ts) + } + return nil + }, + }, + } { + // SingleUse.Query + su := client.Single().WithTimestampBound(test.tb) + got, err := readAll(su.Query( + ctx, + Statement{ + "SELECT SingerId, FirstName, LastName FROM Singers WHERE SingerId IN (@id1, @id3, @id4)", + map[string]interface{}{"id1": int64(1), "id3": int64(3), "id4": int64(4)}, + })) + if err != nil { + t.Errorf("%d: SingleUse.Query returns error %v, want nil", i, err) + } + if !reflect.DeepEqual(got, test.want) { + t.Errorf("%d: got unexpected result from SingleUse.Query: %v, want %v", i, got, test.want) + } + rts, err := su.Timestamp() + if err != nil { + t.Errorf("%d: SingleUse.Query doesn't return a timestamp, error: %v", i, err) + } + if err := test.checkTs(rts); err != nil { + t.Errorf("%d: SingleUse.Query doesn't return expected timestamp: %v", i, err) + } + // SingleUse.Read + su = client.Single().WithTimestampBound(test.tb) + got, err = readAll(su.Read(ctx, "Singers", KeySets(Key{1}, Key{3}, Key{4}), []string{"SingerId", "FirstName", "LastName"})) + if err != nil { + t.Errorf("%d: SingleUse.Read returns error %v, want nil", i, err) + } + if !reflect.DeepEqual(got, test.want) { + t.Errorf("%d: got unexpected result from SingleUse.Read: %v, want %v", i, got, test.want) + } + rts, err = su.Timestamp() + if err != nil { + t.Errorf("%d: SingleUse.Read doesn't return a timestamp, error: %v", i, err) + } + if err := test.checkTs(rts); err != nil { + t.Errorf("%d: SingleUse.Read doesn't return expected timestamp: %v", i, err) + } + // SingleUse.ReadRow + got = nil + for _, k := range []Key{Key{1}, Key{3}, Key{4}} { + su = client.Single().WithTimestampBound(test.tb) + r, err := su.ReadRow(ctx, "Singers", k, []string{"SingerId", "FirstName", "LastName"}) + if err != nil { + continue + } + v, err := rowToValues(r) + if err != nil { + continue + } + got = append(got, v) + rts, err = su.Timestamp() + if err != nil { + t.Errorf("%d: SingleUse.ReadRow(%v) doesn't return a timestamp, error: %v", i, k, err) + } + if err := test.checkTs(rts); err != nil { + t.Errorf("%d: SingleUse.ReadRow(%v) doesn't return expected timestamp: %v", i, k, err) + } + } + if !reflect.DeepEqual(got, test.want) { + t.Errorf("%d: got unexpected results from SingleUse.ReadRow: %v, want %v", i, got, test.want) + } + // SingleUse.ReadUsingIndex + su = client.Single().WithTimestampBound(test.tb) + got, err = readAll(su.ReadUsingIndex(ctx, "Singers", "SingerByName", KeySets(Key{"Marc", "Foo"}, Key{"Alpha", "Beta"}, Key{"Last", "End"}), []string{"SingerId", "FirstName", "LastName"})) + if err != nil { + t.Errorf("%d: SingleUse.ReadUsingIndex returns error %v, want nil", i, err) + } + // The results from ReadUsingIndex is sorted by the index rather than primary key. + if len(got) != len(test.want) { + t.Errorf("%d: got unexpected result from SingleUse.ReadUsingIndex: %v, want %v", i, got, test.want) + } + for j, g := range got { + if j > 0 { + prev := got[j-1][1].(string) + got[j-1][2].(string) + curr := got[j][1].(string) + got[j][2].(string) + if strings.Compare(prev, curr) > 0 { + t.Errorf("%d: SingleUse.ReadUsingIndex fails to order rows by index keys, %v should be after %v", i, got[j-1], got[j]) + } + } + found := false + for _, w := range test.want { + if reflect.DeepEqual(g, w) { + found = true + } + } + if !found { + t.Errorf("%d: got unexpected result from SingleUse.ReadUsingIndex: %v, want %v", i, got, test.want) + break + } + } + rts, err = su.Timestamp() + if err != nil { + t.Errorf("%d: SingleUse.ReadUsingIndex doesn't return a timestamp, error: %v", i, err) + } + if err := test.checkTs(rts); err != nil { + t.Errorf("%d: SingleUse.ReadUsingIndex doesn't return expected timestamp: %v", i, err) + } + } +} + +// Test ReadOnlyTransaction. The testsuite is mostly like SingleUse, except it +// also tests for a single timestamp across multiple reads. +func TestReadOnlyTransaction(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + // Set up testing environment. + if err := prepare(ctx, t, singerDBStatements); err != nil { + // If prepare() fails, tear down whatever that's already up. + tearDown(ctx, t) + t.Fatalf("cannot set up testing environment: %v", err) + } + // After all tests, tear down testing environment. + defer tearDown(ctx, t) + + writes := []struct { + row []interface{} + ts time.Time + }{ + {row: []interface{}{1, "Marc", "Foo"}}, + {row: []interface{}{2, "Tars", "Bar"}}, + {row: []interface{}{3, "Alpha", "Beta"}}, + {row: []interface{}{4, "Last", "End"}}, + } + // Try to write four rows through the Apply API. + for i, w := range writes { + var err error + m := InsertOrUpdate("Singers", + []string{"SingerId", "FirstName", "LastName"}, + w.row) + if writes[i].ts, err = client.Apply(ctx, []*Mutation{m}, ApplyAtLeastOnce()); err != nil { + t.Fatal(err) + } + } + + // For testing timestamp bound staleness. + <-time.After(time.Second) + + // Test reading rows with different timestamp bounds. + for i, test := range []struct { + want [][]interface{} + tb TimestampBound + checkTs func(time.Time) error + }{ + // Note: min_read_timestamp and max_staleness are not supported by ReadOnlyTransaction. See + // API document for more details. + { + // strong + [][]interface{}{{int64(1), "Marc", "Foo"}, {int64(3), "Alpha", "Beta"}, {int64(4), "Last", "End"}}, + StrongRead(), + func(ts time.Time) error { + if ts.Before(writes[3].ts) { + return fmt.Errorf("read got timestamp %v, want it to be no later than %v", ts, writes[3].ts) + } + return nil + }, + }, + { + // read_timestamp + [][]interface{}{{int64(1), "Marc", "Foo"}, {int64(3), "Alpha", "Beta"}}, + ReadTimestamp(writes[2].ts), + func(ts time.Time) error { + if ts != writes[2].ts { + return fmt.Errorf("read got timestamp %v, expect %v", ts, writes[2].ts) + } + return nil + }, + }, + { + // exact_staleness + nil, + // Specify a staleness which should be already before this test because + // context timeout is set to be 10s. + ExactStaleness(11 * time.Second), + func(ts time.Time) error { + if ts.After(writes[0].ts) { + return fmt.Errorf("read got timestamp %v, want it to be no earlier than %v", ts, writes[0].ts) + } + return nil + }, + }, + } { + // ReadOnlyTransaction.Query + ro := client.ReadOnlyTransaction().WithTimestampBound(test.tb) + got, err := readAll(ro.Query( + ctx, + Statement{ + "SELECT SingerId, FirstName, LastName FROM Singers WHERE SingerId IN (@id1, @id3, @id4)", + map[string]interface{}{"id1": int64(1), "id3": int64(3), "id4": int64(4)}, + })) + if err != nil { + t.Errorf("%d: ReadOnlyTransaction.Query returns error %v, want nil", i, err) + } + if !reflect.DeepEqual(got, test.want) { + t.Errorf("%d: got unexpected result from ReadOnlyTransaction.Query: %v, want %v", i, got, test.want) + } + rts, err := ro.Timestamp() + if err != nil { + t.Errorf("%d: ReadOnlyTransaction.Query doesn't return a timestamp, error: %v", i, err) + } + if err := test.checkTs(rts); err != nil { + t.Errorf("%d: ReadOnlyTransaction.Query doesn't return expected timestamp: %v", i, err) + } + roTs := rts + // ReadOnlyTransaction.Read + got, err = readAll(ro.Read(ctx, "Singers", KeySets(Key{1}, Key{3}, Key{4}), []string{"SingerId", "FirstName", "LastName"})) + if err != nil { + t.Errorf("%d: ReadOnlyTransaction.Read returns error %v, want nil", i, err) + } + if !reflect.DeepEqual(got, test.want) { + t.Errorf("%d: got unexpected result from ReadOnlyTransaction.Read: %v, want %v", i, got, test.want) + } + rts, err = ro.Timestamp() + if err != nil { + t.Errorf("%d: ReadOnlyTransaction.Read doesn't return a timestamp, error: %v", i, err) + } + if err := test.checkTs(rts); err != nil { + t.Errorf("%d: ReadOnlyTransaction.Read doesn't return expected timestamp: %v", i, err) + } + if roTs != rts { + t.Errorf("%d: got two read timestamps: %v, %v, want ReadOnlyTransaction to return always the same read timestamp", i, roTs, rts) + } + // ReadOnlyTransaction.ReadRow + got = nil + for _, k := range []Key{Key{1}, Key{3}, Key{4}} { + r, err := ro.ReadRow(ctx, "Singers", k, []string{"SingerId", "FirstName", "LastName"}) + if err != nil { + continue + } + v, err := rowToValues(r) + if err != nil { + continue + } + got = append(got, v) + rts, err = ro.Timestamp() + if err != nil { + t.Errorf("%d: ReadOnlyTransaction.ReadRow(%v) doesn't return a timestamp, error: %v", i, k, err) + } + if err := test.checkTs(rts); err != nil { + t.Errorf("%d: ReadOnlyTransaction.ReadRow(%v) doesn't return expected timestamp: %v", i, k, err) + } + if roTs != rts { + t.Errorf("%d: got two read timestamps: %v, %v, want ReadOnlyTransaction to return always the same read timestamp", i, roTs, rts) + } + } + if !reflect.DeepEqual(got, test.want) { + t.Errorf("%d: got unexpected results from ReadOnlyTransaction.ReadRow: %v, want %v", i, got, test.want) + } + // SingleUse.ReadUsingIndex + got, err = readAll(ro.ReadUsingIndex(ctx, "Singers", "SingerByName", KeySets(Key{"Marc", "Foo"}, Key{"Alpha", "Beta"}, Key{"Last", "End"}), []string{"SingerId", "FirstName", "LastName"})) + if err != nil { + t.Errorf("%d: ReadOnlyTransaction.ReadUsingIndex returns error %v, want nil", i, err) + } + // The results from ReadUsingIndex is sorted by the index rather than primary key. + if len(got) != len(test.want) { + t.Errorf("%d: got unexpected result from ReadOnlyTransaction.ReadUsingIndex: %v, want %v", i, got, test.want) + } + for j, g := range got { + if j > 0 { + prev := got[j-1][1].(string) + got[j-1][2].(string) + curr := got[j][1].(string) + got[j][2].(string) + if strings.Compare(prev, curr) > 0 { + t.Errorf("%d: ReadOnlyTransaction.ReadUsingIndex fails to order rows by index keys, %v should be after %v", i, got[j-1], got[j]) + } + } + found := false + for _, w := range test.want { + if reflect.DeepEqual(g, w) { + found = true + } + } + if !found { + t.Errorf("%d: got unexpected result from ReadOnlyTransaction.ReadUsingIndex: %v, want %v", i, got, test.want) + break + } + } + rts, err = ro.Timestamp() + if err != nil { + t.Errorf("%d: ReadOnlyTransaction.ReadUsingIndex doesn't return a timestamp, error: %v", i, err) + } + if err := test.checkTs(rts); err != nil { + t.Errorf("%d: ReadOnlyTransaction.ReadUsingIndex doesn't return expected timestamp: %v", i, err) + } + if roTs != rts { + t.Errorf("%d: got two read timestamps: %v, %v, want ReadOnlyTransaction to return always the same read timestamp", i, roTs, rts) + } + ro.Close() + } +} + +// Test ReadWriteTransaction. +func TestReadWriteTransaction(t *testing.T) { + // Give a longer deadline because of transaction backoffs. + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + if err := prepare(ctx, t, singerDBStatements); err != nil { + tearDown(ctx, t) + t.Fatalf("cannot set up testing environment: %v", err) + } + defer tearDown(ctx, t) + + // Set up two accounts + accounts := []*Mutation{ + Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}), + Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(2), "Bar", int64(1)}), + } + if _, err := client.Apply(ctx, accounts, ApplyAtLeastOnce()); err != nil { + t.Fatal(err) + } + wg := sync.WaitGroup{} + + readBalance := func(iter *RowIterator) (int64, error) { + defer iter.Stop() + var bal int64 + for { + row, err := iter.Next() + if err == iterator.Done { + return bal, nil + } + if err != nil { + return 0, err + } + if err := row.Column(0, &bal); err != nil { + return 0, err + } + } + } + + for i := 0; i < 20; i++ { + wg.Add(1) + go func(iter int) { + defer wg.Done() + _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + // Query Foo's balance and Bar's balance. + bf, e := readBalance(tx.Query(ctx, + Statement{"SELECT Balance FROM Accounts WHERE AccountId = @id", map[string]interface{}{"id": int64(1)}})) + if e != nil { + return e + } + bb, e := readBalance(tx.Read(ctx, "Accounts", KeySets(Key{int64(2)}), []string{"Balance"})) + if e != nil { + return e + } + if bf <= 0 { + return nil + } + bf-- + bb++ + tx.BufferWrite([]*Mutation{ + Update("Accounts", []string{"AccountId", "Balance"}, []interface{}{int64(1), bf}), + Update("Accounts", []string{"AccountId", "Balance"}, []interface{}{int64(2), bb}), + }) + return nil + }) + if err != nil { + t.Fatalf("%d: failed to execute transaction: %v", iter, err) + } + }(i) + } + // Because of context timeout, all goroutines will eventually return. + wg.Wait() + _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + var bf, bb int64 + r, e := tx.ReadRow(ctx, "Accounts", Key{int64(1)}, []string{"Balance"}) + if e != nil { + return e + } + if ce := r.Column(0, &bf); ce != nil { + return ce + } + bb, e = readBalance(tx.ReadUsingIndex(ctx, "Accounts", "AccountByNickname", KeySets(Key{"Bar"}), []string{"Balance"})) + if e != nil { + return e + } + if bf != 30 || bb != 21 { + t.Errorf("Foo's balance is now %v and Bar's balance is now %v, want %v and %v", bf, bb, 30, 21) + } + return nil + }) + if err != nil { + t.Errorf("failed to check balances: %v", err) + } +} + +const ( + testTable = "TestTable" + testTableIndex = "TestTableByValue" +) + +var testTableColumns = []string{"Key", "StringValue"} + +func TestReads(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + // Set up testing environment. + if err := prepare(ctx, t, readDBStatements); err != nil { + // If prepare() fails, tear down whatever that's already up. + tearDown(ctx, t) + t.Fatalf("cannot set up testing environment: %v", err) + } + // After all tests, tear down testing environment. + defer tearDown(ctx, t) + + // Includes k0..k14. Strings sort lexically, eg "k1" < "k10" < "k2". + var ms []*Mutation + for i := 0; i < 15; i++ { + ms = append(ms, InsertOrUpdate(testTable, + testTableColumns, + []interface{}{fmt.Sprintf("k%d", i), fmt.Sprintf("v%d", i)})) + } + if _, err := client.Apply(ctx, ms, ApplyAtLeastOnce()); err != nil { + t.Fatal(err) + } + + // Empty read. + rows, err := readAllTestTable(client.Single().Read(ctx, testTable, + KeyRange{Start: Key{"k99"}, End: Key{"z"}}, testTableColumns)) + if err != nil { + t.Fatal(err) + } + if got, want := len(rows), 0; got != want { + t.Errorf("got %d, want %d", got, want) + } + + // Index empty read. + rows, err = readAllTestTable(client.Single().ReadUsingIndex(ctx, testTable, testTableIndex, + KeyRange{Start: Key{"v99"}, End: Key{"z"}}, testTableColumns)) + if err != nil { + t.Fatal(err) + } + if got, want := len(rows), 0; got != want { + t.Errorf("got %d, want %d", got, want) + } + + // Point read. + row, err := client.Single().ReadRow(ctx, testTable, Key{"k1"}, testTableColumns) + if err != nil { + t.Fatal(err) + } + var got testTableRow + if err := row.ToStruct(&got); err != nil { + t.Fatal(err) + } + if want := (testTableRow{"k1", "v1"}); got != want { + t.Errorf("got %v, want %v", got, want) + } + + // Point read not found. + _, err = client.Single().ReadRow(ctx, testTable, Key{"k999"}, testTableColumns) + if ErrCode(err) != codes.NotFound { + t.Fatalf("got %v, want NotFound", err) + } + + // No index point read not found, because Go does not have ReadRowUsingIndex. + + rangeReads(ctx, t) + indexRangeReads(ctx, t) +} + +func rangeReads(ctx context.Context, t *testing.T) { + checkRange := func(ks KeySet, wantNums ...int) { + if msg, ok := compareRows(client.Single().Read(ctx, testTable, ks, testTableColumns), wantNums); !ok { + t.Errorf("key set %+v: %s", ks, msg) + } + } + + checkRange(Key{"k1"}, 1) + checkRange(KeyRange{Key{"k3"}, Key{"k5"}, ClosedOpen}, 3, 4) + checkRange(KeyRange{Key{"k3"}, Key{"k5"}, ClosedClosed}, 3, 4, 5) + checkRange(KeyRange{Key{"k3"}, Key{"k5"}, OpenClosed}, 4, 5) + checkRange(KeyRange{Key{"k3"}, Key{"k5"}, OpenOpen}, 4) + + // Partial key specification. + checkRange(KeyRange{Key{"k7"}, Key{}, ClosedClosed}, 7, 8, 9) + checkRange(KeyRange{Key{"k7"}, Key{}, OpenClosed}, 8, 9) + checkRange(KeyRange{Key{}, Key{"k11"}, ClosedOpen}, 0, 1, 10) + checkRange(KeyRange{Key{}, Key{"k11"}, ClosedClosed}, 0, 1, 10, 11) + + // The following produce empty ranges. + // TODO(jba): Consider a multi-part key to illustrate partial key behavior. + // checkRange(KeyRange{Key{"k7"}, Key{}, ClosedOpen}) + // checkRange(KeyRange{Key{"k7"}, Key{}, OpenOpen}) + // checkRange(KeyRange{Key{}, Key{"k11"}, OpenOpen}) + // checkRange(KeyRange{Key{}, Key{"k11"}, OpenClosed}) + + // Prefix is component-wise, not string prefix. + checkRange(Key{"k1"}.AsPrefix(), 1) + checkRange(KeyRange{Key{"k1"}, Key{"k2"}, ClosedOpen}, 1, 10, 11, 12, 13, 14) + + checkRange(AllKeys(), 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14) +} + +func indexRangeReads(ctx context.Context, t *testing.T) { + checkRange := func(ks KeySet, wantNums ...int) { + if msg, ok := compareRows(client.Single().ReadUsingIndex(ctx, testTable, testTableIndex, ks, testTableColumns), + wantNums); !ok { + t.Errorf("key set %+v: %s", ks, msg) + } + } + + checkRange(Key{"v1"}, 1) + checkRange(KeyRange{Key{"v3"}, Key{"v5"}, ClosedOpen}, 3, 4) + checkRange(KeyRange{Key{"v3"}, Key{"v5"}, ClosedClosed}, 3, 4, 5) + checkRange(KeyRange{Key{"v3"}, Key{"v5"}, OpenClosed}, 4, 5) + checkRange(KeyRange{Key{"v3"}, Key{"v5"}, OpenOpen}, 4) + + // // Partial key specification. + checkRange(KeyRange{Key{"v7"}, Key{}, ClosedClosed}, 7, 8, 9) + checkRange(KeyRange{Key{"v7"}, Key{}, OpenClosed}, 8, 9) + checkRange(KeyRange{Key{}, Key{"v11"}, ClosedOpen}, 0, 1, 10) + checkRange(KeyRange{Key{}, Key{"v11"}, ClosedClosed}, 0, 1, 10, 11) + + // // The following produce empty ranges. + // checkRange(KeyRange{Key{"v7"}, Key{}, ClosedOpen}) + // checkRange(KeyRange{Key{"v7"}, Key{}, OpenOpen}) + // checkRange(KeyRange{Key{}, Key{"v11"}, OpenOpen}) + // checkRange(KeyRange{Key{}, Key{"v11"}, OpenClosed}) + + // // Prefix is component-wise, not string prefix. + checkRange(Key{"v1"}.AsPrefix(), 1) + checkRange(KeyRange{Key{"v1"}, Key{"v2"}, ClosedOpen}, 1, 10, 11, 12, 13, 14) + checkRange(AllKeys(), 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14) + + // Read from an index with DESC ordering. + wantNums := []int{14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0} + if msg, ok := compareRows(client.Single().ReadUsingIndex(ctx, testTable, "TestTableByValueDesc", AllKeys(), testTableColumns), + wantNums); !ok { + t.Errorf("desc: %s", msg) + } +} + +func compareRows(iter *RowIterator, wantNums []int) (string, bool) { + rows, err := readAllTestTable(iter) + if err != nil { + return err.Error(), false + } + want := map[string]string{} + for _, n := range wantNums { + want[fmt.Sprintf("k%d", n)] = fmt.Sprintf("v%d", n) + } + got := map[string]string{} + for _, r := range rows { + got[r.Key] = r.StringValue + } + if !reflect.DeepEqual(got, want) { + return fmt.Sprintf("got %v, want %v", got, want), false + } + return "", true +} + +func TestEarlyTimestamp(t *testing.T) { + // Test that we can get the timestamp from a read-only transaction as + // soon as we have read at least one row. + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + // Set up testing environment. + if err := prepare(ctx, t, readDBStatements); err != nil { + // If prepare() fails, tear down whatever that's already up. + tearDown(ctx, t) + t.Fatalf("cannot set up testing environment: %v", err) + } + // After all tests, tear down testing environment. + defer tearDown(ctx, t) + + var ms []*Mutation + for i := 0; i < 3; i++ { + ms = append(ms, InsertOrUpdate(testTable, + testTableColumns, + []interface{}{fmt.Sprintf("k%d", i), fmt.Sprintf("v%d", i)})) + } + if _, err := client.Apply(ctx, ms, ApplyAtLeastOnce()); err != nil { + t.Fatal(err) + } + + txn := client.Single() + iter := txn.Read(ctx, testTable, AllKeys(), testTableColumns) + defer iter.Stop() + // In single-use transaction, we should get an error before reading anything. + if _, err := txn.Timestamp(); err == nil { + t.Error("wanted error, got nil") + } + // After reading one row, the timestamp should be available. + _, err := iter.Next() + if err != nil { + t.Fatal(err) + } + if _, err := txn.Timestamp(); err != nil { + t.Errorf("got %v, want nil", err) + } + + txn = client.ReadOnlyTransaction() + defer txn.Close() + iter = txn.Read(ctx, testTable, AllKeys(), testTableColumns) + defer iter.Stop() + // In an ordinary read-only transaction, the timestamp should be + // available immediately. + if _, err := txn.Timestamp(); err != nil { + t.Errorf("got %v, want nil", err) + } +} + +func TestNestedTransaction(t *testing.T) { + // You cannot use a transaction from inside a read-write transaction. + ctx := context.Background() + if err := prepare(ctx, t, singerDBStatements); err != nil { + tearDown(ctx, t) + t.Fatalf("cannot set up testing environment: %v", err) + } + defer tearDown(ctx, t) + client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + _, err := client.ReadWriteTransaction(ctx, + func(context.Context, *ReadWriteTransaction) error { return nil }) + if ErrCode(err) != codes.FailedPrecondition { + t.Fatalf("got %v, want FailedPrecondition", err) + } + _, err = client.Single().ReadRow(ctx, "Singers", Key{1}, []string{"SingerId"}) + if ErrCode(err) != codes.FailedPrecondition { + t.Fatalf("got %v, want FailedPrecondition", err) + } + rot := client.ReadOnlyTransaction() + defer rot.Close() + _, err = rot.ReadRow(ctx, "Singers", Key{1}, []string{"SingerId"}) + if ErrCode(err) != codes.FailedPrecondition { + t.Fatalf("got %v, want FailedPrecondition", err) + } + return nil + }) +} + +// Test client recovery on database recreation. +func TestDbRemovalRecovery(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + if err := prepare(ctx, t, singerDBStatements); err != nil { + tearDown(ctx, t) + t.Fatalf("cannot set up testing environment: %v", err) + } + defer tearDown(ctx, t) + + // Drop the testing database. + if err := admin.DropDatabase(ctx, &adminpb.DropDatabaseRequest{db}); err != nil { + t.Fatalf("failed to drop testing database %v: %v", db, err) + } + + // Now, send the query. + iter := client.Single().Query(ctx, Statement{SQL: "SELECT SingerId FROM Singers"}) + defer iter.Stop() + if _, err := iter.Next(); err == nil { + t.Errorf("client sends query to removed database successfully, want it to fail") + } + + // Recreate database and table. + op, err := admin.CreateDatabase(ctx, &adminpb.CreateDatabaseRequest{ + Parent: fmt.Sprintf("projects/%v/instances/%v", testProjectID, testInstanceID), + CreateStatement: "CREATE DATABASE " + dbName, + ExtraStatements: []string{ + `CREATE TABLE Singers ( + SingerId INT64 NOT NULL, + FirstName STRING(1024), + LastName STRING(1024), + SingerInfo BYTES(MAX) + ) PRIMARY KEY (SingerId)`, + }, + }) + if _, err := op.Wait(ctx); err != nil { + t.Errorf("cannot recreate testing DB %v: %v", db, err) + } + + // Now, send the query again. + iter = client.Single().Query(ctx, Statement{SQL: "SELECT SingerId FROM Singers"}) + defer iter.Stop() + _, err = iter.Next() + if err != nil && err != iterator.Done { + t.Fatalf("failed to send query to database %v: %v", db, err) + } +} + +// Test encoding/decoding non-struct Cloud Spanner types. +func TestBasicTypes(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if err := prepare(ctx, t, singerDBStatements); err != nil { + tearDown(ctx, t) + t.Fatalf("cannot set up testing environment: %v", err) + } + defer tearDown(ctx, t) + t1, _ := time.Parse(time.RFC3339Nano, "2016-11-15T15:04:05.999999999Z") + // Boundaries + t2, _ := time.Parse(time.RFC3339Nano, "0001-01-01T00:00:00.000000000Z") + t3, _ := time.Parse(time.RFC3339Nano, "9999-12-31T23:59:59.999999999Z") + d1, _ := civil.ParseDate("2016-11-15") + // Boundaries + d2, _ := civil.ParseDate("0001-01-01") + d3, _ := civil.ParseDate("9999-12-31") + + tests := []struct { + col string + val interface{} + want interface{} + }{ + {col: "String", val: ""}, + {col: "String", val: "", want: NullString{"", true}}, + {col: "String", val: "foo"}, + {col: "String", val: "foo", want: NullString{"foo", true}}, + {col: "String", val: NullString{"bar", true}, want: "bar"}, + {col: "String", val: NullString{"bar", false}, want: NullString{"", false}}, + {col: "StringArray", val: []string(nil), want: []NullString(nil)}, + {col: "StringArray", val: []string{}, want: []NullString{}}, + {col: "StringArray", val: []string{"foo", "bar"}, want: []NullString{{"foo", true}, {"bar", true}}}, + {col: "StringArray", val: []NullString(nil)}, + {col: "StringArray", val: []NullString{}}, + {col: "StringArray", val: []NullString{{"foo", true}, {}}}, + {col: "Bytes", val: []byte{}}, + {col: "Bytes", val: []byte{1, 2, 3}}, + {col: "Bytes", val: []byte(nil)}, + {col: "BytesArray", val: [][]byte(nil)}, + {col: "BytesArray", val: [][]byte{}}, + {col: "BytesArray", val: [][]byte{[]byte{1}, []byte{2, 3}}}, + {col: "Int64a", val: 0, want: int64(0)}, + {col: "Int64a", val: -1, want: int64(-1)}, + {col: "Int64a", val: 2, want: int64(2)}, + {col: "Int64a", val: int64(3)}, + {col: "Int64a", val: 4, want: NullInt64{4, true}}, + {col: "Int64a", val: NullInt64{5, true}, want: int64(5)}, + {col: "Int64a", val: NullInt64{6, true}, want: int64(6)}, + {col: "Int64a", val: NullInt64{7, false}, want: NullInt64{0, false}}, + {col: "Int64Array", val: []int(nil), want: []NullInt64(nil)}, + {col: "Int64Array", val: []int{}, want: []NullInt64{}}, + {col: "Int64Array", val: []int{1, 2}, want: []NullInt64{{1, true}, {2, true}}}, + {col: "Int64Array", val: []int64(nil), want: []NullInt64(nil)}, + {col: "Int64Array", val: []int64{}, want: []NullInt64{}}, + {col: "Int64Array", val: []int64{1, 2}, want: []NullInt64{{1, true}, {2, true}}}, + {col: "Int64Array", val: []NullInt64(nil)}, + {col: "Int64Array", val: []NullInt64{}}, + {col: "Int64Array", val: []NullInt64{{1, true}, {}}}, + {col: "Bool", val: false}, + {col: "Bool", val: true}, + {col: "Bool", val: false, want: NullBool{false, true}}, + {col: "Bool", val: true, want: NullBool{true, true}}, + {col: "Bool", val: NullBool{true, true}}, + {col: "Bool", val: NullBool{false, false}}, + {col: "BoolArray", val: []bool(nil), want: []NullBool(nil)}, + {col: "BoolArray", val: []bool{}, want: []NullBool{}}, + {col: "BoolArray", val: []bool{true, false}, want: []NullBool{{true, true}, {false, true}}}, + {col: "BoolArray", val: []NullBool(nil)}, + {col: "BoolArray", val: []NullBool{}}, + {col: "BoolArray", val: []NullBool{{false, true}, {true, true}, {}}}, + {col: "Float64", val: 0.0}, + {col: "Float64", val: 3.14}, + {col: "Float64", val: math.NaN()}, + {col: "Float64", val: math.Inf(1)}, + {col: "Float64", val: math.Inf(-1)}, + {col: "Float64", val: 2.78, want: NullFloat64{2.78, true}}, + {col: "Float64", val: NullFloat64{2.71, true}, want: 2.71}, + {col: "Float64", val: NullFloat64{1.41, true}, want: NullFloat64{1.41, true}}, + {col: "Float64", val: NullFloat64{0, false}}, + {col: "Float64Array", val: []float64(nil), want: []NullFloat64(nil)}, + {col: "Float64Array", val: []float64{}, want: []NullFloat64{}}, + {col: "Float64Array", val: []float64{2.72, 3.14, math.Inf(1)}, want: []NullFloat64{{2.72, true}, {3.14, true}, {math.Inf(1), true}}}, + {col: "Float64Array", val: []NullFloat64(nil)}, + {col: "Float64Array", val: []NullFloat64{}}, + {col: "Float64Array", val: []NullFloat64{{2.72, true}, {math.Inf(1), true}, {}}}, + {col: "Date", val: d1}, + {col: "Date", val: d1, want: NullDate{d1, true}}, + {col: "Date", val: NullDate{d1, true}}, + {col: "Date", val: NullDate{d1, true}, want: d1}, + {col: "Date", val: NullDate{civil.Date{}, false}}, + {col: "DateArray", val: []civil.Date(nil), want: []NullDate(nil)}, + {col: "DateArray", val: []civil.Date{}, want: []NullDate{}}, + {col: "DateArray", val: []civil.Date{d1, d2, d3}, want: []NullDate{{d1, true}, {d2, true}, {d3, true}}}, + {col: "Timestamp", val: t1}, + {col: "Timestamp", val: t1, want: NullTime{t1, true}}, + {col: "Timestamp", val: NullTime{t1, true}}, + {col: "Timestamp", val: NullTime{t1, true}, want: t1}, + {col: "Timestamp", val: NullTime{}}, + {col: "TimestampArray", val: []time.Time(nil), want: []NullTime(nil)}, + {col: "TimestampArray", val: []time.Time{}, want: []NullTime{}}, + {col: "TimestampArray", val: []time.Time{t1, t2, t3}, want: []NullTime{{t1, true}, {t2, true}, {t3, true}}}, + } + + // Write rows into table first. + var muts []*Mutation + for i, test := range tests { + muts = append(muts, InsertOrUpdate("Types", []string{"RowID", test.col}, []interface{}{i, test.val})) + } + if _, err := client.Apply(ctx, muts, ApplyAtLeastOnce()); err != nil { + t.Fatal(err) + } + + for i, test := range tests { + row, err := client.Single().ReadRow(ctx, "Types", []interface{}{i}, []string{test.col}) + if err != nil { + t.Fatalf("Unable to fetch row %v: %v", i, err) + } + // Create new instance of type of test.want. + want := test.want + if want == nil { + want = test.val + } + gotp := reflect.New(reflect.TypeOf(want)) + if err := row.Column(0, gotp.Interface()); err != nil { + t.Errorf("%d: col:%v val:%#v, %v", i, test.col, test.val, err) + continue + } + got := reflect.Indirect(gotp).Interface() + + // One of the test cases is checking NaN handling. Given + // NaN!=NaN, we can't use reflect to test for it. + isNaN := func(t interface{}) bool { + f, ok := t.(float64) + if !ok { + return false + } + return math.IsNaN(f) + } + if isNaN(got) && isNaN(want) { + continue + } + + // Check non-NaN cases. + if !reflect.DeepEqual(got, want) { + t.Errorf("%d: col:%v val:%#v, got %#v, want %#v", i, test.col, test.val, got, want) + continue + } + } +} + +// Test decoding Cloud Spanner STRUCT type. +func TestStructTypes(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + if err := prepare(ctx, t, singerDBStatements); err != nil { + tearDown(ctx, t) + t.Fatalf("cannot set up testing environment: %v", err) + } + defer tearDown(ctx, t) + + tests := []struct { + q Statement + want func(r *Row) error + }{ + { + q: Statement{SQL: `SELECT ARRAY(SELECT STRUCT(1, 2))`}, + want: func(r *Row) error { + // Test STRUCT ARRAY decoding to []NullRow. + var rows []NullRow + if err := r.Column(0, &rows); err != nil { + return err + } + if len(rows) != 1 { + return fmt.Errorf("len(rows) = %d; want 1", len(rows)) + } + if !rows[0].Valid { + return fmt.Errorf("rows[0] is NULL") + } + var i, j int64 + if err := rows[0].Row.Columns(&i, &j); err != nil { + return err + } + if i != 1 || j != 2 { + return fmt.Errorf("got (%d,%d), want (1,2)", i, j) + } + return nil + }, + }, + { + q: Statement{SQL: `SELECT ARRAY(SELECT STRUCT(1 as foo, 2 as bar)) as col1`}, + want: func(r *Row) error { + // Test Row.ToStruct. + s := struct { + Col1 []*struct { + Foo int64 `spanner:"foo"` + Bar int64 `spanner:"bar"` + } `spanner:"col1"` + }{} + if err := r.ToStruct(&s); err != nil { + return err + } + want := struct { + Col1 []*struct { + Foo int64 `spanner:"foo"` + Bar int64 `spanner:"bar"` + } `spanner:"col1"` + }{ + Col1: []*struct { + Foo int64 `spanner:"foo"` + Bar int64 `spanner:"bar"` + }{ + { + Foo: 1, + Bar: 2, + }, + }, + } + if !reflect.DeepEqual(want, s) { + return fmt.Errorf("unexpected decoding result: %v, want %v", s, want) + } + return nil + }, + }, + } + for i, test := range tests { + iter := client.Single().Query(ctx, test.q) + defer iter.Stop() + row, err := iter.Next() + if err != nil { + t.Errorf("%d: %v", i, err) + continue + } + if err := test.want(row); err != nil { + t.Errorf("%d: %v", i, err) + continue + } + } +} + +func rowToValues(r *Row) ([]interface{}, error) { + var x int64 + var y, z string + if err := r.Column(0, &x); err != nil { + return nil, err + } + if err := r.Column(1, &y); err != nil { + return nil, err + } + if err := r.Column(2, &z); err != nil { + return nil, err + } + return []interface{}{x, y, z}, nil +} + +func readAll(iter *RowIterator) ([][]interface{}, error) { + defer iter.Stop() + var vals [][]interface{} + for { + row, err := iter.Next() + if err == iterator.Done { + return vals, nil + } + if err != nil { + return nil, err + } + v, err := rowToValues(row) + if err != nil { + return nil, err + } + vals = append(vals, v) + } +} + +func readAllTestTable(iter *RowIterator) ([]testTableRow, error) { + defer iter.Stop() + var vals []testTableRow + for { + row, err := iter.Next() + if err == iterator.Done { + return vals, nil + } + if err != nil { + return nil, err + } + var ttr testTableRow + if err := row.ToStruct(&ttr); err != nil { + return nil, err + } + vals = append(vals, ttr) + } +} diff --git a/vendor/cloud.google.com/go/spanner/statement.go b/vendor/cloud.google.com/go/spanner/statement.go new file mode 100644 index 000000000..8e422b09c --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/statement.go @@ -0,0 +1,78 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spanner + +import ( + "fmt" + + proto3 "github.com/golang/protobuf/ptypes/struct" + + sppb "google.golang.org/genproto/googleapis/spanner/v1" + "google.golang.org/grpc/codes" +) + +// A Statement is a SQL query with named parameters. +// +// A parameter placeholder consists of '@' followed by the parameter name. +// Parameter names consist of any combination of letters, numbers, and +// underscores. Names may be entirely numeric (e.g., "WHERE m.id = @5"). +// Parameters may appear anywhere that a literal value is expected. The same +// parameter name may be used more than once. It is an error to execute a +// statement with unbound parameters. On the other hand, it is allowable to +// bind parameter names that are not used. +// +// See the documentation of the Row type for how Go types are mapped to Cloud +// Spanner types. +type Statement struct { + SQL string + Params map[string]interface{} +} + +// NewStatement returns a Statement with the given SQL and an empty Params map. +func NewStatement(sql string) Statement { + return Statement{SQL: sql, Params: map[string]interface{}{}} +} + +// errBindParam returns error for not being able to bind parameter to query request. +func errBindParam(k string, v interface{}, err error) error { + if err == nil { + return nil + } + se, ok := toSpannerError(err).(*Error) + if !ok { + return spannerErrorf(codes.InvalidArgument, "failed to bind query parameter(name: %q, value: %q), error = <%v>", k, v, err) + } + se.decorate(fmt.Sprintf("failed to bind query parameter(name: %q, value: %q)", k, v)) + return se +} + +// bindParams binds parameters in a Statement to a sppb.ExecuteSqlRequest. +func (s *Statement) bindParams(r *sppb.ExecuteSqlRequest) error { + r.Params = &proto3.Struct{ + Fields: map[string]*proto3.Value{}, + } + r.ParamTypes = map[string]*sppb.Type{} + for k, v := range s.Params { + val, t, err := encodeValue(v) + if err != nil { + return errBindParam(k, v, err) + } + r.Params.Fields[k] = val + r.ParamTypes[k] = t + } + return nil +} diff --git a/vendor/cloud.google.com/go/spanner/statement_test.go b/vendor/cloud.google.com/go/spanner/statement_test.go new file mode 100644 index 000000000..a441e0e82 --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/statement_test.go @@ -0,0 +1,64 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spanner + +import ( + "reflect" + "testing" + + proto3 "github.com/golang/protobuf/ptypes/struct" + + sppb "google.golang.org/genproto/googleapis/spanner/v1" +) + +// Test Statement.bindParams. +func TestBindParams(t *testing.T) { + // Verify Statement.bindParams generates correct values and types. + want := sppb.ExecuteSqlRequest{ + Params: &proto3.Struct{ + Fields: map[string]*proto3.Value{ + "var1": stringProto("abc"), + "var2": intProto(1), + }, + }, + ParamTypes: map[string]*sppb.Type{ + "var1": stringType(), + "var2": intType(), + }, + } + st := Statement{ + SQL: "SELECT id from t_foo WHERE col1 = @var1 AND col2 = @var2", + Params: map[string]interface{}{"var1": "abc", "var2": int64(1)}, + } + got := sppb.ExecuteSqlRequest{} + if err := st.bindParams(&got); err != nil || !reflect.DeepEqual(got, want) { + t.Errorf("bind result: \n(%v, %v)\nwant\n(%v, %v)\n", got, err, want, nil) + } + // Verify type error reporting. + st.Params["var2"] = struct{}{} + wantErr := errBindParam("var2", struct{}{}, errEncoderUnsupportedType(struct{}{})) + if err := st.bindParams(&got); !reflect.DeepEqual(err, wantErr) { + t.Errorf("got unexpected error: %v, want: %v", err, wantErr) + } +} + +func TestNewStatement(t *testing.T) { + s := NewStatement("query") + if got, want := s.SQL, "query"; got != want { + t.Errorf("got %q, want %q", got, want) + } +} diff --git a/vendor/cloud.google.com/go/spanner/timestampbound.go b/vendor/cloud.google.com/go/spanner/timestampbound.go new file mode 100644 index 000000000..068d96600 --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/timestampbound.go @@ -0,0 +1,245 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spanner + +import ( + "fmt" + "time" + + pbd "github.com/golang/protobuf/ptypes/duration" + pbt "github.com/golang/protobuf/ptypes/timestamp" + sppb "google.golang.org/genproto/googleapis/spanner/v1" +) + +// timestampBoundType specifies the timestamp bound mode. +type timestampBoundType int + +const ( + strong timestampBoundType = iota // strong reads + exactStaleness // read with exact staleness + maxStaleness // read with max staleness + minReadTimestamp // read with min freshness + readTimestamp // read data at exact timestamp +) + +// TimestampBound defines how Cloud Spanner will choose a timestamp for a single +// read/query or read-only transaction. +// +// The types of timestamp bound are: +// +// - Strong (the default). +// - Bounded staleness. +// - Exact staleness. +// +// If the Cloud Spanner database to be read is geographically distributed, stale +// read-only transactions can execute more quickly than strong or read-write +// transactions, because they are able to execute far from the leader replica. +// +// Each type of timestamp bound is discussed in detail below. A TimestampBound +// can be specified when creating transactions, see the documentation of +// spanner.Client for an example. +// +// Strong reads +// +// Strong reads are guaranteed to see the effects of all transactions that have +// committed before the start of the read. Furthermore, all rows yielded by a +// single read are consistent with each other - if any part of the read +// observes a transaction, all parts of the read see the transaction. +// +// Strong reads are not repeatable: two consecutive strong read-only +// transactions might return inconsistent results if there are concurrent +// writes. If consistency across reads is required, the reads should be +// executed within a transaction or at an exact read timestamp. +// +// Use StrongRead() to create a bound of this type. +// +// Exact staleness +// +// These timestamp bounds execute reads at a user-specified timestamp. Reads at +// a timestamp are guaranteed to see a consistent prefix of the global +// transaction history: they observe modifications done by all transactions +// with a commit timestamp less than or equal to the read timestamp, and +// observe none of the modifications done by transactions with a larger commit +// timestamp. They will block until all conflicting transactions that may be +// assigned commit timestamps less than or equal to the read timestamp have +// finished. +// +// The timestamp can either be expressed as an absolute Cloud Spanner commit +// timestamp or a staleness relative to the current time. +// +// These modes do not require a "negotiation phase" to pick a timestamp. As a +// result, they execute slightly faster than the equivalent boundedly stale +// concurrency modes. On the other hand, boundedly stale reads usually return +// fresher results. +// +// Use ReadTimestamp() and ExactStaleness() to create a bound of this type. +// +// Bounded staleness +// +// Bounded staleness modes allow Cloud Spanner to pick the read timestamp, subject to +// a user-provided staleness bound. Cloud Spanner chooses the newest timestamp within +// the staleness bound that allows execution of the reads at the closest +// available replica without blocking. +// +// All rows yielded are consistent with each other -- if any part of the read +// observes a transaction, all parts of the read see the transaction. Boundedly +// stale reads are not repeatable: two stale reads, even if they use the same +// staleness bound, can execute at different timestamps and thus return +// inconsistent results. +// +// Boundedly stale reads execute in two phases: the first phase negotiates a +// timestamp among all replicas needed to serve the read. In the second phase, +// reads are executed at the negotiated timestamp. +// +// As a result of the two phase execution, bounded staleness reads are usually +// a little slower than comparable exact staleness reads. However, they are +// typically able to return fresher results, and are more likely to execute at +// the closest replica. +// +// Because the timestamp negotiation requires up-front knowledge of which rows +// will be read, it can only be used with single-use reads and single-use +// read-only transactions. +// +// Use MinReadTimestamp() and MaxStaleness() to create a bound of this type. +// +// Old read timestamps and garbage collection +// +// Cloud Spanner continuously garbage collects deleted and overwritten data in the +// background to reclaim storage space. This process is known as "version +// GC". By default, version GC reclaims versions after they are four hours +// old. Because of this, Cloud Spanner cannot perform reads at read timestamps more +// than four hours in the past. This restriction also applies to in-progress +// reads and/or SQL queries whose timestamp become too old while +// executing. Reads and SQL queries with too-old read timestamps fail with the +// error ErrorCode.FAILED_PRECONDITION. +type TimestampBound struct { + mode timestampBoundType + d time.Duration + t time.Time +} + +// StrongRead returns a TimestampBound that will perform reads and queries at a +// timestamp where all previously committed transactions are visible. +func StrongRead() TimestampBound { + return TimestampBound{mode: strong} +} + +// ExactStaleness returns a TimestampBound that will perform reads and queries +// at an exact staleness. +func ExactStaleness(d time.Duration) TimestampBound { + return TimestampBound{ + mode: exactStaleness, + d: d, + } +} + +// MaxStaleness returns a TimestampBound that will perform reads and queries at +// a time chosen to be at most "d" stale. +func MaxStaleness(d time.Duration) TimestampBound { + return TimestampBound{ + mode: maxStaleness, + d: d, + } +} + +// MinReadTimestamp returns a TimestampBound that bound that will perform reads +// and queries at a time chosen to be at least "t". +func MinReadTimestamp(t time.Time) TimestampBound { + return TimestampBound{ + mode: minReadTimestamp, + t: t, + } +} + +// ReadTimestamp returns a TimestampBound that will peform reads and queries at +// the given time. +func ReadTimestamp(t time.Time) TimestampBound { + return TimestampBound{ + mode: readTimestamp, + t: t, + } +} + +// String implements fmt.Stringer. +func (tb TimestampBound) String() string { + switch tb.mode { + case strong: + return fmt.Sprintf("(strong)") + case exactStaleness: + return fmt.Sprintf("(exactStaleness: %s)", tb.d) + case maxStaleness: + return fmt.Sprintf("(maxStaleness: %s)", tb.d) + case minReadTimestamp: + return fmt.Sprintf("(minReadTimestamp: %s)", tb.t) + case readTimestamp: + return fmt.Sprintf("(readTimestamp: %s)", tb.t) + default: + return fmt.Sprintf("{mode=%v, d=%v, t=%v}", tb.mode, tb.d, tb.t) + } +} + +// durationProto takes a time.Duration and converts it into pdb.Duration for +// calling gRPC APIs. +func durationProto(d time.Duration) *pbd.Duration { + n := d.Nanoseconds() + return &pbd.Duration{ + Seconds: n / int64(time.Second), + Nanos: int32(n % int64(time.Second)), + } +} + +// timestampProto takes a time.Time and converts it into pbt.Timestamp for calling +// gRPC APIs. +func timestampProto(t time.Time) *pbt.Timestamp { + return &pbt.Timestamp{ + Seconds: t.Unix(), + Nanos: int32(t.Nanosecond()), + } +} + +// buildTransactionOptionsReadOnly converts a spanner.TimestampBound into a sppb.TransactionOptions_ReadOnly +// transaction option, which is then used in transactional reads. +func buildTransactionOptionsReadOnly(tb TimestampBound, returnReadTimestamp bool) *sppb.TransactionOptions_ReadOnly { + pb := &sppb.TransactionOptions_ReadOnly{ + ReturnReadTimestamp: returnReadTimestamp, + } + switch tb.mode { + case strong: + pb.TimestampBound = &sppb.TransactionOptions_ReadOnly_Strong{ + Strong: true, + } + case exactStaleness: + pb.TimestampBound = &sppb.TransactionOptions_ReadOnly_ExactStaleness{ + ExactStaleness: durationProto(tb.d), + } + case maxStaleness: + pb.TimestampBound = &sppb.TransactionOptions_ReadOnly_MaxStaleness{ + MaxStaleness: durationProto(tb.d), + } + case minReadTimestamp: + pb.TimestampBound = &sppb.TransactionOptions_ReadOnly_MinReadTimestamp{ + MinReadTimestamp: timestampProto(tb.t), + } + case readTimestamp: + pb.TimestampBound = &sppb.TransactionOptions_ReadOnly_ReadTimestamp{ + ReadTimestamp: timestampProto(tb.t), + } + default: + panic(fmt.Sprintf("buildTransactionOptionsReadOnly(%v,%v)", tb, returnReadTimestamp)) + } + return pb +} diff --git a/vendor/cloud.google.com/go/spanner/timestampbound_test.go b/vendor/cloud.google.com/go/spanner/timestampbound_test.go new file mode 100644 index 000000000..47fb481db --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/timestampbound_test.go @@ -0,0 +1,208 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spanner + +import ( + "reflect" + "testing" + "time" + + pbd "github.com/golang/protobuf/ptypes/duration" + pbt "github.com/golang/protobuf/ptypes/timestamp" + + sppb "google.golang.org/genproto/googleapis/spanner/v1" +) + +// Test generating TimestampBound for strong reads. +func TestStrong(t *testing.T) { + got := StrongRead() + want := TimestampBound{mode: strong} + if !reflect.DeepEqual(got, want) { + t.Errorf("Strong() = %v; want %v", got, want) + } +} + +// Test generating TimestampBound for reads with exact staleness. +func TestExactStaleness(t *testing.T) { + got := ExactStaleness(10 * time.Second) + want := TimestampBound{mode: exactStaleness, d: 10 * time.Second} + if !reflect.DeepEqual(got, want) { + t.Errorf("ExactStaleness(10*time.Second) = %v; want %v", got, want) + } +} + +// Test generating TimestampBound for reads with max staleness. +func TestMaxStaleness(t *testing.T) { + got := MaxStaleness(10 * time.Second) + want := TimestampBound{mode: maxStaleness, d: 10 * time.Second} + if !reflect.DeepEqual(got, want) { + t.Errorf("MaxStaleness(10*time.Second) = %v; want %v", got, want) + } +} + +// Test generating TimestampBound for reads with minimum freshness requirement. +func TestMinReadTimestamp(t *testing.T) { + ts := time.Now() + got := MinReadTimestamp(ts) + want := TimestampBound{mode: minReadTimestamp, t: ts} + if !reflect.DeepEqual(got, want) { + t.Errorf("MinReadTimestamp(%v) = %v; want %v", ts, got, want) + } +} + +// Test generating TimestampBound for reads requesting data at a exact timestamp. +func TestReadTimestamp(t *testing.T) { + ts := time.Now() + got := ReadTimestamp(ts) + want := TimestampBound{mode: readTimestamp, t: ts} + if !reflect.DeepEqual(got, want) { + t.Errorf("ReadTimestamp(%v) = %v; want %v", ts, got, want) + } +} + +// Test TimestampBound.String. +func TestTimestampBoundString(t *testing.T) { + ts := time.Unix(1136239445, 0).UTC() + var tests = []struct { + tb TimestampBound + want string + }{ + { + tb: TimestampBound{mode: strong}, + want: "(strong)", + }, + { + tb: TimestampBound{mode: exactStaleness, d: 10 * time.Second}, + want: "(exactStaleness: 10s)", + }, + { + tb: TimestampBound{mode: maxStaleness, d: 10 * time.Second}, + want: "(maxStaleness: 10s)", + }, + { + tb: TimestampBound{mode: minReadTimestamp, t: ts}, + want: "(minReadTimestamp: 2006-01-02 22:04:05 +0000 UTC)", + }, + { + tb: TimestampBound{mode: readTimestamp, t: ts}, + want: "(readTimestamp: 2006-01-02 22:04:05 +0000 UTC)", + }, + } + for _, test := range tests { + got := test.tb.String() + if got != test.want { + t.Errorf("%#v.String():\ngot %q\nwant %q", test.tb, got, test.want) + } + } +} + +// Test time.Duration to pdb.Duration conversion. +func TestDurationProto(t *testing.T) { + var tests = []struct { + d time.Duration + want pbd.Duration + }{ + {time.Duration(0), pbd.Duration{Seconds: 0, Nanos: 0}}, + {time.Second, pbd.Duration{Seconds: 1, Nanos: 0}}, + {time.Millisecond, pbd.Duration{Seconds: 0, Nanos: 1e6}}, + {15 * time.Nanosecond, pbd.Duration{Seconds: 0, Nanos: 15}}, + {42 * time.Hour, pbd.Duration{Seconds: 151200}}, + {-(1*time.Hour + 4*time.Millisecond), pbd.Duration{Seconds: -3600, Nanos: -4e6}}, + } + for _, test := range tests { + got := durationProto(test.d) + if !reflect.DeepEqual(got, &test.want) { + t.Errorf("durationProto(%v) = %v; want %v", test.d, got, test.want) + } + } +} + +// Test time.Time to pbt.Timestamp conversion. +func TestTimeProto(t *testing.T) { + var tests = []struct { + t time.Time + want pbt.Timestamp + }{ + {time.Unix(0, 0), pbt.Timestamp{}}, + {time.Unix(1136239445, 12345), pbt.Timestamp{Seconds: 1136239445, Nanos: 12345}}, + {time.Unix(-1000, 12345), pbt.Timestamp{Seconds: -1000, Nanos: 12345}}, + } + for _, test := range tests { + got := timestampProto(test.t) + if !reflect.DeepEqual(got, &test.want) { + t.Errorf("timestampProto(%v) = %v; want %v", test.t, got, test.want) + } + } +} + +// Test readonly transaction option builder. +func TestBuildTransactionOptionsReadOnly(t *testing.T) { + ts := time.Unix(1136239445, 12345) + var tests = []struct { + tb TimestampBound + ts bool + want sppb.TransactionOptions_ReadOnly + }{ + { + StrongRead(), false, + sppb.TransactionOptions_ReadOnly{ + TimestampBound: &sppb.TransactionOptions_ReadOnly_Strong{ + Strong: true}, + ReturnReadTimestamp: false, + }, + }, + { + ExactStaleness(10 * time.Second), true, + sppb.TransactionOptions_ReadOnly{ + TimestampBound: &sppb.TransactionOptions_ReadOnly_ExactStaleness{ + ExactStaleness: &pbd.Duration{Seconds: 10}}, + ReturnReadTimestamp: true, + }, + }, + { + MaxStaleness(10 * time.Second), true, + sppb.TransactionOptions_ReadOnly{ + TimestampBound: &sppb.TransactionOptions_ReadOnly_MaxStaleness{ + MaxStaleness: &pbd.Duration{Seconds: 10}}, + ReturnReadTimestamp: true, + }, + }, + + { + MinReadTimestamp(ts), true, + sppb.TransactionOptions_ReadOnly{ + TimestampBound: &sppb.TransactionOptions_ReadOnly_MinReadTimestamp{ + MinReadTimestamp: &pbt.Timestamp{Seconds: 1136239445, Nanos: 12345}}, + ReturnReadTimestamp: true, + }, + }, + { + ReadTimestamp(ts), true, + sppb.TransactionOptions_ReadOnly{ + TimestampBound: &sppb.TransactionOptions_ReadOnly_ReadTimestamp{ + ReadTimestamp: &pbt.Timestamp{Seconds: 1136239445, Nanos: 12345}}, + ReturnReadTimestamp: true, + }, + }, + } + for _, test := range tests { + got := buildTransactionOptionsReadOnly(test.tb, test.ts) + if !reflect.DeepEqual(got, &test.want) { + t.Errorf("buildTransactionOptionsReadOnly(%v,%v) = %v; want %v", test.tb, test.ts, got, test.want) + } + } +} diff --git a/vendor/cloud.google.com/go/spanner/transaction.go b/vendor/cloud.google.com/go/spanner/transaction.go new file mode 100644 index 000000000..ff1e96fe9 --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/transaction.go @@ -0,0 +1,821 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spanner + +import ( + "sync" + "time" + + "golang.org/x/net/context" + + "google.golang.org/api/iterator" + sppb "google.golang.org/genproto/googleapis/spanner/v1" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" +) + +// transactionID stores a transaction ID which uniquely identifies a transaction in Cloud Spanner. +type transactionID []byte + +// txReadEnv manages a read-transaction environment consisting of a session handle and a transaction selector. +type txReadEnv interface { + // acquire returns a read-transaction environment that can be used to perform a transactional read. + acquire(ctx context.Context) (*sessionHandle, *sppb.TransactionSelector, error) + // sets the transaction's read timestamp + setTimestamp(time.Time) + // release should be called at the end of every transactional read to deal with session recycling. + release(error) +} + +// txReadOnly contains methods for doing transactional reads. +type txReadOnly struct { + // read-transaction environment for performing transactional read operations. + txReadEnv +} + +// errSessionClosed returns error for using a recycled/destroyed session +func errSessionClosed(sh *sessionHandle) error { + return spannerErrorf(codes.FailedPrecondition, + "session is already recycled / destroyed: session_id = %q, rpc_client = %v", sh.getID(), sh.getClient()) +} + +// Read returns a RowIterator for reading multiple rows from the database. +func (t *txReadOnly) Read(ctx context.Context, table string, keys KeySet, columns []string) *RowIterator { + // ReadUsingIndex will use primary index if an empty index name is provided. + return t.ReadUsingIndex(ctx, table, "", keys, columns) +} + +// ReadUsingIndex returns a RowIterator for reading multiple rows from the database +// using an index. +// +// Currently, this function can only read columns that are part of the index +// key, part of the primary key, or stored in the index due to a STORING clause +// in the index definition. +func (t *txReadOnly) ReadUsingIndex(ctx context.Context, table, index string, keys KeySet, columns []string) *RowIterator { + var ( + sh *sessionHandle + ts *sppb.TransactionSelector + err error + ) + kset, err := keys.keySetProto() + if err != nil { + return &RowIterator{err: err} + } + if sh, ts, err = t.acquire(ctx); err != nil { + return &RowIterator{err: err} + } + // Cloud Spanner will return "Session not found" on bad sessions. + sid, client := sh.getID(), sh.getClient() + if sid == "" || client == nil { + // Might happen if transaction is closed in the middle of a API call. + return &RowIterator{err: errSessionClosed(sh)} + } + return stream( + contextWithOutgoingMetadata(ctx, sh.getMetadata()), + func(ctx context.Context, resumeToken []byte) (streamingReceiver, error) { + return client.StreamingRead(ctx, + &sppb.ReadRequest{ + Session: sid, + Transaction: ts, + Table: table, + Index: index, + Columns: columns, + KeySet: kset, + ResumeToken: resumeToken, + }) + }, + t.setTimestamp, + t.release, + ) +} + +// errRowNotFound returns error for not being able to read the row identified by key. +func errRowNotFound(table string, key Key) error { + return spannerErrorf(codes.NotFound, "row not found(Table: %v, PrimaryKey: %v)", table, key) +} + +// ReadRow reads a single row from the database. +// +// If no row is present with the given key, then ReadRow returns an error where +// spanner.ErrCode(err) is codes.NotFound. +func (t *txReadOnly) ReadRow(ctx context.Context, table string, key Key, columns []string) (*Row, error) { + iter := t.Read(ctx, table, key, columns) + defer iter.Stop() + row, err := iter.Next() + switch err { + case iterator.Done: + return nil, errRowNotFound(table, key) + case nil: + return row, nil + default: + return nil, err + } +} + +// Query executes a query against the database. It returns a RowIterator +// for retrieving the resulting rows. +func (t *txReadOnly) Query(ctx context.Context, statement Statement) *RowIterator { + var ( + sh *sessionHandle + ts *sppb.TransactionSelector + err error + ) + if sh, ts, err = t.acquire(ctx); err != nil { + return &RowIterator{err: err} + } + // Cloud Spanner will return "Session not found" on bad sessions. + sid, client := sh.getID(), sh.getClient() + if sid == "" || client == nil { + // Might happen if transaction is closed in the middle of a API call. + return &RowIterator{err: errSessionClosed(sh)} + } + req := &sppb.ExecuteSqlRequest{ + Session: sid, + Transaction: ts, + Sql: statement.SQL, + } + if err := statement.bindParams(req); err != nil { + return &RowIterator{err: err} + } + return stream( + contextWithOutgoingMetadata(ctx, sh.getMetadata()), + func(ctx context.Context, resumeToken []byte) (streamingReceiver, error) { + req.ResumeToken = resumeToken + return client.ExecuteStreamingSql(ctx, req) + }, + t.setTimestamp, + t.release) +} + +// txState is the status of a transaction. +type txState int + +const ( + // transaction is new, waiting to be initialized. + txNew txState = iota + // transaction is being initialized. + txInit + // transaction is active and can perform read/write. + txActive + // transaction is closed, cannot be used anymore. + txClosed +) + +// errRtsUnavailable returns error for read transaction's read timestamp being unavailable. +func errRtsUnavailable() error { + return spannerErrorf(codes.Internal, "read timestamp is unavailable") +} + +// errTxNotInitialized returns error for using an uninitialized transaction. +func errTxNotInitialized() error { + return spannerErrorf(codes.InvalidArgument, "cannot use a uninitialized transaction") +} + +// errTxClosed returns error for using a closed transaction. +func errTxClosed() error { + return spannerErrorf(codes.InvalidArgument, "cannot use a closed transaction") +} + +// errUnexpectedTxState returns error for transaction enters an unexpected state. +func errUnexpectedTxState(ts txState) error { + return spannerErrorf(codes.FailedPrecondition, "unexpected transaction state: %v", ts) +} + +// ReadOnlyTransaction provides a snapshot transaction with guaranteed +// consistency across reads, but does not allow writes. Read-only +// transactions can be configured to read at timestamps in the past. +// +// Read-only transactions do not take locks. Instead, they work by choosing a +// Cloud Spanner timestamp, then executing all reads at that timestamp. Since they do +// not acquire locks, they do not block concurrent read-write transactions. +// +// Unlike locking read-write transactions, read-only transactions never +// abort. They can fail if the chosen read timestamp is garbage collected; +// however, the default garbage collection policy is generous enough that most +// applications do not need to worry about this in practice. See the +// documentation of TimestampBound for more details. +// +// A ReadOnlyTransaction consumes resources on the server until Close() is +// called. +type ReadOnlyTransaction struct { + // txReadOnly contains methods for performing transactional reads. + txReadOnly + + // singleUse indicates that the transaction can be used for only one read. + singleUse bool + + // sp is the session pool for allocating a session to execute the read-only transaction. It is set only once during initialization of the ReadOnlyTransaction. + sp *sessionPool + // mu protects concurrent access to the internal states of ReadOnlyTransaction. + mu sync.Mutex + // tx is the transaction ID in Cloud Spanner that uniquely identifies the ReadOnlyTransaction. + tx transactionID + // txReadyOrClosed is for broadcasting that transaction ID has been returned by Cloud Spanner or that transaction is closed. + txReadyOrClosed chan struct{} + // state is the current transaction status of the ReadOnly transaction. + state txState + // sh is the sessionHandle allocated from sp. + sh *sessionHandle + // rts is the read timestamp returned by transactional reads. + rts time.Time + // tb is the read staleness bound specification for transactional reads. + tb TimestampBound +} + +// errTxInitTimeout returns error for timeout in waiting for initialization of the transaction. +func errTxInitTimeout() error { + return spannerErrorf(codes.Canceled, "timeout/context canceled in waiting for transaction's initialization") +} + +// getTimestampBound returns the read staleness bound specified for the ReadOnlyTransaction. +func (t *ReadOnlyTransaction) getTimestampBound() TimestampBound { + t.mu.Lock() + defer t.mu.Unlock() + return t.tb +} + +// begin starts a snapshot read-only Transaction on Cloud Spanner. +func (t *ReadOnlyTransaction) begin(ctx context.Context) error { + var ( + locked bool + tx transactionID + rts time.Time + sh *sessionHandle + err error + ) + defer func() { + if !locked { + t.mu.Lock() + // Not necessary, just to make it clear that t.mu is being held when locked == true. + locked = true + } + if t.state != txClosed { + // Signal other initialization routines. + close(t.txReadyOrClosed) + t.txReadyOrClosed = make(chan struct{}) + } + t.mu.Unlock() + if err != nil && sh != nil { + // Got a valid session handle, but failed to initalize transaction on Cloud Spanner. + if shouldDropSession(err) { + sh.destroy() + } + // If sh.destroy was already executed, this becomes a noop. + sh.recycle() + } + }() + sh, err = t.sp.take(ctx) + if err != nil { + return err + } + err = runRetryable(contextWithOutgoingMetadata(ctx, sh.getMetadata()), func(ctx context.Context) error { + res, e := sh.getClient().BeginTransaction(ctx, &sppb.BeginTransactionRequest{ + Session: sh.getID(), + Options: &sppb.TransactionOptions{ + Mode: &sppb.TransactionOptions_ReadOnly_{ + ReadOnly: buildTransactionOptionsReadOnly(t.getTimestampBound(), true), + }, + }, + }) + if e != nil { + return e + } + tx = res.Id + if res.ReadTimestamp != nil { + rts = time.Unix(res.ReadTimestamp.Seconds, int64(res.ReadTimestamp.Nanos)) + } + return nil + }) + t.mu.Lock() + locked = true // defer function will be executed with t.mu being held. + if t.state == txClosed { // During the execution of t.begin(), t.Close() was invoked. + return errSessionClosed(sh) + } + // If begin() fails, this allows other queries to take over the initialization. + t.tx = nil + if err == nil { + t.tx = tx + t.rts = rts + t.sh = sh + // State transite to txActive. + t.state = txActive + } + return err +} + +// acquire implements txReadEnv.acquire. +func (t *ReadOnlyTransaction) acquire(ctx context.Context) (*sessionHandle, *sppb.TransactionSelector, error) { + if err := checkNestedTxn(ctx); err != nil { + return nil, nil, err + } + if t.singleUse { + return t.acquireSingleUse(ctx) + } + return t.acquireMultiUse(ctx) +} + +func (t *ReadOnlyTransaction) acquireSingleUse(ctx context.Context) (*sessionHandle, *sppb.TransactionSelector, error) { + t.mu.Lock() + defer t.mu.Unlock() + switch t.state { + case txClosed: + // A closed single-use transaction can never be reused. + return nil, nil, errTxClosed() + case txNew: + t.state = txClosed + ts := &sppb.TransactionSelector{ + Selector: &sppb.TransactionSelector_SingleUse{ + SingleUse: &sppb.TransactionOptions{ + Mode: &sppb.TransactionOptions_ReadOnly_{ + ReadOnly: buildTransactionOptionsReadOnly(t.tb, true), + }, + }, + }, + } + sh, err := t.sp.take(ctx) + if err != nil { + return nil, nil, err + } + // Install session handle into t, which can be used for readonly operations later. + t.sh = sh + return sh, ts, nil + } + us := t.state + // SingleUse transaction should only be in either txNew state or txClosed state. + return nil, nil, errUnexpectedTxState(us) +} + +func (t *ReadOnlyTransaction) acquireMultiUse(ctx context.Context) (*sessionHandle, *sppb.TransactionSelector, error) { + for { + t.mu.Lock() + switch t.state { + case txClosed: + t.mu.Unlock() + return nil, nil, errTxClosed() + case txNew: + // State transit to txInit so that no further TimestampBound change is accepted. + t.state = txInit + t.mu.Unlock() + continue + case txInit: + if t.tx != nil { + // Wait for a transaction ID to become ready. + txReadyOrClosed := t.txReadyOrClosed + t.mu.Unlock() + select { + case <-txReadyOrClosed: + // Need to check transaction state again. + continue + case <-ctx.Done(): + // The waiting for initialization is timeout, return error directly. + return nil, nil, errTxInitTimeout() + } + } + // Take the ownership of initializing the transaction. + t.tx = transactionID{} + t.mu.Unlock() + // Begin a read-only transaction. + // TODO: consider adding a transaction option which allow queries to initiate transactions by themselves. Note that this option might not be + // always good because the ID of the new transaction won't be ready till the query returns some data or completes. + if err := t.begin(ctx); err != nil { + return nil, nil, err + } + // If t.begin() succeeded, t.state should have been changed to txActive, so we can just continue here. + continue + case txActive: + sh := t.sh + ts := &sppb.TransactionSelector{ + Selector: &sppb.TransactionSelector_Id{ + Id: t.tx, + }, + } + t.mu.Unlock() + return sh, ts, nil + } + state := t.state + t.mu.Unlock() + return nil, nil, errUnexpectedTxState(state) + } +} + +func (t *ReadOnlyTransaction) setTimestamp(ts time.Time) { + t.mu.Lock() + defer t.mu.Unlock() + if t.rts.IsZero() { + t.rts = ts + } +} + +// release implements txReadEnv.release. +func (t *ReadOnlyTransaction) release(err error) { + t.mu.Lock() + sh := t.sh + t.mu.Unlock() + if sh != nil { // sh could be nil if t.acquire() fails. + if shouldDropSession(err) { + sh.destroy() + } + if t.singleUse { + // If session handle is already destroyed, this becomes a noop. + sh.recycle() + } + } +} + +// Close closes a ReadOnlyTransaction, the transaction cannot perform any reads after being closed. +func (t *ReadOnlyTransaction) Close() { + if t.singleUse { + return + } + t.mu.Lock() + if t.state != txClosed { + t.state = txClosed + close(t.txReadyOrClosed) + } + sh := t.sh + t.mu.Unlock() + if sh == nil { + return + } + // If session handle is already destroyed, this becomes a noop. + // If there are still active queries and if the recycled session is reused before they complete, Cloud Spanner will cancel them + // on behalf of the new transaction on the session. + if sh != nil { + sh.recycle() + } +} + +// Timestamp returns the timestamp chosen to perform reads and +// queries in this transaction. The value can only be read after some +// read or query has either returned some data or completed without +// returning any data. +func (t *ReadOnlyTransaction) Timestamp() (time.Time, error) { + t.mu.Lock() + defer t.mu.Unlock() + if t.rts.IsZero() { + return t.rts, errRtsUnavailable() + } + return t.rts, nil +} + +// WithTimestampBound specifies the TimestampBound to use for read or query. +// This can only be used before the first read or query is invoked. Note: +// bounded staleness is not available with general ReadOnlyTransactions; use a +// single-use ReadOnlyTransaction instead. +// +// The returned value is the ReadOnlyTransaction so calls can be chained. +func (t *ReadOnlyTransaction) WithTimestampBound(tb TimestampBound) *ReadOnlyTransaction { + t.mu.Lock() + defer t.mu.Unlock() + if t.state == txNew { + // Only allow to set TimestampBound before the first query. + t.tb = tb + } + return t +} + +// ReadWriteTransaction provides a locking read-write transaction. +// +// This type of transaction is the only way to write data into Cloud Spanner; +// (*Client).Apply and (*Client).ApplyAtLeastOnce use transactions +// internally. These transactions rely on pessimistic locking and, if +// necessary, two-phase commit. Locking read-write transactions may abort, +// requiring the application to retry. However, the interface exposed by +// (*Client).ReadWriteTransaction eliminates the need for applications to write +// retry loops explicitly. +// +// Locking transactions may be used to atomically read-modify-write data +// anywhere in a database. This type of transaction is externally consistent. +// +// Clients should attempt to minimize the amount of time a transaction is +// active. Faster transactions commit with higher probability and cause less +// contention. Cloud Spanner attempts to keep read locks active as long as the +// transaction continues to do reads. Long periods of inactivity at the client +// may cause Cloud Spanner to release a transaction's locks and abort it. +// +// Reads performed within a transaction acquire locks on the data being +// read. Writes can only be done at commit time, after all reads have been +// completed. Conceptually, a read-write transaction consists of zero or more +// reads or SQL queries followed by a commit. +// +// See (*Client).ReadWriteTransaction for an example. +// +// Semantics +// +// Cloud Spanner can commit the transaction if all read locks it acquired are still +// valid at commit time, and it is able to acquire write locks for all +// writes. Cloud Spanner can abort the transaction for any reason. If a commit +// attempt returns ABORTED, Cloud Spanner guarantees that the transaction has not +// modified any user data in Cloud Spanner. +// +// Unless the transaction commits, Cloud Spanner makes no guarantees about how long +// the transaction's locks were held for. It is an error to use Cloud Spanner locks +// for any sort of mutual exclusion other than between Cloud Spanner transactions +// themselves. +// +// Aborted transactions +// +// Application code does not need to retry explicitly; RunInTransaction will +// automatically retry a transaction if an attempt results in an abort. The +// lock priority of a transaction increases after each prior aborted +// transaction, meaning that the next attempt has a slightly better chance of +// success than before. +// +// Under some circumstances (e.g., many transactions attempting to modify the +// same row(s)), a transaction can abort many times in a short period before +// successfully committing. Thus, it is not a good idea to cap the number of +// retries a transaction can attempt; instead, it is better to limit the total +// amount of wall time spent retrying. +// +// Idle transactions +// +// A transaction is considered idle if it has no outstanding reads or SQL +// queries and has not started a read or SQL query within the last 10 +// seconds. Idle transactions can be aborted by Cloud Spanner so that they don't hold +// on to locks indefinitely. In that case, the commit will fail with error +// ABORTED. +// +// If this behavior is undesirable, periodically executing a simple SQL query +// in the transaction (e.g., SELECT 1) prevents the transaction from becoming +// idle. +type ReadWriteTransaction struct { + // txReadOnly contains methods for performing transactional reads. + txReadOnly + // sh is the sessionHandle allocated from sp. It is set only once during the initialization of ReadWriteTransaction. + sh *sessionHandle + // tx is the transaction ID in Cloud Spanner that uniquely identifies the ReadWriteTransaction. + // It is set only once in ReadWriteTransaction.begin() during the initialization of ReadWriteTransaction. + tx transactionID + // mu protects concurrent access to the internal states of ReadWriteTransaction. + mu sync.Mutex + // state is the current transaction status of the read-write transaction. + state txState + // wb is the set of buffered mutations waiting to be commited. + wb []*Mutation +} + +// BufferWrite adds a list of mutations to the set of updates that will be +// applied when the transaction is committed. It does not actually apply the +// write until the transaction is committed, so the operation does not +// block. The effects of the write won't be visible to any reads (including +// reads done in the same transaction) until the transaction commits. +// +// See the example for Client.ReadWriteTransaction. +func (t *ReadWriteTransaction) BufferWrite(ms []*Mutation) error { + t.mu.Lock() + defer t.mu.Unlock() + if t.state == txClosed { + return errTxClosed() + } + if t.state != txActive { + return errUnexpectedTxState(t.state) + } + t.wb = append(t.wb, ms...) + return nil +} + +// acquire implements txReadEnv.acquire. +func (t *ReadWriteTransaction) acquire(ctx context.Context) (*sessionHandle, *sppb.TransactionSelector, error) { + ts := &sppb.TransactionSelector{ + Selector: &sppb.TransactionSelector_Id{ + Id: t.tx, + }, + } + t.mu.Lock() + defer t.mu.Unlock() + switch t.state { + case txClosed: + return nil, nil, errTxClosed() + case txActive: + return t.sh, ts, nil + } + return nil, nil, errUnexpectedTxState(t.state) +} + +// release implements txReadEnv.release. +func (t *ReadWriteTransaction) release(err error) { + t.mu.Lock() + sh := t.sh + t.mu.Unlock() + if sh != nil && shouldDropSession(err) { + sh.destroy() + } +} + +func beginTransaction(ctx context.Context, sid string, client sppb.SpannerClient) (transactionID, error) { + var tx transactionID + err := runRetryable(ctx, func(ctx context.Context) error { + res, e := client.BeginTransaction(ctx, &sppb.BeginTransactionRequest{ + Session: sid, + Options: &sppb.TransactionOptions{ + Mode: &sppb.TransactionOptions_ReadWrite_{ + ReadWrite: &sppb.TransactionOptions_ReadWrite{}, + }, + }, + }) + if e != nil { + return e + } + tx = res.Id + return nil + }) + if err != nil { + return nil, err + } + return tx, nil +} + +// begin starts a read-write transacton on Cloud Spanner, it is always called before any of the public APIs. +func (t *ReadWriteTransaction) begin(ctx context.Context) error { + if t.tx != nil { + t.state = txActive + return nil + } + tx, err := beginTransaction(contextWithOutgoingMetadata(ctx, t.sh.getMetadata()), t.sh.getID(), t.sh.getClient()) + if err == nil { + t.tx = tx + t.state = txActive + return nil + } + if shouldDropSession(err) { + t.sh.destroy() + } + return err +} + +// commit tries to commit a readwrite transaction to Cloud Spanner. It also returns the commit timestamp for the transactions. +func (t *ReadWriteTransaction) commit(ctx context.Context) (time.Time, error) { + var ts time.Time + t.mu.Lock() + t.state = txClosed // No futher operations after commit. + mPb, err := mutationsProto(t.wb) + t.mu.Unlock() + if err != nil { + return ts, err + } + // In case that sessionHandle was destroyed but transaction body fails to report it. + sid, client := t.sh.getID(), t.sh.getClient() + if sid == "" || client == nil { + return ts, errSessionClosed(t.sh) + } + err = runRetryable(contextWithOutgoingMetadata(ctx, t.sh.getMetadata()), func(ctx context.Context) error { + var trailer metadata.MD + res, e := client.Commit(ctx, &sppb.CommitRequest{ + Session: sid, + Transaction: &sppb.CommitRequest_TransactionId{ + TransactionId: t.tx, + }, + Mutations: mPb, + }, grpc.Trailer(&trailer)) + if e != nil { + return toSpannerErrorWithMetadata(e, trailer) + } + if tstamp := res.GetCommitTimestamp(); tstamp != nil { + ts = time.Unix(tstamp.Seconds, int64(tstamp.Nanos)) + } + return nil + }) + if shouldDropSession(err) { + t.sh.destroy() + } + return ts, err +} + +// rollback is called when a commit is aborted or the transaction body runs into error. +func (t *ReadWriteTransaction) rollback(ctx context.Context) { + t.mu.Lock() + // Forbid further operations on rollbacked transaction. + t.state = txClosed + t.mu.Unlock() + // In case that sessionHandle was destroyed but transaction body fails to report it. + sid, client := t.sh.getID(), t.sh.getClient() + if sid == "" || client == nil { + return + } + err := runRetryable(contextWithOutgoingMetadata(ctx, t.sh.getMetadata()), func(ctx context.Context) error { + _, e := client.Rollback(ctx, &sppb.RollbackRequest{ + Session: sid, + TransactionId: t.tx, + }) + return e + }) + if shouldDropSession(err) { + t.sh.destroy() + } + return +} + +// runInTransaction executes f under a read-write transaction context. +func (t *ReadWriteTransaction) runInTransaction(ctx context.Context, f func(context.Context, *ReadWriteTransaction) error) (time.Time, error) { + var ( + ts time.Time + err error + ) + if err = f(context.WithValue(ctx, transactionInProgressKey{}, 1), t); err == nil { + // Try to commit if transaction body returns no error. + ts, err = t.commit(ctx) + } + if err != nil { + if isAbortErr(err) { + // Retry the transaction using the same session on ABORT error. + // Cloud Spanner will create the new transaction with the previous one's wound-wait priority. + err = errRetry(err) + return ts, err + } + // Not going to commit, according to API spec, should rollback the transaction. + t.rollback(ctx) + return ts, err + } + // err == nil, return commit timestamp. + return ts, err +} + +// writeOnlyTransaction provides the most efficient way of doing write-only transactions. It essentially does blind writes to Cloud Spanner. +type writeOnlyTransaction struct { + // sp is the session pool which writeOnlyTransaction uses to get Cloud Spanner sessions for blind writes. + sp *sessionPool +} + +// applyAtLeastOnce commits a list of mutations to Cloud Spanner for at least once, unless one of the following happends: +// 1) Context is timeout. +// 2) An unretryable error(e.g. database not found) occurs. +// 3) There is a malformed Mutation object. +func (t *writeOnlyTransaction) applyAtLeastOnce(ctx context.Context, ms ...*Mutation) (time.Time, error) { + var ( + ts time.Time + sh *sessionHandle + ) + mPb, err := mutationsProto(ms) + if err != nil { + // Malformed mutation found, just return the error. + return ts, err + } + err = runRetryable(ctx, func(ct context.Context) error { + var e error + var trailers metadata.MD + if sh == nil || sh.getID() == "" || sh.getClient() == nil { + // No usable session for doing the commit, take one from pool. + sh, e = t.sp.take(ctx) + if e != nil { + // sessionPool.Take already retries for session creations/retrivals. + return e + } + } + res, e := sh.getClient().Commit(contextWithOutgoingMetadata(ctx, sh.getMetadata()), &sppb.CommitRequest{ + Session: sh.getID(), + Transaction: &sppb.CommitRequest_SingleUseTransaction{ + SingleUseTransaction: &sppb.TransactionOptions{ + Mode: &sppb.TransactionOptions_ReadWrite_{ + ReadWrite: &sppb.TransactionOptions_ReadWrite{}, + }, + }, + }, + Mutations: mPb, + }, grpc.Trailer(&trailers)) + if e != nil { + if isAbortErr(e) { + // Mask ABORT error as retryable, because aborted transactions are allowed to be retried. + return errRetry(toSpannerErrorWithMetadata(e, trailers)) + } + if shouldDropSession(e) { + // Discard the bad session. + sh.destroy() + } + return e + } + if tstamp := res.GetCommitTimestamp(); tstamp != nil { + ts = time.Unix(tstamp.Seconds, int64(tstamp.Nanos)) + } + return nil + }) + if sh != nil { + sh.recycle() + } + return ts, err +} + +// isAbortedErr returns true if the error indicates that an gRPC call is aborted on the server side. +func isAbortErr(err error) bool { + if err == nil { + return false + } + if ErrCode(err) == codes.Aborted { + return true + } + return false +} diff --git a/vendor/cloud.google.com/go/spanner/value.go b/vendor/cloud.google.com/go/spanner/value.go new file mode 100644 index 000000000..c09713264 --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/value.go @@ -0,0 +1,1244 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spanner + +import ( + "encoding/base64" + "fmt" + "math" + "reflect" + "strconv" + "strings" + "time" + + "cloud.google.com/go/civil" + "cloud.google.com/go/internal/fields" + proto "github.com/golang/protobuf/proto" + proto3 "github.com/golang/protobuf/ptypes/struct" + sppb "google.golang.org/genproto/googleapis/spanner/v1" + "google.golang.org/grpc/codes" +) + +// NullInt64 represents a Cloud Spanner INT64 that may be NULL. +type NullInt64 struct { + Int64 int64 + Valid bool // Valid is true if Int64 is not NULL. +} + +// String implements Stringer.String for NullInt64 +func (n NullInt64) String() string { + if !n.Valid { + return fmt.Sprintf("%v", "<null>") + } + return fmt.Sprintf("%v", n.Int64) +} + +// NullString represents a Cloud Spanner STRING that may be NULL. +type NullString struct { + StringVal string + Valid bool // Valid is true if StringVal is not NULL. +} + +// String implements Stringer.String for NullString +func (n NullString) String() string { + if !n.Valid { + return fmt.Sprintf("%v", "<null>") + } + return fmt.Sprintf("%q", n.StringVal) +} + +// NullFloat64 represents a Cloud Spanner FLOAT64 that may be NULL. +type NullFloat64 struct { + Float64 float64 + Valid bool // Valid is true if Float64 is not NULL. +} + +// String implements Stringer.String for NullFloat64 +func (n NullFloat64) String() string { + if !n.Valid { + return fmt.Sprintf("%v", "<null>") + } + return fmt.Sprintf("%v", n.Float64) +} + +// NullBool represents a Cloud Spanner BOOL that may be NULL. +type NullBool struct { + Bool bool + Valid bool // Valid is true if Bool is not NULL. +} + +// String implements Stringer.String for NullBool +func (n NullBool) String() string { + if !n.Valid { + return fmt.Sprintf("%v", "<null>") + } + return fmt.Sprintf("%v", n.Bool) +} + +// NullTime represents a Cloud Spanner TIMESTAMP that may be null. +type NullTime struct { + Time time.Time + Valid bool // Valid is true if Time is not NULL. +} + +// String implements Stringer.String for NullTime +func (n NullTime) String() string { + if !n.Valid { + return fmt.Sprintf("%s", "<null>") + } + return fmt.Sprintf("%q", n.Time.Format(time.RFC3339Nano)) +} + +// NullDate represents a Cloud Spanner DATE that may be null. +type NullDate struct { + Date civil.Date + Valid bool // Valid is true if Date is not NULL. +} + +// String implements Stringer.String for NullDate +func (n NullDate) String() string { + if !n.Valid { + return fmt.Sprintf("%s", "<null>") + } + return fmt.Sprintf("%q", n.Date) +} + +// NullRow represents a Cloud Spanner STRUCT that may be NULL. +// See also the document for Row. +// Note that NullRow is not a valid Cloud Spanner column Type. +type NullRow struct { + Row Row + Valid bool // Valid is true if Row is not NULL. +} + +// GenericColumnValue represents the generic encoded value and type of the +// column. See google.spanner.v1.ResultSet proto for details. This can be +// useful for proxying query results when the result types are not known in +// advance. +type GenericColumnValue struct { + Type *sppb.Type + Value *proto3.Value +} + +// Decode decodes a GenericColumnValue. The ptr argument should be a pointer +// to a Go value that can accept v. +func (v GenericColumnValue) Decode(ptr interface{}) error { + return decodeValue(v.Value, v.Type, ptr) +} + +// NewGenericColumnValue creates a GenericColumnValue from Go value that is +// valid for Cloud Spanner. +func newGenericColumnValue(v interface{}) (*GenericColumnValue, error) { + value, typ, err := encodeValue(v) + if err != nil { + return nil, err + } + return &GenericColumnValue{Value: value, Type: typ}, nil +} + +// errTypeMismatch returns error for destination not having a compatible type +// with source Cloud Spanner type. +func errTypeMismatch(srcType sppb.TypeCode, isArray bool, dst interface{}) error { + usage := srcType.String() + if isArray { + usage = fmt.Sprintf("%v[%v]", sppb.TypeCode_ARRAY, srcType) + } + return spannerErrorf(codes.InvalidArgument, "type %T cannot be used for decoding %v", dst, usage) +} + +// errNilSpannerType returns error for nil Cloud Spanner type in decoding. +func errNilSpannerType() error { + return spannerErrorf(codes.FailedPrecondition, "unexpected nil Cloud Spanner data type in decoding") +} + +// errNilSrc returns error for decoding from nil proto value. +func errNilSrc() error { + return spannerErrorf(codes.FailedPrecondition, "unexpected nil Cloud Spanner value in decoding") +} + +// errNilDst returns error for decoding into nil interface{}. +func errNilDst(dst interface{}) error { + return spannerErrorf(codes.InvalidArgument, "cannot decode into nil type %T", dst) +} + +// errNilArrElemType returns error for input Cloud Spanner data type being a array but without a +// non-nil array element type. +func errNilArrElemType(t *sppb.Type) error { + return spannerErrorf(codes.FailedPrecondition, "array type %v is with nil array element type", t) +} + +// errDstNotForNull returns error for decoding a SQL NULL value into a destination which doesn't +// support NULL values. +func errDstNotForNull(dst interface{}) error { + return spannerErrorf(codes.InvalidArgument, "destination %T cannot support NULL SQL values", dst) +} + +// errBadEncoding returns error for decoding wrongly encoded types. +func errBadEncoding(v *proto3.Value, err error) error { + return spannerErrorf(codes.FailedPrecondition, "%v wasn't correctly encoded: <%v>", v, err) +} + +func parseNullTime(v *proto3.Value, p *NullTime, code sppb.TypeCode, isNull bool) error { + if p == nil { + return errNilDst(p) + } + if code != sppb.TypeCode_TIMESTAMP { + return errTypeMismatch(code, false, p) + } + if isNull { + *p = NullTime{} + return nil + } + x, err := getStringValue(v) + if err != nil { + return err + } + y, err := time.Parse(time.RFC3339Nano, x) + if err != nil { + return errBadEncoding(v, err) + } + p.Valid = true + p.Time = y + return nil +} + +// decodeValue decodes a protobuf Value into a pointer to a Go value, as +// specified by sppb.Type. +func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error { + if v == nil { + return errNilSrc() + } + if t == nil { + return errNilSpannerType() + } + code := t.Code + acode := sppb.TypeCode_TYPE_CODE_UNSPECIFIED + if code == sppb.TypeCode_ARRAY { + if t.ArrayElementType == nil { + return errNilArrElemType(t) + } + acode = t.ArrayElementType.Code + } + typeErr := errTypeMismatch(code, false, ptr) + if code == sppb.TypeCode_ARRAY { + typeErr = errTypeMismatch(acode, true, ptr) + } + nullErr := errDstNotForNull(ptr) + _, isNull := v.Kind.(*proto3.Value_NullValue) + + // Do the decoding based on the type of ptr. + switch p := ptr.(type) { + case nil: + return errNilDst(nil) + case *string: + if p == nil { + return errNilDst(p) + } + if code != sppb.TypeCode_STRING { + return typeErr + } + if isNull { + return nullErr + } + x, err := getStringValue(v) + if err != nil { + return err + } + *p = x + case *NullString: + if p == nil { + return errNilDst(p) + } + if code != sppb.TypeCode_STRING { + return typeErr + } + if isNull { + *p = NullString{} + break + } + x, err := getStringValue(v) + if err != nil { + return err + } + p.Valid = true + p.StringVal = x + case *[]NullString: + if p == nil { + return errNilDst(p) + } + if acode != sppb.TypeCode_STRING { + return typeErr + } + if isNull { + *p = nil + break + } + x, err := getListValue(v) + if err != nil { + return err + } + y, err := decodeStringArray(x) + if err != nil { + return err + } + *p = y + case *[]byte: + if p == nil { + return errNilDst(p) + } + if code != sppb.TypeCode_BYTES { + return typeErr + } + if isNull { + *p = nil + break + } + x, err := getStringValue(v) + if err != nil { + return err + } + y, err := base64.StdEncoding.DecodeString(x) + if err != nil { + return errBadEncoding(v, err) + } + *p = y + case *[][]byte: + if p == nil { + return errNilDst(p) + } + if acode != sppb.TypeCode_BYTES { + return typeErr + } + if isNull { + *p = nil + break + } + x, err := getListValue(v) + if err != nil { + return err + } + y, err := decodeByteArray(x) + if err != nil { + return err + } + *p = y + case *int64: + if p == nil { + return errNilDst(p) + } + if code != sppb.TypeCode_INT64 { + return typeErr + } + if isNull { + return nullErr + } + x, err := getStringValue(v) + if err != nil { + return err + } + y, err := strconv.ParseInt(x, 10, 64) + if err != nil { + return errBadEncoding(v, err) + } + *p = y + case *NullInt64: + if p == nil { + return errNilDst(p) + } + if code != sppb.TypeCode_INT64 { + return typeErr + } + if isNull { + *p = NullInt64{} + break + } + x, err := getStringValue(v) + if err != nil { + return err + } + y, err := strconv.ParseInt(x, 10, 64) + if err != nil { + return errBadEncoding(v, err) + } + p.Valid = true + p.Int64 = y + case *[]NullInt64: + if p == nil { + return errNilDst(p) + } + if acode != sppb.TypeCode_INT64 { + return typeErr + } + if isNull { + *p = nil + break + } + x, err := getListValue(v) + if err != nil { + return err + } + y, err := decodeIntArray(x) + if err != nil { + return err + } + *p = y + case *bool: + if p == nil { + return errNilDst(p) + } + if code != sppb.TypeCode_BOOL { + return typeErr + } + if isNull { + return nullErr + } + x, err := getBoolValue(v) + if err != nil { + return err + } + *p = x + case *NullBool: + if p == nil { + return errNilDst(p) + } + if code != sppb.TypeCode_BOOL { + return typeErr + } + if isNull { + *p = NullBool{} + break + } + x, err := getBoolValue(v) + if err != nil { + return err + } + p.Valid = true + p.Bool = x + case *[]NullBool: + if p == nil { + return errNilDst(p) + } + if acode != sppb.TypeCode_BOOL { + return typeErr + } + if isNull { + *p = nil + break + } + x, err := getListValue(v) + if err != nil { + return err + } + y, err := decodeBoolArray(x) + if err != nil { + return err + } + *p = y + case *float64: + if p == nil { + return errNilDst(p) + } + if code != sppb.TypeCode_FLOAT64 { + return typeErr + } + if isNull { + return nullErr + } + x, err := getFloat64Value(v) + if err != nil { + return err + } + *p = x + case *NullFloat64: + if p == nil { + return errNilDst(p) + } + if code != sppb.TypeCode_FLOAT64 { + return typeErr + } + if isNull { + *p = NullFloat64{} + break + } + x, err := getFloat64Value(v) + if err != nil { + return err + } + p.Valid = true + p.Float64 = x + case *[]NullFloat64: + if p == nil { + return errNilDst(p) + } + if acode != sppb.TypeCode_FLOAT64 { + return typeErr + } + if isNull { + *p = nil + break + } + x, err := getListValue(v) + if err != nil { + return err + } + y, err := decodeFloat64Array(x) + if err != nil { + return err + } + *p = y + case *time.Time: + var nt NullTime + if isNull { + return nullErr + } + err := parseNullTime(v, &nt, code, isNull) + if err != nil { + return nil + } + *p = nt.Time + case *NullTime: + err := parseNullTime(v, p, code, isNull) + if err != nil { + return err + } + case *[]NullTime: + if p == nil { + return errNilDst(p) + } + if acode != sppb.TypeCode_TIMESTAMP { + return typeErr + } + if isNull { + *p = nil + break + } + x, err := getListValue(v) + if err != nil { + return err + } + y, err := decodeTimeArray(x) + if err != nil { + return err + } + *p = y + case *civil.Date: + if p == nil { + return errNilDst(p) + } + if code != sppb.TypeCode_DATE { + return typeErr + } + if isNull { + return nullErr + } + x, err := getStringValue(v) + if err != nil { + return err + } + y, err := civil.ParseDate(x) + if err != nil { + return errBadEncoding(v, err) + } + *p = y + case *NullDate: + if p == nil { + return errNilDst(p) + } + if code != sppb.TypeCode_DATE { + return typeErr + } + if isNull { + *p = NullDate{} + break + } + x, err := getStringValue(v) + if err != nil { + return err + } + y, err := civil.ParseDate(x) + if err != nil { + return errBadEncoding(v, err) + } + p.Valid = true + p.Date = y + case *[]NullDate: + if p == nil { + return errNilDst(p) + } + if acode != sppb.TypeCode_DATE { + return typeErr + } + if isNull { + *p = nil + break + } + x, err := getListValue(v) + if err != nil { + return err + } + y, err := decodeDateArray(x) + if err != nil { + return err + } + *p = y + case *[]NullRow: + if p == nil { + return errNilDst(p) + } + if acode != sppb.TypeCode_STRUCT { + return typeErr + } + if isNull { + *p = nil + break + } + x, err := getListValue(v) + if err != nil { + return err + } + y, err := decodeRowArray(t.ArrayElementType.StructType, x) + if err != nil { + return err + } + *p = y + case *GenericColumnValue: + *p = GenericColumnValue{ + // Deep clone to ensure subsequent changes to t or v + // don't affect our decoded value. + Type: proto.Clone(t).(*sppb.Type), + Value: proto.Clone(v).(*proto3.Value), + } + default: + // Check if the proto encoding is for an array of structs. + if !(code == sppb.TypeCode_ARRAY && acode == sppb.TypeCode_STRUCT) { + return typeErr + } + vp := reflect.ValueOf(p) + if !vp.IsValid() { + return errNilDst(p) + } + if !isPtrStructPtrSlice(vp.Type()) { + // The container is not a pointer to a struct pointer slice. + return typeErr + } + // Only use reflection for nil detection on slow path. + // Also, IsNil panics on many types, so check it after the type check. + if vp.IsNil() { + return errNilDst(p) + } + if isNull { + // The proto Value is encoding NULL, set the pointer to struct + // slice to nil as well. + vp.Elem().Set(reflect.Zero(vp.Elem().Type())) + break + } + x, err := getListValue(v) + if err != nil { + return err + } + if err = decodeStructArray(t.ArrayElementType.StructType, x, p); err != nil { + return err + } + } + return nil +} + +// errSrvVal returns an error for getting a wrong source protobuf value in decoding. +func errSrcVal(v *proto3.Value, want string) error { + return spannerErrorf(codes.FailedPrecondition, "cannot use %v(Kind: %T) as %s Value", + v, v.GetKind(), want) +} + +// getStringValue returns the string value encoded in proto3.Value v whose +// kind is proto3.Value_StringValue. +func getStringValue(v *proto3.Value) (string, error) { + if x, ok := v.GetKind().(*proto3.Value_StringValue); ok && x != nil { + return x.StringValue, nil + } + return "", errSrcVal(v, "String") +} + +// getBoolValue returns the bool value encoded in proto3.Value v whose +// kind is proto3.Value_BoolValue. +func getBoolValue(v *proto3.Value) (bool, error) { + if x, ok := v.GetKind().(*proto3.Value_BoolValue); ok && x != nil { + return x.BoolValue, nil + } + return false, errSrcVal(v, "Bool") +} + +// getListValue returns the proto3.ListValue contained in proto3.Value v whose +// kind is proto3.Value_ListValue. +func getListValue(v *proto3.Value) (*proto3.ListValue, error) { + if x, ok := v.GetKind().(*proto3.Value_ListValue); ok && x != nil { + return x.ListValue, nil + } + return nil, errSrcVal(v, "List") +} + +// errUnexpectedNumStr returns error for decoder getting a unexpected string for +// representing special float values. +func errUnexpectedNumStr(s string) error { + return spannerErrorf(codes.FailedPrecondition, "unexpected string value %q for number", s) +} + +// getFloat64Value returns the float64 value encoded in proto3.Value v whose +// kind is proto3.Value_NumberValue / proto3.Value_StringValue. +// Cloud Spanner uses string to encode NaN, Infinity and -Infinity. +func getFloat64Value(v *proto3.Value) (float64, error) { + switch x := v.GetKind().(type) { + case *proto3.Value_NumberValue: + if x == nil { + break + } + return x.NumberValue, nil + case *proto3.Value_StringValue: + if x == nil { + break + } + switch x.StringValue { + case "NaN": + return math.NaN(), nil + case "Infinity": + return math.Inf(1), nil + case "-Infinity": + return math.Inf(-1), nil + default: + return 0, errUnexpectedNumStr(x.StringValue) + } + } + return 0, errSrcVal(v, "Number") +} + +// errNilListValue returns error for unexpected nil ListValue in decoding Cloud Spanner ARRAYs. +func errNilListValue(sqlType string) error { + return spannerErrorf(codes.FailedPrecondition, "unexpected nil ListValue in decoding %v array", sqlType) +} + +// errDecodeArrayElement returns error for failure in decoding single array element. +func errDecodeArrayElement(i int, v proto.Message, sqlType string, err error) error { + se, ok := toSpannerError(err).(*Error) + if !ok { + return spannerErrorf(codes.Unknown, + "cannot decode %v(array element %v) as %v, error = <%v>", v, i, sqlType, err) + } + se.decorate(fmt.Sprintf("cannot decode %v(array element %v) as %v", v, i, sqlType)) + return se +} + +// decodeStringArray decodes proto3.ListValue pb into a NullString slice. +func decodeStringArray(pb *proto3.ListValue) ([]NullString, error) { + if pb == nil { + return nil, errNilListValue("STRING") + } + a := make([]NullString, len(pb.Values)) + for i, v := range pb.Values { + if err := decodeValue(v, stringType(), &a[i]); err != nil { + return nil, errDecodeArrayElement(i, v, "STRING", err) + } + } + return a, nil +} + +// decodeIntArray decodes proto3.ListValue pb into a NullInt64 slice. +func decodeIntArray(pb *proto3.ListValue) ([]NullInt64, error) { + if pb == nil { + return nil, errNilListValue("INT64") + } + a := make([]NullInt64, len(pb.Values)) + for i, v := range pb.Values { + if err := decodeValue(v, intType(), &a[i]); err != nil { + return nil, errDecodeArrayElement(i, v, "INT64", err) + } + } + return a, nil +} + +// decodeBoolArray decodes proto3.ListValue pb into a NullBool slice. +func decodeBoolArray(pb *proto3.ListValue) ([]NullBool, error) { + if pb == nil { + return nil, errNilListValue("BOOL") + } + a := make([]NullBool, len(pb.Values)) + for i, v := range pb.Values { + if err := decodeValue(v, boolType(), &a[i]); err != nil { + return nil, errDecodeArrayElement(i, v, "BOOL", err) + } + } + return a, nil +} + +// decodeFloat64Array decodes proto3.ListValue pb into a NullFloat64 slice. +func decodeFloat64Array(pb *proto3.ListValue) ([]NullFloat64, error) { + if pb == nil { + return nil, errNilListValue("FLOAT64") + } + a := make([]NullFloat64, len(pb.Values)) + for i, v := range pb.Values { + if err := decodeValue(v, floatType(), &a[i]); err != nil { + return nil, errDecodeArrayElement(i, v, "FLOAT64", err) + } + } + return a, nil +} + +// decodeByteArray decodes proto3.ListValue pb into a slice of byte slice. +func decodeByteArray(pb *proto3.ListValue) ([][]byte, error) { + if pb == nil { + return nil, errNilListValue("BYTES") + } + a := make([][]byte, len(pb.Values)) + for i, v := range pb.Values { + if err := decodeValue(v, bytesType(), &a[i]); err != nil { + return nil, errDecodeArrayElement(i, v, "BYTES", err) + } + } + return a, nil +} + +// decodeTimeArray decodes proto3.ListValue pb into a NullTime slice. +func decodeTimeArray(pb *proto3.ListValue) ([]NullTime, error) { + if pb == nil { + return nil, errNilListValue("TIMESTAMP") + } + a := make([]NullTime, len(pb.Values)) + for i, v := range pb.Values { + if err := decodeValue(v, timeType(), &a[i]); err != nil { + return nil, errDecodeArrayElement(i, v, "TIMESTAMP", err) + } + } + return a, nil +} + +// decodeDateArray decodes proto3.ListValue pb into a NullDate slice. +func decodeDateArray(pb *proto3.ListValue) ([]NullDate, error) { + if pb == nil { + return nil, errNilListValue("DATE") + } + a := make([]NullDate, len(pb.Values)) + for i, v := range pb.Values { + if err := decodeValue(v, dateType(), &a[i]); err != nil { + return nil, errDecodeArrayElement(i, v, "DATE", err) + } + } + return a, nil +} + +func errNotStructElement(i int, v *proto3.Value) error { + return errDecodeArrayElement(i, v, "STRUCT", + spannerErrorf(codes.FailedPrecondition, "%v(type: %T) doesn't encode Cloud Spanner STRUCT", v, v)) +} + +// decodeRowArray decodes proto3.ListValue pb into a NullRow slice according to +// the structual information given in sppb.StructType ty. +func decodeRowArray(ty *sppb.StructType, pb *proto3.ListValue) ([]NullRow, error) { + if pb == nil { + return nil, errNilListValue("STRUCT") + } + a := make([]NullRow, len(pb.Values)) + for i := range pb.Values { + switch v := pb.Values[i].GetKind().(type) { + case *proto3.Value_ListValue: + a[i] = NullRow{ + Row: Row{ + fields: ty.Fields, + vals: v.ListValue.Values, + }, + Valid: true, + } + // Null elements not currently supported by the server, see + // https://cloud.google.com/spanner/docs/query-syntax#using-structs-with-select + case *proto3.Value_NullValue: + // no-op, a[i] is NullRow{} already + default: + return nil, errNotStructElement(i, pb.Values[i]) + } + } + return a, nil +} + +// structFieldColumn returns the name of i-th field of struct type typ if the field +// is untagged; otherwise, it returns the tagged name of the field. +func structFieldColumn(typ reflect.Type, i int) (col string, ok bool) { + desc := typ.Field(i) + if desc.PkgPath != "" || desc.Anonymous { + // Skip unexported or anonymous fields. + return "", false + } + col = desc.Name + if tag := desc.Tag.Get("spanner"); tag != "" { + if tag == "-" { + // Skip fields tagged "-" to match encoding/json and others. + return "", false + } + col = tag + if idx := strings.Index(tag, ","); idx != -1 { + col = tag[:idx] + } + } + return col, true +} + +// errNilSpannerStructType returns error for unexpected nil Cloud Spanner STRUCT schema type in decoding. +func errNilSpannerStructType() error { + return spannerErrorf(codes.FailedPrecondition, "unexpected nil StructType in decoding Cloud Spanner STRUCT") +} + +// errUnnamedField returns error for decoding a Cloud Spanner STRUCT with unnamed field into a Go struct. +func errUnnamedField(ty *sppb.StructType, i int) error { + return spannerErrorf(codes.InvalidArgument, "unnamed field %v in Cloud Spanner STRUCT %+v", i, ty) +} + +// errNoOrDupGoField returns error for decoding a Cloud Spanner +// STRUCT into a Go struct which is either missing a field, or has duplicate fields. +func errNoOrDupGoField(s interface{}, f string) error { + return spannerErrorf(codes.InvalidArgument, "Go struct %+v(type %T) has no or duplicate fields for Cloud Spanner STRUCT field %v", s, s, f) +} + +// errDupColNames returns error for duplicated Cloud Spanner STRUCT field names found in decoding a Cloud Spanner STRUCT into a Go struct. +func errDupSpannerField(f string, ty *sppb.StructType) error { + return spannerErrorf(codes.InvalidArgument, "duplicated field name %q in Cloud Spanner STRUCT %+v", f, ty) +} + +// errDecodeStructField returns error for failure in decoding a single field of a Cloud Spanner STRUCT. +func errDecodeStructField(ty *sppb.StructType, f string, err error) error { + se, ok := toSpannerError(err).(*Error) + if !ok { + return spannerErrorf(codes.Unknown, + "cannot decode field %v of Cloud Spanner STRUCT %+v, error = <%v>", f, ty, err) + } + se.decorate(fmt.Sprintf("cannot decode field %v of Cloud Spanner STRUCT %+v", f, ty)) + return se +} + +// decodeStruct decodes proto3.ListValue pb into struct referenced by pointer ptr, according to +// the structual information given in sppb.StructType ty. +func decodeStruct(ty *sppb.StructType, pb *proto3.ListValue, ptr interface{}) error { + if reflect.ValueOf(ptr).IsNil() { + return errNilDst(ptr) + } + if ty == nil { + return errNilSpannerStructType() + } + // t holds the structual information of ptr. + t := reflect.TypeOf(ptr).Elem() + // v is the actual value that ptr points to. + v := reflect.ValueOf(ptr).Elem() + + fields, err := fieldCache.Fields(t) + if err != nil { + return toSpannerError(err) + } + seen := map[string]bool{} + for i, f := range ty.Fields { + if f.Name == "" { + return errUnnamedField(ty, i) + } + sf := fields.Match(f.Name) + if sf == nil { + return errNoOrDupGoField(ptr, f.Name) + } + if seen[f.Name] { + // We don't allow duplicated field name. + return errDupSpannerField(f.Name, ty) + } + // Try to decode a single field. + if err := decodeValue(pb.Values[i], f.Type, v.FieldByIndex(sf.Index).Addr().Interface()); err != nil { + return errDecodeStructField(ty, f.Name, err) + } + // Mark field f.Name as processed. + seen[f.Name] = true + } + return nil +} + +// isPtrStructPtrSlice returns true if ptr is a pointer to a slice of struct pointers. +func isPtrStructPtrSlice(t reflect.Type) bool { + if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Slice { + // t is not a pointer to a slice. + return false + } + if t = t.Elem(); t.Elem().Kind() != reflect.Ptr || t.Elem().Elem().Kind() != reflect.Struct { + // the slice that t points to is not a slice of struct pointers. + return false + } + return true +} + +// decodeStructArray decodes proto3.ListValue pb into struct slice referenced by pointer ptr, according to the +// structual information given in a sppb.StructType. +func decodeStructArray(ty *sppb.StructType, pb *proto3.ListValue, ptr interface{}) error { + if pb == nil { + return errNilListValue("STRUCT") + } + // Type of the struct pointers stored in the slice that ptr points to. + ts := reflect.TypeOf(ptr).Elem().Elem() + // The slice that ptr points to, might be nil at this point. + v := reflect.ValueOf(ptr).Elem() + // Allocate empty slice. + v.Set(reflect.MakeSlice(v.Type(), 0, len(pb.Values))) + // Decode every struct in pb.Values. + for i, pv := range pb.Values { + // Check if pv is a NULL value. + if _, isNull := pv.Kind.(*proto3.Value_NullValue); isNull { + // Append a nil pointer to the slice. + v.Set(reflect.Append(v, reflect.New(ts).Elem())) + continue + } + // Allocate empty struct. + s := reflect.New(ts.Elem()) + // Get proto3.ListValue l from proto3.Value pv. + l, err := getListValue(pv) + if err != nil { + return errDecodeArrayElement(i, pv, "STRUCT", err) + } + // Decode proto3.ListValue l into struct referenced by s.Interface(). + if err = decodeStruct(ty, l, s.Interface()); err != nil { + return errDecodeArrayElement(i, pv, "STRUCT", err) + } + // Append the decoded struct back into the slice. + v.Set(reflect.Append(v, s)) + } + return nil +} + +// errEncoderUnsupportedType returns error for not being able to encode a value of +// certain type. +func errEncoderUnsupportedType(v interface{}) error { + return spannerErrorf(codes.InvalidArgument, "client doesn't support type %T", v) +} + +// encodeValue encodes a Go native type into a proto3.Value. +func encodeValue(v interface{}) (*proto3.Value, *sppb.Type, error) { + pb := &proto3.Value{ + Kind: &proto3.Value_NullValue{NullValue: proto3.NullValue_NULL_VALUE}, + } + var pt *sppb.Type + var err error + switch v := v.(type) { + case nil: + case string: + pb.Kind = stringKind(v) + pt = stringType() + case NullString: + if v.Valid { + return encodeValue(v.StringVal) + } + case []string: + if v != nil { + pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] }) + if err != nil { + return nil, nil, err + } + pt = listType(stringType()) + } + case []NullString: + if v != nil { + pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] }) + if err != nil { + return nil, nil, err + } + pt = listType(stringType()) + } + case []byte: + if v != nil { + pb.Kind = stringKind(base64.StdEncoding.EncodeToString(v)) + pt = bytesType() + } + case [][]byte: + if v != nil { + pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] }) + if err != nil { + return nil, nil, err + } + pt = listType(bytesType()) + } + case int: + pb.Kind = stringKind(strconv.FormatInt(int64(v), 10)) + pt = intType() + case []int: + if v != nil { + pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] }) + if err != nil { + return nil, nil, err + } + pt = listType(intType()) + } + case int64: + pb.Kind = stringKind(strconv.FormatInt(v, 10)) + pt = intType() + case []int64: + if v != nil { + pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] }) + if err != nil { + return nil, nil, err + } + pt = listType(intType()) + } + case NullInt64: + if v.Valid { + return encodeValue(v.Int64) + } + case []NullInt64: + if v != nil { + pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] }) + if err != nil { + return nil, nil, err + } + pt = listType(intType()) + } + case bool: + pb.Kind = &proto3.Value_BoolValue{BoolValue: v} + pt = boolType() + case []bool: + if v != nil { + pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] }) + if err != nil { + return nil, nil, err + } + pt = listType(boolType()) + } + case NullBool: + if v.Valid { + return encodeValue(v.Bool) + } + case []NullBool: + if v != nil { + pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] }) + if err != nil { + return nil, nil, err + } + pt = listType(boolType()) + } + case float64: + pb.Kind = &proto3.Value_NumberValue{NumberValue: v} + pt = floatType() + case []float64: + if v != nil { + pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] }) + if err != nil { + return nil, nil, err + } + pt = listType(floatType()) + } + case NullFloat64: + if v.Valid { + return encodeValue(v.Float64) + } + case []NullFloat64: + if v != nil { + pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] }) + if err != nil { + return nil, nil, err + } + pt = listType(floatType()) + } + case time.Time: + pb.Kind = stringKind(v.UTC().Format(time.RFC3339Nano)) + pt = timeType() + case []time.Time: + if v != nil { + pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] }) + if err != nil { + return nil, nil, err + } + pt = listType(timeType()) + } + case NullTime: + if v.Valid { + return encodeValue(v.Time) + } + case []NullTime: + if v != nil { + pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] }) + if err != nil { + return nil, nil, err + } + pt = listType(timeType()) + } + case civil.Date: + pb.Kind = stringKind(v.String()) + pt = dateType() + case []civil.Date: + if v != nil { + pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] }) + if err != nil { + return nil, nil, err + } + pt = listType(dateType()) + } + case NullDate: + if v.Valid { + return encodeValue(v.Date) + } + case []NullDate: + if v != nil { + pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] }) + if err != nil { + return nil, nil, err + } + pt = listType(dateType()) + } + case GenericColumnValue: + // Deep clone to ensure subsequent changes to v before + // transmission don't affect our encoded value. + pb = proto.Clone(v.Value).(*proto3.Value) + pt = proto.Clone(v.Type).(*sppb.Type) + default: + return nil, nil, errEncoderUnsupportedType(v) + } + return pb, pt, nil +} + +// encodeValueArray encodes a Value array into a proto3.ListValue. +func encodeValueArray(vs []interface{}) (*proto3.ListValue, error) { + lv := &proto3.ListValue{} + lv.Values = make([]*proto3.Value, 0, len(vs)) + for _, v := range vs { + pb, _, err := encodeValue(v) + if err != nil { + return nil, err + } + lv.Values = append(lv.Values, pb) + } + return lv, nil +} + +// encodeArray assumes that all values of the array element type encode without error. +func encodeArray(len int, at func(int) interface{}) (*proto3.Value, error) { + vs := make([]*proto3.Value, len) + var err error + for i := 0; i < len; i++ { + vs[i], _, err = encodeValue(at(i)) + if err != nil { + return nil, err + } + } + return listProto(vs...), nil +} + +func spannerTagParser(t reflect.StructTag) (name string, keep bool, other interface{}, err error) { + if s := t.Get("spanner"); s != "" { + if s == "-" { + return "", false, nil, nil + } + return s, true, nil, nil + } + return "", true, nil, nil +} + +var fieldCache = fields.NewCache(spannerTagParser, nil, nil) diff --git a/vendor/cloud.google.com/go/spanner/value_test.go b/vendor/cloud.google.com/go/spanner/value_test.go new file mode 100644 index 000000000..00748b0a1 --- /dev/null +++ b/vendor/cloud.google.com/go/spanner/value_test.go @@ -0,0 +1,611 @@ +/* +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spanner + +import ( + "math" + "reflect" + "testing" + "time" + + "cloud.google.com/go/civil" + "github.com/golang/protobuf/proto" + proto3 "github.com/golang/protobuf/ptypes/struct" + sppb "google.golang.org/genproto/googleapis/spanner/v1" +) + +var ( + t1 = mustParseTime("2016-11-15T15:04:05.999999999Z") + // Boundaries + t2 = mustParseTime("0000-01-01T00:00:00.000000000Z") + t3 = mustParseTime("9999-12-31T23:59:59.999999999Z") + // Local timezone + t4 = time.Now() + d1 = mustParseDate("2016-11-15") + d2 = mustParseDate("1678-01-01") +) + +func mustParseTime(s string) time.Time { + t, err := time.Parse(time.RFC3339Nano, s) + if err != nil { + panic(err) + } + return t +} + +func mustParseDate(s string) civil.Date { + d, err := civil.ParseDate(s) + if err != nil { + panic(err) + } + return d +} + +// Test encoding Values. +func TestEncodeValue(t *testing.T) { + var ( + tString = stringType() + tInt = intType() + tBool = boolType() + tFloat = floatType() + tBytes = bytesType() + tTime = timeType() + tDate = dateType() + ) + for i, test := range []struct { + in interface{} + want *proto3.Value + wantType *sppb.Type + }{ + // STRING / STRING ARRAY + {"abc", stringProto("abc"), tString}, + {NullString{"abc", true}, stringProto("abc"), tString}, + {NullString{"abc", false}, nullProto(), nil}, + {[]string{"abc", "bcd"}, listProto(stringProto("abc"), stringProto("bcd")), listType(tString)}, + {[]NullString{{"abcd", true}, {"xyz", false}}, listProto(stringProto("abcd"), nullProto()), listType(tString)}, + // BYTES / BYTES ARRAY + {[]byte("foo"), bytesProto([]byte("foo")), tBytes}, + {[]byte(nil), nullProto(), nil}, + {[][]byte{nil, []byte("ab")}, listProto(nullProto(), bytesProto([]byte("ab"))), listType(tBytes)}, + {[][]byte(nil), nullProto(), nil}, + // INT64 / INT64 ARRAY + {7, intProto(7), tInt}, + {[]int{31, 127}, listProto(intProto(31), intProto(127)), listType(tInt)}, + {int64(81), intProto(81), tInt}, + {[]int64{33, 129}, listProto(intProto(33), intProto(129)), listType(tInt)}, + {NullInt64{11, true}, intProto(11), tInt}, + {NullInt64{11, false}, nullProto(), nil}, + {[]NullInt64{{35, true}, {131, false}}, listProto(intProto(35), nullProto()), listType(tInt)}, + // BOOL / BOOL ARRAY + {true, boolProto(true), tBool}, + {NullBool{true, true}, boolProto(true), tBool}, + {NullBool{true, false}, nullProto(), nil}, + {[]bool{true, false}, listProto(boolProto(true), boolProto(false)), listType(tBool)}, + {[]NullBool{{true, true}, {true, false}}, listProto(boolProto(true), nullProto()), listType(tBool)}, + // FLOAT64 / FLOAT64 ARRAY + {3.14, floatProto(3.14), tFloat}, + {NullFloat64{3.1415, true}, floatProto(3.1415), tFloat}, + {NullFloat64{math.Inf(1), true}, floatProto(math.Inf(1)), tFloat}, + {NullFloat64{3.14159, false}, nullProto(), nil}, + {[]float64{3.141, 0.618, math.Inf(-1)}, listProto(floatProto(3.141), floatProto(0.618), floatProto(math.Inf(-1))), listType(tFloat)}, + {[]NullFloat64{{3.141, true}, {0.618, false}}, listProto(floatProto(3.141), nullProto()), listType(tFloat)}, + // TIMESTAMP / TIMESTAMP ARRAY + {t1, timeProto(t1), tTime}, + {NullTime{t1, true}, timeProto(t1), tTime}, + {NullTime{t1, false}, nullProto(), nil}, + {[]time.Time{t1, t2, t3, t4}, listProto(timeProto(t1), timeProto(t2), timeProto(t3), timeProto(t4)), listType(tTime)}, + {[]NullTime{{t1, true}, {t1, false}}, listProto(timeProto(t1), nullProto()), listType(tTime)}, + // DATE / DATE ARRAY + {d1, dateProto(d1), tDate}, + {NullDate{d1, true}, dateProto(d1), tDate}, + {NullDate{civil.Date{}, false}, nullProto(), nil}, + {[]civil.Date{d1, d2}, listProto(dateProto(d1), dateProto(d2)), listType(tDate)}, + {[]NullDate{{d1, true}, {civil.Date{}, false}}, listProto(dateProto(d1), nullProto()), listType(tDate)}, + // GenericColumnValue + {GenericColumnValue{tString, stringProto("abc")}, stringProto("abc"), tString}, + {GenericColumnValue{tString, nullProto()}, nullProto(), tString}, + // not actually valid (stringProto inside int list), but demonstrates pass-through. + { + GenericColumnValue{ + Type: listType(tInt), + Value: listProto(intProto(5), nullProto(), stringProto("bcd")), + }, + listProto(intProto(5), nullProto(), stringProto("bcd")), + listType(tInt), + }, + } { + got, gotType, err := encodeValue(test.in) + if err != nil { + t.Fatalf("#%d: got error during encoding: %v, want nil", i, err) + } + if !reflect.DeepEqual(got, test.want) { + t.Errorf("#%d: got encode result: %v, want %v", i, got, test.want) + } + if !reflect.DeepEqual(gotType, test.wantType) { + t.Errorf("#%d: got encode type: %v, want %v", i, gotType, test.wantType) + } + } +} + +// Test decoding Values. +func TestDecodeValue(t *testing.T) { + for i, test := range []struct { + in *proto3.Value + t *sppb.Type + want interface{} + fail bool + }{ + // STRING + {stringProto("abc"), stringType(), "abc", false}, + {nullProto(), stringType(), "abc", true}, + {stringProto("abc"), stringType(), NullString{"abc", true}, false}, + {nullProto(), stringType(), NullString{}, false}, + // STRING ARRAY + { + listProto(stringProto("abc"), nullProto(), stringProto("bcd")), + listType(stringType()), + []NullString{{"abc", true}, {}, {"bcd", true}}, + false, + }, + {nullProto(), listType(stringType()), []NullString(nil), false}, + // BYTES + {bytesProto([]byte("ab")), bytesType(), []byte("ab"), false}, + {nullProto(), bytesType(), []byte(nil), false}, + // BYTES ARRAY + {listProto(bytesProto([]byte("ab")), nullProto()), listType(bytesType()), [][]byte{[]byte("ab"), nil}, false}, + {nullProto(), listType(bytesType()), [][]byte(nil), false}, + //INT64 + {intProto(15), intType(), int64(15), false}, + {nullProto(), intType(), int64(0), true}, + {intProto(15), intType(), NullInt64{15, true}, false}, + {nullProto(), intType(), NullInt64{}, false}, + // INT64 ARRAY + {listProto(intProto(91), nullProto(), intProto(87)), listType(intType()), []NullInt64{{91, true}, {}, {87, true}}, false}, + {nullProto(), listType(intType()), []NullInt64(nil), false}, + // BOOL + {boolProto(true), boolType(), true, false}, + {nullProto(), boolType(), true, true}, + {boolProto(true), boolType(), NullBool{true, true}, false}, + {nullProto(), boolType(), NullBool{}, false}, + // BOOL ARRAY + {listProto(boolProto(true), boolProto(false), nullProto()), listType(boolType()), []NullBool{{true, true}, {false, true}, {}}, false}, + {nullProto(), listType(boolType()), []NullBool(nil), false}, + // FLOAT64 + {floatProto(3.14), floatType(), 3.14, false}, + {nullProto(), floatType(), 0.00, true}, + {floatProto(3.14), floatType(), NullFloat64{3.14, true}, false}, + {nullProto(), floatType(), NullFloat64{}, false}, + // FLOAT64 ARRAY + { + listProto(floatProto(math.Inf(1)), floatProto(math.Inf(-1)), nullProto(), floatProto(3.1)), + listType(floatType()), + []NullFloat64{{math.Inf(1), true}, {math.Inf(-1), true}, {}, {3.1, true}}, + false, + }, + {nullProto(), listType(floatType()), []NullFloat64(nil), false}, + // TIMESTAMP + {timeProto(t1), timeType(), t1, false}, + {timeProto(t1), timeType(), NullTime{t1, true}, false}, + {nullProto(), timeType(), NullTime{}, false}, + // TIMESTAMP ARRAY + {listProto(timeProto(t1), timeProto(t2), timeProto(t3), nullProto()), listType(timeType()), []NullTime{{t1, true}, {t2, true}, {t3, true}, {}}, false}, + {nullProto(), listType(timeType()), []NullTime(nil), false}, + // DATE + {dateProto(d1), dateType(), d1, false}, + {dateProto(d1), dateType(), NullDate{d1, true}, false}, + {nullProto(), dateType(), NullDate{}, false}, + // DATE ARRAY + {listProto(dateProto(d1), dateProto(d2), nullProto()), listType(dateType()), []NullDate{{d1, true}, {d2, true}, {}}, false}, + {nullProto(), listType(dateType()), []NullDate(nil), false}, + // STRUCT ARRAY + // STRUCT schema is equal to the following Go struct: + // type s struct { + // Col1 NullInt64 + // Col2 []struct { + // SubCol1 float64 + // SubCol2 string + // } + // } + { + in: listProto( + listProto( + intProto(3), + listProto( + listProto(floatProto(3.14), stringProto("this")), + listProto(floatProto(0.57), stringProto("siht")), + ), + ), + listProto( + nullProto(), + nullProto(), + ), + nullProto(), + ), + t: listType( + structType( + mkField("Col1", intType()), + mkField( + "Col2", + listType( + structType( + mkField("SubCol1", floatType()), + mkField("SubCol2", stringType()), + ), + ), + ), + ), + ), + want: []NullRow{ + { + Row: Row{ + fields: []*sppb.StructType_Field{ + mkField("Col1", intType()), + mkField( + "Col2", + listType( + structType( + mkField("SubCol1", floatType()), + mkField("SubCol2", stringType()), + ), + ), + ), + }, + vals: []*proto3.Value{ + intProto(3), + listProto( + listProto(floatProto(3.14), stringProto("this")), + listProto(floatProto(0.57), stringProto("siht")), + ), + }, + }, + Valid: true, + }, + { + Row: Row{ + fields: []*sppb.StructType_Field{ + mkField("Col1", intType()), + mkField( + "Col2", + listType( + structType( + mkField("SubCol1", floatType()), + mkField("SubCol2", stringType()), + ), + ), + ), + }, + vals: []*proto3.Value{ + nullProto(), + nullProto(), + }, + }, + Valid: true, + }, + {}, + }, + fail: false, + }, + { + in: listProto( + listProto( + intProto(3), + listProto( + listProto(floatProto(3.14), stringProto("this")), + listProto(floatProto(0.57), stringProto("siht")), + ), + ), + listProto( + nullProto(), + nullProto(), + ), + nullProto(), + ), + t: listType( + structType( + mkField("Col1", intType()), + mkField( + "Col2", + listType( + structType( + mkField("SubCol1", floatType()), + mkField("SubCol2", stringType()), + ), + ), + ), + ), + ), + want: []*struct { + Col1 NullInt64 + StructCol []*struct { + SubCol1 NullFloat64 + SubCol2 string + } `spanner:"Col2"` + }{ + { + Col1: NullInt64{3, true}, + StructCol: []*struct { + SubCol1 NullFloat64 + SubCol2 string + }{ + { + SubCol1: NullFloat64{3.14, true}, + SubCol2: "this", + }, + { + SubCol1: NullFloat64{0.57, true}, + SubCol2: "siht", + }, + }, + }, + { + Col1: NullInt64{}, + StructCol: []*struct { + SubCol1 NullFloat64 + SubCol2 string + }(nil), + }, + nil, + }, + fail: false, + }, + // GenericColumnValue + {stringProto("abc"), stringType(), GenericColumnValue{stringType(), stringProto("abc")}, false}, + {nullProto(), stringType(), GenericColumnValue{stringType(), nullProto()}, false}, + // not actually valid (stringProto inside int list), but demonstrates pass-through. + { + in: listProto(intProto(5), nullProto(), stringProto("bcd")), + t: listType(intType()), + want: GenericColumnValue{ + Type: listType(intType()), + Value: listProto(intProto(5), nullProto(), stringProto("bcd")), + }, + fail: false, + }, + } { + gotp := reflect.New(reflect.TypeOf(test.want)) + if err := decodeValue(test.in, test.t, gotp.Interface()); err != nil { + if !test.fail { + t.Errorf("%d: cannot decode %v(%v): %v", i, test.in, test.t, err) + } + continue + } + if test.fail { + t.Errorf("%d: decoding %v(%v) succeeds unexpectedly, want error", i, test.in, test.t) + continue + } + got := reflect.Indirect(gotp).Interface() + if !reflect.DeepEqual(got, test.want) { + t.Errorf("%d: unexpected decoding result - got %v, want %v", i, got, test.want) + continue + } + } +} + +// Test error cases for decodeValue. +func TestDecodeValueErrors(t *testing.T) { + for i, test := range []struct { + in *proto3.Value + t *sppb.Type + v interface{} + }{ + {nullProto(), stringType(), nil}, + {nullProto(), stringType(), 1}, + } { + err := decodeValue(test.in, test.t, test.v) + if err == nil { + t.Errorf("#%d: want error, got nil", i) + } + } +} + +// Test NaN encoding/decoding. +func TestNaN(t *testing.T) { + // Decode NaN value. + f := 0.0 + nf := NullFloat64{} + // To float64 + if err := decodeValue(floatProto(math.NaN()), floatType(), &f); err != nil { + t.Errorf("decodeValue returns %q for %v, want nil", err, floatProto(math.NaN())) + } + if !math.IsNaN(f) { + t.Errorf("f = %v, want %v", f, math.NaN()) + } + // To NullFloat64 + if err := decodeValue(floatProto(math.NaN()), floatType(), &nf); err != nil { + t.Errorf("decodeValue returns %q for %v, want nil", err, floatProto(math.NaN())) + } + if !math.IsNaN(nf.Float64) || !nf.Valid { + t.Errorf("f = %v, want %v", f, NullFloat64{math.NaN(), true}) + } + // Encode NaN value + // From float64 + v, _, err := encodeValue(math.NaN()) + if err != nil { + t.Errorf("encodeValue returns %q for NaN, want nil", err) + } + x, ok := v.GetKind().(*proto3.Value_NumberValue) + if !ok { + t.Errorf("incorrect type for v.GetKind(): %T, want *proto3.Value_NumberValue", v.GetKind()) + } + if !math.IsNaN(x.NumberValue) { + t.Errorf("x.NumberValue = %v, want %v", x.NumberValue, math.NaN()) + } + // From NullFloat64 + v, _, err = encodeValue(NullFloat64{math.NaN(), true}) + if err != nil { + t.Errorf("encodeValue returns %q for NaN, want nil", err) + } + x, ok = v.GetKind().(*proto3.Value_NumberValue) + if !ok { + t.Errorf("incorrect type for v.GetKind(): %T, want *proto3.Value_NumberValue", v.GetKind()) + } + if !math.IsNaN(x.NumberValue) { + t.Errorf("x.NumberValue = %v, want %v", x.NumberValue, math.NaN()) + } +} + +func TestGenericColumnValue(t *testing.T) { + for _, test := range []struct { + in GenericColumnValue + want interface{} + fail bool + }{ + {GenericColumnValue{stringType(), stringProto("abc")}, "abc", false}, + {GenericColumnValue{stringType(), stringProto("abc")}, 5, true}, + {GenericColumnValue{listType(intType()), listProto(intProto(91), nullProto(), intProto(87))}, []NullInt64{{91, true}, {}, {87, true}}, false}, + {GenericColumnValue{intType(), intProto(42)}, GenericColumnValue{intType(), intProto(42)}, false}, // trippy! :-) + } { + // We take a copy and mutate because we're paranoid about immutability. + inCopy := GenericColumnValue{ + Type: proto.Clone(test.in.Type).(*sppb.Type), + Value: proto.Clone(test.in.Value).(*proto3.Value), + } + gotp := reflect.New(reflect.TypeOf(test.want)) + if err := inCopy.Decode(gotp.Interface()); err != nil { + if !test.fail { + t.Errorf("cannot decode %v to %v: %v", test.in, test.want, err) + } + continue + } + if test.fail { + t.Errorf("decoding %v to %v succeeds unexpectedly", test.in, test.want) + } + // mutations to inCopy should be invisible to gotp. + inCopy.Type.Code = sppb.TypeCode_TIMESTAMP + inCopy.Value.Kind = &proto3.Value_NumberValue{NumberValue: 999} + got := reflect.Indirect(gotp).Interface() + if !reflect.DeepEqual(got, test.want) { + t.Errorf("unexpected decode result - got %v, want %v", got, test.want) + } + + // Test we can go backwards as well. + v, err := newGenericColumnValue(test.want) + if err != nil { + t.Errorf("NewGenericColumnValue failed: %v", err) + continue + } + if !reflect.DeepEqual(*v, test.in) { + t.Errorf("unexpected encode result - got %v, want %v", v, test.in) + } + // If want is a GenericColumnValue, mutate its underlying value to validate + // we have taken a deep copy. + if gcv, ok := test.want.(GenericColumnValue); ok { + gcv.Type.Code = sppb.TypeCode_TIMESTAMP + gcv.Value.Kind = &proto3.Value_NumberValue{NumberValue: 999} + if !reflect.DeepEqual(*v, test.in) { + t.Errorf("expected deep copy - got %v, want %v", v, test.in) + } + } + } +} + +func runBench(b *testing.B, size int, f func(a []int) (*proto3.Value, *sppb.Type, error)) { + a := make([]int, size) + for i := 0; i < b.N; i++ { + f(a) + } +} + +func BenchmarkEncodeIntArrayOrig1(b *testing.B) { + runBench(b, 1, encodeIntArrayOrig) +} + +func BenchmarkEncodeIntArrayOrig10(b *testing.B) { + runBench(b, 10, encodeIntArrayOrig) +} + +func BenchmarkEncodeIntArrayOrig100(b *testing.B) { + runBench(b, 100, encodeIntArrayOrig) +} + +func BenchmarkEncodeIntArrayOrig1000(b *testing.B) { + runBench(b, 1000, encodeIntArrayOrig) +} + +func BenchmarkEncodeIntArrayFunc1(b *testing.B) { + runBench(b, 1, encodeIntArrayFunc) +} + +func BenchmarkEncodeIntArrayFunc10(b *testing.B) { + runBench(b, 10, encodeIntArrayFunc) +} + +func BenchmarkEncodeIntArrayFunc100(b *testing.B) { + runBench(b, 100, encodeIntArrayFunc) +} + +func BenchmarkEncodeIntArrayFunc1000(b *testing.B) { + runBench(b, 1000, encodeIntArrayFunc) +} + +func BenchmarkEncodeIntArrayReflect1(b *testing.B) { + runBench(b, 1, encodeIntArrayReflect) +} + +func BenchmarkEncodeIntArrayReflect10(b *testing.B) { + runBench(b, 10, encodeIntArrayReflect) +} + +func BenchmarkEncodeIntArrayReflect100(b *testing.B) { + runBench(b, 100, encodeIntArrayReflect) +} + +func BenchmarkEncodeIntArrayReflect1000(b *testing.B) { + runBench(b, 1000, encodeIntArrayReflect) +} + +func encodeIntArrayOrig(a []int) (*proto3.Value, *sppb.Type, error) { + vs := make([]*proto3.Value, len(a)) + var err error + for i := range a { + vs[i], _, err = encodeValue(a[i]) + if err != nil { + return nil, nil, err + } + } + return listProto(vs...), listType(intType()), nil +} + +func encodeIntArrayFunc(a []int) (*proto3.Value, *sppb.Type, error) { + v, err := encodeArray(len(a), func(i int) interface{} { return a[i] }) + if err != nil { + return nil, nil, err + } + return v, listType(intType()), nil +} + +func encodeIntArrayReflect(a []int) (*proto3.Value, *sppb.Type, error) { + v, err := encodeArrayReflect(a) + if err != nil { + return nil, nil, err + } + return v, listType(intType()), nil +} + +func encodeArrayReflect(a interface{}) (*proto3.Value, error) { + va := reflect.ValueOf(a) + len := va.Len() + vs := make([]*proto3.Value, len) + var err error + for i := 0; i < len; i++ { + vs[i], _, err = encodeValue(va.Index(i).Interface()) + if err != nil { + return nil, err + } + } + return listProto(vs...), nil +} |
