TorchModuleWrapper classkeras.layers.TorchModuleWrapper(module, name=None, output_shape=None, **kwargs)
Torch module wrapper layer.
TorchModuleWrapper is a wrapper class that can turn any
torch.nn.Module into a Keras layer, in particular by making its
parameters trackable by Keras.
TorchModuleWrapper is only compatible with the PyTorch backend and
cannot be used with the TensorFlow or JAX backends.
Arguments
torch.nn.Module instance. If it's a LazyModule
instance, then its parameters must be initialized before
passing the instance to TorchModuleWrapper (e.g. by calling
it once).Example
Here's an example of how the TorchModuleWrapper can be used with vanilla
PyTorch modules.
import torch
import torch.nn as nn
import torch.nn.functional as F
import keras
from keras.layers import TorchModuleWrapper
class Classifier(keras.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Wrap `torch.nn.Module`s with `TorchModuleWrapper`
# if they contain parameters
self.conv1 = TorchModuleWrapper(
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3))
)
self.conv2 = TorchModuleWrapper(
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3))
)
self.pool = nn.MaxPool2d(kernel_size=(2, 2))
self.flatten = nn.Flatten()
self.dropout = nn.Dropout(p=0.5)
self.fc = TorchModuleWrapper(nn.Linear(1600, 10))
def call(self, inputs):
x = F.relu(self.conv1(inputs))
x = self.pool(x)
x = F.relu(self.conv2(x))
x = self.pool(x)
x = self.flatten(x)
x = self.dropout(x)
x = self.fc(x)
return F.softmax(x, dim=1)
model = Classifier()
model.build((1, 28, 28))
print("# Output shape", model(torch.ones(1, 1, 28, 28).to("cuda")).shape)
model.compile(
loss="sparse_categorical_crossentropy",
optimizer="adam",
metrics=["accuracy"]
)
model.fit(train_loader, epochs=5)