Skip to content
This repository was archived by the owner on Sep 26, 2023. It is now read-only.

Commit b863041

Browse files
feat: add mtls feature to http and grpc transport provider (#1249)
* feat: add mtls support to grpc and http transport
1 parent 3b1859e commit b863041

File tree

16 files changed

+1045
-15
lines changed

16 files changed

+1045
-15
lines changed

‎gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import com.google.api.gax.rpc.HeaderProvider;
3939
import com.google.api.gax.rpc.TransportChannel;
4040
import com.google.api.gax.rpc.TransportChannelProvider;
41+
import com.google.api.gax.rpc.mtls.MtlsProvider;
4142
import com.google.auth.Credentials;
4243
import com.google.auth.oauth2.ComputeEngineCredentials;
4344
import com.google.common.annotations.VisibleForTesting;
@@ -46,16 +47,22 @@
4647
import com.google.common.collect.ImmutableList;
4748
import com.google.common.collect.ImmutableMap;
4849
import com.google.common.io.CharStreams;
50+
import io.grpc.ChannelCredentials;
51+
import io.grpc.Grpc;
4952
import io.grpc.ManagedChannel;
5053
import io.grpc.ManagedChannelBuilder;
54+
import io.grpc.TlsChannelCredentials;
5155
import io.grpc.alts.ComputeEngineChannelBuilder;
5256
import java.io.IOException;
5357
import java.io.InputStreamReader;
58+
import java.security.GeneralSecurityException;
59+
import java.security.KeyStore;
5460
import java.util.Map;
5561
import java.util.concurrent.Executor;
5662
import java.util.concurrent.ScheduledExecutorService;
5763
import java.util.concurrent.TimeUnit;
5864
import javax.annotation.Nullable;
65+
import javax.net.ssl.KeyManagerFactory;
5966
import org.threeten.bp.Duration;
6067

6168
/**
@@ -96,6 +103,7 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP
96103
@Nullable private final ChannelPrimer channelPrimer;
97104
@Nullable private final Boolean attemptDirectPath;
98105
@VisibleForTesting final ImmutableMap<String, ?> directPathServiceConfig;
106+
@Nullable private final MtlsProvider mtlsProvider;
99107

100108
@Nullable
101109
private final ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> channelConfigurator;
@@ -105,6 +113,7 @@ private InstantiatingGrpcChannelProvider(Builder builder) {
105113
this.executor = builder.executor;
106114
this.headerProvider = builder.headerProvider;
107115
this.endpoint = builder.endpoint;
116+
this.mtlsProvider = builder.mtlsProvider;
108117
this.envProvider = builder.envProvider;
109118
this.interceptorProvider = builder.interceptorProvider;
110119
this.maxInboundMessageSize = builder.maxInboundMessageSize;
@@ -216,8 +225,13 @@ private TransportChannel createChannel() throws IOException {
216225
int realPoolSize = MoreObjects.firstNonNull(poolSize, 1);
217226
ChannelFactory channelFactory =
218227
new ChannelFactory() {
228+
@Override
219229
public ManagedChannel createSingleChannel() throws IOException {
220-
return InstantiatingGrpcChannelProvider.this.createSingleChannel();
230+
try {
231+
return InstantiatingGrpcChannelProvider.this.createSingleChannel();
232+
} catch (GeneralSecurityException e) {
233+
throw new IOException(e);
234+
}
221235
}
222236
};
223237
ManagedChannel outerChannel;
@@ -264,7 +278,21 @@ static boolean isOnComputeEngine() {
264278
return false;
265279
}
266280

267-
private ManagedChannel createSingleChannel() throws IOException {
281+
@VisibleForTesting
282+
ChannelCredentials createMtlsChannelCredentials() throws IOException, GeneralSecurityException {
283+
if (mtlsProvider.useMtlsClientCertificate()) {
284+
KeyStore mtlsKeyStore = mtlsProvider.getKeyStore();
285+
if (mtlsKeyStore != null) {
286+
KeyManagerFactory factory =
287+
KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
288+
factory.init(mtlsKeyStore, new char[] {});
289+
return TlsChannelCredentials.newBuilder().keyManager(factory.getKeyManagers()).build();
290+
}
291+
}
292+
return null;
293+
}
294+
295+
private ManagedChannel createSingleChannel() throws IOException, GeneralSecurityException {
268296
GrpcHeaderInterceptor headerInterceptor =
269297
new GrpcHeaderInterceptor(headerProvider.getHeaders());
270298
GrpcMetadataHandlerInterceptor metadataHandlerInterceptor =
@@ -290,7 +318,12 @@ && isOnComputeEngine()) {
290318
builder.keepAliveTimeout(DIRECT_PATH_KEEP_ALIVE_TIMEOUT_SECONDS, TimeUnit.SECONDS);
291319
builder.defaultServiceConfig(directPathServiceConfig);
292320
} else {
293-
builder = ManagedChannelBuilder.forAddress(serviceAddress, port);
321+
ChannelCredentials channelCredentials = createMtlsChannelCredentials();
322+
if (channelCredentials != null) {
323+
builder = Grpc.newChannelBuilder(endpoint, channelCredentials);
324+
} else {
325+
builder = ManagedChannelBuilder.forAddress(serviceAddress, port);
326+
}
294327
}
295328
builder =
296329
builder
@@ -376,6 +409,7 @@ public static final class Builder {
376409
private HeaderProvider headerProvider;
377410
private String endpoint;
378411
private EnvironmentProvider envProvider;
412+
private MtlsProvider mtlsProvider = new MtlsProvider();
379413
@Nullable private GrpcInterceptorProvider interceptorProvider;
380414
@Nullable private Integer maxInboundMessageSize;
381415
@Nullable private Integer maxInboundMetadataSize;
@@ -412,6 +446,7 @@ private Builder(InstantiatingGrpcChannelProvider provider) {
412446
this.channelPrimer = provider.channelPrimer;
413447
this.attemptDirectPath = provider.attemptDirectPath;
414448
this.directPathServiceConfig = provider.directPathServiceConfig;
449+
this.mtlsProvider = provider.mtlsProvider;
415450
}
416451

417452
/** Sets the number of available CPUs, used internally for testing. */
@@ -458,6 +493,12 @@ public Builder setEndpoint(String endpoint) {
458493
return this;
459494
}
460495

496+
@VisibleForTesting
497+
Builder setMtlsProvider(MtlsProvider mtlsProvider) {
498+
this.mtlsProvider = mtlsProvider;
499+
return this;
500+
}
501+
461502
/**
462503
* Sets the GrpcInterceptorProvider for this TransportChannelProvider.
463504
*

‎gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.Builder;
3939
import com.google.api.gax.rpc.HeaderProvider;
4040
import com.google.api.gax.rpc.TransportChannelProvider;
41+
import com.google.api.gax.rpc.mtls.AbstractMtlsTransportChannelTest;
42+
import com.google.api.gax.rpc.mtls.MtlsProvider;
4143
import com.google.auth.oauth2.CloudShellCredentials;
4244
import com.google.auth.oauth2.ComputeEngineCredentials;
4345
import com.google.common.collect.ImmutableList;
@@ -46,6 +48,7 @@
4648
import io.grpc.ManagedChannelBuilder;
4749
import io.grpc.alts.ComputeEngineChannelBuilder;
4850
import java.io.IOException;
51+
import java.security.GeneralSecurityException;
4952
import java.util.ArrayList;
5053
import java.util.Collections;
5154
import java.util.HashMap;
@@ -63,8 +66,7 @@
6366
import org.threeten.bp.Duration;
6467

6568
@RunWith(JUnit4.class)
66-
public class InstantiatingGrpcChannelProviderTest {
67-
69+
public class InstantiatingGrpcChannelProviderTest extends AbstractMtlsTransportChannelTest {
6870
@Test
6971
public void testEndpoint() {
7072
String endpoint = "localhost:8080";
@@ -499,4 +501,17 @@ public void testWithCustomDirectPathServiceConfig() {
499501
ImmutableMap<String, ?> defaultServiceConfig = provider.directPathServiceConfig;
500502
assertThat(defaultServiceConfig).isEqualTo(passedServiceConfig);
501503
}
504+
505+
@Override
506+
protected Object getMtlsObjectFromTransportChannel(MtlsProvider provider)
507+
throws IOException, GeneralSecurityException {
508+
InstantiatingGrpcChannelProvider channelProvider =
509+
InstantiatingGrpcChannelProvider.newBuilder()
510+
.setEndpoint("localhost:8080")
511+
.setMtlsProvider(provider)
512+
.setHeaderProvider(Mockito.mock(HeaderProvider.class))
513+
.setExecutor(Mockito.mock(Executor.class))
514+
.build();
515+
return channelProvider.createMtlsChannelCredentials();
516+
}
502517
}

‎gax-grpc/src/test/java/com/google/api/gax/grpc/SettingsTest.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
import com.google.api.gax.rpc.StubSettings;
5151
import com.google.api.gax.rpc.TransportChannelProvider;
5252
import com.google.api.gax.rpc.UnaryCallSettings;
53+
import com.google.api.gax.rpc.mtls.MtlsProvider;
5354
import com.google.auth.Credentials;
5455
import com.google.common.collect.ImmutableList;
5556
import com.google.common.collect.ImmutableMap;
@@ -83,6 +84,7 @@ private static class FakeStubSettings extends StubSettings<FakeStubSettings> {
8384
public static final int DEFAULT_SERVICE_PORT = 443;
8485
public static final String DEFAULT_SERVICE_ENDPOINT =
8586
DEFAULT_SERVICE_ADDRESS + ':' + DEFAULT_SERVICE_PORT;
87+
public static final MtlsProvider DEFAULT_MTLS_PROVIDER = new MtlsProvider();
8688
public static final ImmutableList<String> DEFAULT_SERVICE_SCOPES =
8789
ImmutableList.<String>builder()
8890
.add("https://www.googleapis.com/auth/pubsub")
@@ -148,7 +150,9 @@ public static InstantiatingExecutorProvider.Builder defaultExecutorProviderBuild
148150

149151
/** Returns a builder for the default TransportChannelProvider for this service. */
150152
public static InstantiatingGrpcChannelProvider.Builder defaultGrpcChannelProviderBuilder() {
151-
return InstantiatingGrpcChannelProvider.newBuilder().setEndpoint(DEFAULT_SERVICE_ENDPOINT);
153+
return InstantiatingGrpcChannelProvider.newBuilder()
154+
.setEndpoint(DEFAULT_SERVICE_ENDPOINT)
155+
.setMtlsProvider(DEFAULT_MTLS_PROVIDER);
152156
}
153157

154158
public static ApiClientHeaderProvider.Builder defaultGoogleServiceHeaderProviderBuilder() {

‎gax-httpjson/src/main/java/com/google/api/gax/httpjson/InstantiatingHttpJsonChannelProvider.java

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,21 @@
3030
package com.google.api.gax.httpjson;
3131

3232
import com.google.api.client.http.HttpTransport;
33+
import com.google.api.client.http.javanet.NetHttpTransport;
3334
import com.google.api.core.BetaApi;
3435
import com.google.api.core.InternalExtensionOnly;
3536
import com.google.api.gax.core.ExecutorProvider;
3637
import com.google.api.gax.rpc.FixedHeaderProvider;
3738
import com.google.api.gax.rpc.HeaderProvider;
3839
import com.google.api.gax.rpc.TransportChannel;
3940
import com.google.api.gax.rpc.TransportChannelProvider;
41+
import com.google.api.gax.rpc.mtls.MtlsProvider;
4042
import com.google.auth.Credentials;
43+
import com.google.common.annotations.VisibleForTesting;
4144
import com.google.common.collect.Lists;
4245
import java.io.IOException;
46+
import java.security.GeneralSecurityException;
47+
import java.security.KeyStore;
4348
import java.util.List;
4449
import java.util.Map;
4550
import java.util.concurrent.Executor;
@@ -64,24 +69,28 @@ public final class InstantiatingHttpJsonChannelProvider implements TransportChan
6469
private final HeaderProvider headerProvider;
6570
private final String endpoint;
6671
private final HttpTransport httpTransport;
72+
private final MtlsProvider mtlsProvider;
6773

6874
private InstantiatingHttpJsonChannelProvider(
6975
Executor executor, HeaderProvider headerProvider, String endpoint) {
7076
this.executor = executor;
7177
this.headerProvider = headerProvider;
7278
this.endpoint = endpoint;
7379
this.httpTransport = null;
80+
this.mtlsProvider = new MtlsProvider();
7481
}
7582

7683
private InstantiatingHttpJsonChannelProvider(
7784
Executor executor,
7885
HeaderProvider headerProvider,
7986
String endpoint,
80-
HttpTransport httpTransport) {
87+
HttpTransport httpTransport,
88+
MtlsProvider mtlsProvider) {
8189
this.executor = executor;
8290
this.headerProvider = headerProvider;
8391
this.endpoint = endpoint;
8492
this.httpTransport = httpTransport;
93+
this.mtlsProvider = mtlsProvider;
8594
}
8695

8796
@Override
@@ -145,7 +154,11 @@ public TransportChannel getTransportChannel() throws IOException {
145154
} else if (needsHeaders()) {
146155
throw new IllegalStateException("getTransportChannel() called when needsHeaders() is true");
147156
} else {
148-
return createChannel();
157+
try {
158+
return createChannel();
159+
} catch (GeneralSecurityException e) {
160+
throw new IOException(e);
161+
}
149162
}
150163
}
151164

@@ -160,20 +173,35 @@ public TransportChannelProvider withCredentials(Credentials credentials) {
160173
"InstantiatingHttpJsonChannelProvider doesn't need credentials");
161174
}
162175

