Skip to content

Commit fd918a5

Browse files
carabasdanielactions-user
authored andcommitted
Add gcp big query model adapter (#2093)
closes: https://linear.app/overmind/issue/ENG-806/support-google-cloud-bigquery-model GitOrigin-RevId: 96af7ad08418adb54bd049bcdec19c21a03b40a6
1 parent 1d2f9bf commit fd918a5

File tree

7 files changed

+518
-2
lines changed

7 files changed

+518
-2
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
package integrationtests
2+
3+
import (
4+
"context"
5+
"os"
6+
"strings"
7+
"testing"
8+
9+
"cloud.google.com/go/bigquery"
10+
"go.uber.org/mock/gomock"
11+
12+
"github.com/overmindtech/cli/sources/gcp/manual"
13+
gcpshared "github.com/overmindtech/cli/sources/gcp/shared"
14+
)
15+
16+
func TestBigQueryModel(t *testing.T) {
17+
projectID := os.Getenv("GCP_PROJECT_ID")
18+
if projectID == "" {
19+
t.Skip("GCP_PROJECT_ID environment variable is not set, skipping BigQuery model tests")
20+
}
21+
22+
dataSet := "test_dataset"
23+
model := "test_model"
24+
25+
ctx := context.Background()
26+
ctrl := gomock.NewController(t)
27+
client, err := bigquery.NewClient(ctx, projectID)
28+
if err != nil {
29+
t.Fatalf("Failed to create BigQuery client: %v", err)
30+
}
31+
defer client.Close()
32+
33+
defer ctrl.Finish()
34+
t.Run("Setup", func(t *testing.T) {
35+
datasetItem := client.Dataset(dataSet)
36+
err := datasetItem.Create(ctx, &bigquery.DatasetMetadata{
37+
Name: dataSet,
38+
Description: "Test dataset for model integration tests",
39+
})
40+
if err != nil && !strings.Contains(err.Error(), "Already Exists") {
41+
t.Fatalf("Failed to create dataset %s: %v", dataSet, err)
42+
}
43+
t.Logf("Dataset %s created successfully", dataSet)
44+
45+
query := "CREATE OR REPLACE MODEL `" + projectID + "." + dataSet + "." + model + "` OPTIONS " +
46+
`(model_type='LOGISTIC_REG',
47+
labels=['animal_label']
48+
) AS
49+
SELECT
50+
1 AS feature_dummy, -- A dummy feature for 'cats'
51+
'cats' AS animal_label -- The primary label we want to output
52+
UNION ALL
53+
SELECT
54+
2 AS feature_dummy, -- A different dummy feature for the second label
55+
'dogs' AS animal_label; -- A second, dummy label to satisfy the classification requirement`
56+
57+
op, err := client.Query(query).Run(ctx)
58+
if err != nil {
59+
t.Fatalf("Failed to create model: %v", err)
60+
}
61+
if _, err := op.Wait(ctx); err != nil {
62+
t.Fatalf("Failed to wait for model creation: %v", err)
63+
}
64+
modelItem := client.Dataset(dataSet).Model(model)
65+
modelMetadata, err := modelItem.Update(ctx, bigquery.ModelMetadataToUpdate{
66+
Name: model,
67+
Description: "Test model description",
68+
}, "")
69+
if err != nil {
70+
t.Fatalf("Failed to create model: %v", err)
71+
}
72+
t.Logf("Model created: %s", modelMetadata.Name)
73+
})
74+
t.Run("Get", func(t *testing.T) {
75+
bigqueryClient := gcpshared.NewBigQueryModelClient(client, dataSet)
76+
adapter := manual.NewBigQueryModel(bigqueryClient, projectID)
77+
sdpItem, err := adapter.Get(ctx, dataSet, model)
78+
if err != nil {
79+
t.Fatalf("Failed to get item: %v", err)
80+
}
81+
if sdpItem == nil {
82+
t.Fatal("Expected an item, got nil")
83+
}
84+
uniqueAttrKey := sdpItem.GetUniqueAttribute()
85+
86+
uniqueAttrValue, attrErr := sdpItem.GetAttributes().Get(uniqueAttrKey)
87+
if attrErr != nil {
88+
t.Fatalf("Failed to get unique attribute: %v", err)
89+
}
90+
91+
if uniqueAttrValue != model {
92+
t.Fatalf("Expected unique attribute value to be %s, got %s", model, uniqueAttrValue)
93+
}
94+
95+
sdpItems, err := adapter.Search(ctx, dataSet)
96+
if err != nil {
97+
t.Fatalf("Failed to search items: %v", err)
98+
}
99+
if len(sdpItems) < 1 {
100+
t.Fatalf("Expected at least one model in dataset, got %d", len(sdpItems))
101+
}
102+
103+
var found bool
104+
for _, item := range sdpItems {
105+
if v, err := item.GetAttributes().Get(uniqueAttrKey); err == nil && v == model {
106+
found = true
107+
break
108+
}
109+
}
110+
111+
if !found {
112+
t.Fatalf("Expected to find model %s in the list of dataset models", model)
113+
}
114+
})
115+
t.Run("Teardown", func(t *testing.T) {
116+
// Cleanup resources if needed
117+
err := client.Dataset(dataSet).DeleteWithContents(ctx)
118+
if err != nil {
119+
t.Fatalf("Failed to delete dataset %s: %v", dataSet, err)
120+
} else {
121+
t.Logf("Dataset %s deleted successfully", dataSet)
122+
}
123+
})
124+
}

‎sources/gcp/integration-tests/spanner-database_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@ import (
88
"os"
99
"testing"
1010

11+
12+
"github.com/googleapis/gax-go/v2/apierror"
13+
"google.golang.org/grpc/codes"
1114
database "cloud.google.com/go/spanner/admin/database/apiv1"
1215
"cloud.google.com/go/spanner/admin/database/apiv1/databasepb"
1316
instance "cloud.google.com/go/spanner/admin/instance/apiv1"
14-
"github.com/googleapis/gax-go/v2/apierror"
15-
"google.golang.org/grpc/codes"
1617

1718
"github.com/overmindtech/cli/sources/gcp/dynamic"
1819
gcpshared "github.com/overmindtech/cli/sources/gcp/shared"

‎sources/gcp/manual/big-query-model.go

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
package manual
2+
3+
import (
4+
"context"
5+
6+
"cloud.google.com/go/bigquery"
7+
8+
"github.com/overmindtech/cli/sdp-go"
9+
"github.com/overmindtech/cli/sources"
10+
gcpshared "github.com/overmindtech/cli/sources/gcp/shared"
11+
"github.com/overmindtech/cli/sources/shared"
12+
)
13+
14+
var (
15+
BigQueryModelLookupById = shared.NewItemTypeLookup("id", gcpshared.BigQueryModel)
16+
)
17+
18+
// BigQueryModelWrapper is a wrapper for the BigQueryModelClient that implements the sources.SearchableWrapper interface
19+
type BigQueryModelWrapper struct {
20+
client gcpshared.BigQueryModelClient
21+
*gcpshared.ProjectBase
22+
}
23+
24+
// NewBigQueryModel creates a new BigQueryModelWrapper instance
25+
func NewBigQueryModel(client gcpshared.BigQueryModelClient, projectID string) sources.SearchableWrapper {
26+
return &BigQueryModelWrapper{
27+
client: client,
28+
ProjectBase: gcpshared.NewProjectBase(
29+
projectID,
30+
sdp.AdapterCategory_ADAPTER_CATEGORY_DATABASE,
31+
gcpshared.BigQueryModel,
32+
),
33+
}
34+
}
35+
36+
func (m BigQueryModelWrapper) GetLookups() sources.ItemTypeLookups {
37+
return sources.ItemTypeLookups{
38+
BigQueryDatasetLookupByID,
39+
BigQueryModelLookupById,
40+
}
41+
}
42+
43+
func (m BigQueryModelWrapper) Get(ctx context.Context, queryParts ...string) (*sdp.Item, *sdp.QueryError) {
44+
metadata, err := m.client.Get(ctx, m.ProjectBase.ProjectID(), queryParts[0], queryParts[1])
45+
if err != nil {
46+
return nil, gcpshared.QueryError(err)
47+
}
48+
return m.GCPBigQueryMetadataToItem(ctx, queryParts[0], metadata)
49+
}
50+
51+
func (m BigQueryModelWrapper) GCPBigQueryMetadataToItem(ctx context.Context, dataSetId string, metadata *bigquery.ModelMetadata) (*sdp.Item, *sdp.QueryError) {
52+
attributes, err := shared.ToAttributesWithExclude(metadata, "labels")
53+
if err != nil {
54+
return nil, gcpshared.QueryError(err)
55+
}
56+
57+
sdpItem := &sdp.Item{
58+
Type: gcpshared.BigQueryModel.String(),
59+
UniqueAttribute: "Name",
60+
Attributes: attributes,
61+
Scope: m.DefaultScope(),
62+
Tags: metadata.Labels,
63+
}
64+
65+
sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{
66+
Query: &sdp.Query{
67+
Type: gcpshared.BigQueryDataset.String(),
68+
Method: sdp.QueryMethod_GET,
69+
Scope: m.DefaultScope(),
70+
Query: dataSetId,
71+
},
72+
// Model is in a dataset, if dataset is deleted, model is deleted.
73+
// If the model is deleted, the dataset is not deleted.
74+
BlastPropagation: &sdp.BlastPropagation{
75+
In: false,
76+
Out: true,
77+
},
78+
})
79+
80+
if metadata.EncryptionConfig != nil && metadata.EncryptionConfig.KMSKeyName != "" {
81+
values := gcpshared.ExtractPathParams(metadata.EncryptionConfig.KMSKeyName, "locations", "keyRings", "cryptoKeys")
82+
if len(values) == 3 && values[0] != "" && values[1] != "" && values[2] != "" {
83+
sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{
84+
85+
Query: &sdp.Query{
86+
Type: gcpshared.CloudKMSCryptoKey.String(),
87+
Method: sdp.QueryMethod_GET,
88+
Scope: m.ProjectID(),
89+
Query: shared.CompositeLookupKey(values...),
90+
},
91+
BlastPropagation: &sdp.BlastPropagation{
92+
In: true,
93+
Out: false,
94+
},
95+
})
96+
}
97+
}
98+
99+
for _, row := range metadata.RawTrainingRuns() {
100+
if row.DataSplitResult != nil && row.DataSplitResult.EvaluationTable.TableId != "" {
101+
sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{
102+
Query: &sdp.Query{
103+
Type: gcpshared.BigQueryTable.String(),
104+
Method: sdp.QueryMethod_GET,
105+
Scope: m.DefaultScope(),
106+
Query: shared.CompositeLookupKey(dataSetId, row.DataSplitResult.EvaluationTable.TableId),
107+
},
108+
BlastPropagation: &sdp.BlastPropagation{
109+
In: true,
110+
Out: false,
111+
},
112+
})
113+
}
114+
}
115+
116+
return sdpItem, nil
117+
}
118+
119+
func (m BigQueryModelWrapper) PotentialLinks() map[shared.ItemType]bool {
120+
return shared.NewItemTypesSet(
121+
gcpshared.CloudKMSCryptoKey,
122+
gcpshared.BigQueryDataset,
123+
gcpshared.BigQueryTable,
124+
)
125+
}
126+
127+
func (m BigQueryModelWrapper) SearchLookups() []sources.ItemTypeLookups {
128+
return []sources.ItemTypeLookups{
129+
{
130+
BigQueryModelLookupById,
131+
},
132+
}
133+
}
134+
135+
func (m BigQueryModelWrapper) Search(ctx context.Context, queryParts ...string) ([]*sdp.Item, *sdp.QueryError) {
136+
items, err := m.client.List(ctx, m.ProjectBase.ProjectID(), queryParts[0], func(ctx context.Context, metadata *bigquery.ModelMetadata) (*sdp.Item, *sdp.QueryError) {
137+
// Convert the dataset metadata to an SDP item
138+
attributes, err := shared.ToAttributesWithExclude(metadata, "labels")
139+
if err != nil {
140+
return nil, gcpshared.QueryError(err)
141+
}
142+
143+
item := &sdp.Item{
144+
Type: gcpshared.BigQueryModel.String(),
145+
UniqueAttribute: "Name",
146+
Scope: m.DefaultScope(),
147+
Attributes: attributes,
148+
Tags: metadata.Labels,
149+
}
150+
return item, nil
151+
})
152+
if err != nil {
153+
return nil, gcpshared.QueryError(err)
154+
}
155+
return items, nil
156+
}

0 commit comments

Comments
 (0)