163-
private TransportChannel createChannel() throws IOException {
176+
HttpTransport createHttpTransport() throws IOException, GeneralSecurityException {
177+
if (mtlsProvider.useMtlsClientCertificate()) {
178+
KeyStore mtlsKeyStore = mtlsProvider.getKeyStore();
179+
if (mtlsKeyStore != null) {
180+
return new NetHttpTransport.Builder().trustCertificates(null, mtlsKeyStore, "").build();
181+
}
182+
}
183+
return null;
184+
}
185+
186+
private TransportChannel createChannel() throws IOException, GeneralSecurityException {
164187
Map<String, String> headers = headerProvider.getHeaders();
165188

166189
List<HttpJsonHeaderEnhancer> headerEnhancers = Lists.newArrayList();
167190
for (Map.Entry<String, String> header : headers.entrySet()) {
168191
headerEnhancers.add(HttpJsonHeaderEnhancers.create(header.getKey(), header.getValue()));
169192
}
170193

194+
HttpTransport httpTransportToUse = httpTransport;
195+
if (httpTransportToUse == null) {
196+
httpTransportToUse = createHttpTransport();
197+
}
198+
171199
ManagedHttpJsonChannel channel =
172200
ManagedHttpJsonChannel.newBuilder()
173201
.setEndpoint(endpoint)
174202
.setHeaderEnhancers(headerEnhancers)
175203
.setExecutor(executor)
176-
.setHttpTransport(httpTransport)
204+
.setHttpTransport(httpTransportToUse)
177205
.build();
178206

179207
return HttpJsonTransportChannel.newBuilder().setManagedChannel(channel).build();
@@ -202,6 +230,7 @@ public static final class Builder {
202230
private HeaderProvider headerProvider;
203231
private String endpoint;
204232
private HttpTransport httpTransport;
233+
private MtlsProvider mtlsProvider = new MtlsProvider();
205234

206235
private Builder() {}
207236

@@ -210,6 +239,7 @@ private Builder(InstantiatingHttpJsonChannelProvider provider) {
210239
this.headerProvider = provider.headerProvider;
211240
this.endpoint = provider.endpoint;
212241
this.httpTransport = provider.httpTransport;
242+
this.mtlsProvider = provider.mtlsProvider;
213243
}
214244

215245
/**
@@ -259,9 +289,15 @@ public String getEndpoint() {
259289
return endpoint;
260290
}
261291

292+
@VisibleForTesting
293+
Builder setMtlsProvider(MtlsProvider mtlsProvider) {
294+
this.mtlsProvider = mtlsProvider;
295+
return this;
296+
}
297+
262298
public InstantiatingHttpJsonChannelProvider build() {
263299
return new InstantiatingHttpJsonChannelProvider(
264-
executor, headerProvider, endpoint, httpTransport);
300+
executor, headerProvider, endpoint, httpTransport, mtlsProvider);
265301
}
266302
}
267303
}

‎gax-httpjson/src/test/java/com/google/api/gax/httpjson/InstantiatingHttpJsonChannelProviderTest.java

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,23 @@
3232
import static com.google.common.truth.Truth.assertThat;
3333
import static org.junit.Assert.assertEquals;
3434

35+
import com.google.api.gax.rpc.HeaderProvider;
3536
import com.google.api.gax.rpc.TransportChannelProvider;
37+
import com.google.api.gax.rpc.mtls.AbstractMtlsTransportChannelTest;
38+
import com.google.api.gax.rpc.mtls.MtlsProvider;
3639
import java.io.IOException;
40+
import java.security.GeneralSecurityException;
3741
import java.util.Collections;
3842
import java.util.concurrent.Executor;
3943
import java.util.concurrent.ScheduledExecutorService;
4044
import java.util.concurrent.ScheduledThreadPoolExecutor;
4145
import org.junit.Test;
4246
import org.junit.runner.RunWith;
4347
import org.junit.runners.JUnit4;
48+
import org.mockito.Mockito;
4449

4550
@RunWith(JUnit4.class)
46-
public class InstantiatingHttpJsonChannelProviderTest {
51+
public class InstantiatingHttpJsonChannelProviderTest extends AbstractMtlsTransportChannelTest {
4752

4853
@Test
4954
public void basicTest() throws IOException {
@@ -94,4 +99,17 @@ public void basicTest() throws IOException {
9499
// Make sure we can create channels OK.
95100
provider.getTransportChannel().shutdownNow();
96101
}
102+
103+
@Override
104+
protected Object getMtlsObjectFromTransportChannel(MtlsProvider provider)
105+
throws IOException, GeneralSecurityException {
106+
InstantiatingHttpJsonChannelProvider channelProvider =
107+
InstantiatingHttpJsonChannelProvider.newBuilder()
108+
.setEndpoint("localhost:8080")
109+
.setMtlsProvider(provider)
110+
.setHeaderProvider(Mockito.mock(HeaderProvider.class))
111+
.setExecutor(Mockito.mock(Executor.class))
112+
.build();
113+
return channelProvider.createHttpTransport();
114+
}
97115
}

‎gax/BUILD.bazel

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ java_library(
5151
srcs = glob(["src/test/java/**/*.java"]),
5252
javacopts = _JAVA_COPTS,
5353
plugins = ["//:auto_value_plugin"],
54+
resources = glob([
55+
"src/test/resources/com/google/api/gax/rpc/mtls/mtls_context_aware_metadata.json",
56+
"src/test/resources/com/google/api/gax/rpc/mtls/mtlsCertAndKey.pem",
57+
]),
5458
visibility = ["//visibility:public"],
5559
deps = [":gax"] + _COMPILE_DEPS + _TEST_COMPILE_DEPS,
5660
)

0 commit comments

Comments
 (0)