Skip to content

Update timm universal (support transformer-style model) #1004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Dec 19, 2024
Prev Previous commit
Next Next commit
Update timm_universal.py
  • Loading branch information
brianhou0208 committed Dec 18, 2024
commit 8b0fece33cfabf89b75c12c6d929b53b3caa692e
133 changes: 79 additions & 54 deletions segmentation_models_pytorch/encoders/timm_universal.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,29 @@
"""
TimmUniversalEncoder provides a unified feature extraction interface built on the
`timm` library, supporting various backbone architectures, including traditional
CNNs (e.g., ResNet) and models adopting a transformer-like feature hierarchy
(e.g., Swin Transformer, ConvNeXt).
TimmUniversalEncoder provides a unified feature extraction interface built on the
`timm` library, supporting both traditional-style (e.g., ResNet) and transformer-style
models (e.g., Swin Transformer, ConvNeXt).

This encoder produces standardized multi-level feature maps, facilitating integration
with semantic segmentation tasks. It allows configuring the number of feature extraction
stages (`depth`) and adjusting `output_stride` when supported.
This encoder produces consistent multi-level feature maps for semantic segmentation tasks.
It allows configuring the number of feature extraction stages (`depth`) and adjusting
`output_stride` when supported.

Key Features:
- Flexible model selection through `timm.create_model`.
- A unified interface that outputs consistent, multi-level features even if the
underlying model differs in its feature hierarchy.
- Automatic alignment: If a model lacks certain early-stage features (for example,
modern architectures that start from a 1/4 scale rather than 1/2 scale), the encoder
inserts dummy features to maintain consistency with traditional CNN structures.
- Easy access to channel information: Use the `out_channels` property to retrieve
the number of channels at each feature stage.
- Flexible model selection using `timm.create_model`.
- Unified multi-level output across different model hierarchies.
- Automatic alignment for inconsistent feature scales:
- Transformer-style models (start at 1/4 scale): Insert dummy features for 1/2 scale.
- VGG-style models (include scale-1 features): Align outputs for compatibility.
- Easy access to feature scale information via the `reduction` property.

Feature Scale Differences:
- Traditional CNNs (e.g., ResNet) typically provide features at 1/2, 1/4, 1/8, 1/16,
and 1/32 scales.
- Transformer-style or next-generation models (e.g., Swin Transformer, ConvNeXt) often
start from the 1/4 scale (then 1/8, 1/16, 1/32), omitting the initial 1/2 scale
feature. TimmUniversalEncoder compensates for this omission to ensure a unified
multi-stage output.
- Traditional-style models (e.g., ResNet): Scales at 1/2, 1/4, 1/8, 1/16, 1/32.
- Transformer-style models (e.g., Swin Transformer): Start at 1/4 scale, skip 1/2 scale.
- VGG-style models: Include scale-1 features (input resolution).

Notes:
- Not all models support modifying `output_stride` (especially transformer-based or
transformer-like models).
- Certain models (e.g., TResNet, DLA) require special handling to ensure correct
feature indexing.
- `output_stride` is unsupported in some models, especially transformer-based architectures.
- Special handling for models like TResNet and DLA to ensure correct feature indexing.
- VGG-style models use `_is_skip_first` to align scale-1 features with standard outputs.
"""

from typing import Any
Expand All @@ -42,14 +35,13 @@

class TimmUniversalEncoder(nn.Module):
"""
A universal encoder built on the `timm` library, designed to adapt to a wide variety of
model architectures, including both traditional CNNs and those that follow a
transformer-like hierarchy.
A universal encoder leveraging the `timm` library for feature extraction from
various model architectures, including traditional-style and transformer-style models.

Features:
- Supports flexible depth and output stride for feature extraction.
- Automatically adjusts to input/output channel structures based on the model type.
- Compatible with both convolutional and transformer-like encoders.
- Supports configurable depth and output stride.
- Ensures consistent multi-level feature extraction across diverse models.
- Compatible with convolutional and transformer-like backbones.
"""

def __init__(
Expand All @@ -65,15 +57,16 @@ def __init__(
Initialize the encoder.

Args:
name (str): Name of the model to be loaded from the `timm` library.
pretrained (bool): If True, loads pretrained weights.
name (str): Model name to load from `timm`.
pretrained (bool): Load pretrained weights (default: True).
in_channels (int): Number of input channels (default: 3 for RGB).
depth (int): Number of feature extraction stages (default: 5).
depth (int): Number of feature stages to extract (default: 5).
output_stride (int): Desired output stride (default: 32).
**kwargs: Additional keyword arguments for `timm.create_model`.
**kwargs: Additional arguments passed to `timm.create_model`.
"""
super().__init__()

# Default model configuration for feature extraction
common_kwargs = dict(
in_chans=in_channels,
features_only=True,
Expand All @@ -82,24 +75,37 @@ def __init__(
out_indices=tuple(range(depth)),
)

# not all models support output stride argument, drop it by default
# Not all models support output stride argument, drop it by default
if output_stride == 32:
common_kwargs.pop("output_stride")

# Load a preliminary model to determine its feature hierarchy structure.
# Load a temporary model to analyze its feature hierarchy
self.model = timm.create_model(name, features_only=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to load a temporary model? I would try to avoid it if possible.

Copy link
Contributor Author

@brianhou0208 brianhou0208 Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that a temporary model is necessary because we need to determine feature_info.reduction() to classify the model as traditional, transformer, or VGG style. This affects the range of out_indices to be used:

common_kwargs["out_indices"] = tuple(range(depth))
  • If depth == 5, out_indices is
    • traditional-style (0, 1, 2, 3, 4)
    • transformer-style (0, 1, 2, 3)
    • vgg-style (0, 1, 2, 3, 4, 5)
  • If depth == 3, out_indices is
    • traditional-style (0, 1, 2)
    • transformer-style (0, 1)
    • vgg-style (0, 1, 2, 3)

Is there any other way to determine feature_info.reduction() in advance?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we slice features in forward instead of providing "out_indicies"? Otherwise, I would recommend using pretrained=False for the tmp model and maybe initialize it on the meta device to avoid double memory consumption.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In timm.create_model(), default is pretrained=False
I think initialize tmp model to torch.device("meta") is good

self.model = timm.create_model(name, pretrained=False, features_only=True).to("meta")

what do you think?

Copy link
Collaborator

@qubvel qubvel Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explicit pretrained=False would be nice, for meta it should be something like this:

with torch.device("meta"):
    tmp_model = timm.create_model(name, pretrained=False, features_only=True)

+ without self.
+ let's name it with tmp_

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets leave it as is for now, it can be optimized later if needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't use additional variable names, it shouldn't take up extra memory?
renamed temp_model to self.model

Although the variable names will be a little confusing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As is I mean:

# Load a temporary model to analyze its feature hierarchy
try:
    with torch.device("meta"):
        tmp_model = timm.create_model(name, features_only=True)
except Exception:
    tmp_model = timm.create_model(name, features_only=True)

sorry for the confusuion

Copy link
Collaborator

@qubvel qubvel Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't use additional variable names, it shouldn't take up extra memory?

don't think so, we still allocate twice.

  1. we have tmp model initialized and linked to self.model
  2. we initialize required model
  3. we unlink tmp model from self.model var name and link required one

two models exist at a time

Copy link
Contributor Author

@brianhou0208 brianhou0208 Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are right, thanks for your explanation


# Check if the model's output is in channel-last format (B, H, W, C).
# Check if model output is in channel-last format (NHWC)
self._is_channel_last = getattr(self.model, "output_fmt", None) == "NHWC"

# Determine if this model uses a transformer-like hierarchy (i.e., starting at 1/4 scale)
# rather than a traditional CNN hierarchy (starting at 1/2 scale).
if len(self.model.feature_info.channels()) == 5:
# Determine the model's downsampling pattern and set hierarchy flags
encoder_stage = len(self.model.feature_info.reduction())
reduction_scales = self.model.feature_info.reduction()

if reduction_scales == [2 ** (i + 2) for i in range(encoder_stage)]:
# Transformer-style downsampling: scales (4, 8, 16, 32)
self._is_transformer_style = True
self._is_skip_first = False
elif reduction_scales == [2 ** (i + 1) for i in range(encoder_stage)]:
# Traditional-style downsampling: scales (2, 4, 8, 16, 32)
self._is_transformer_style = False
self._is_skip_first = False
elif reduction_scales == [2 ** i for i in range(encoder_stage)]:
# Models including scale 1: scales (1, 2, 4, 8, 16, 32)
self._is_transformer_style = False
self._is_skip_first = True
else:
self._is_transformer_style = True
raise ValueError("Unsupported model downsampling pattern.")

if self._is_transformer_style:
# Transformer-like models (start at scale 4)
if "tresnet" in name:
# 'tresnet' models start feature extraction at stage 1,
# so out_indices=(1, 2, 3, 4) for depth=5.
Expand All @@ -119,65 +125,84 @@ def __init__(
if "dla" in name:
# For 'dla' models, out_indices starts at 0 and matches the input size.
common_kwargs["out_indices"] = tuple(range(1, depth + 1))
if self._is_skip_first:
common_kwargs["out_indices"] = tuple(range(depth + 1))

self.model = timm.create_model(
name, **_merge_kwargs_no_duplicates(common_kwargs, kwargs)
)
self._out_channels = [in_channels] + self.model.feature_info.channels()

if self._is_skip_first:
self._out_channels = self.model.feature_info.channels()
else:
self._out_channels = [in_channels] + self.model.feature_info.channels()

self._in_channels = in_channels
self._depth = depth
self._output_stride = output_stride

def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
"""
Pass the input through the encoder and return extracted features.
Forward pass to extract multi-stage features.

Args:
x (torch.Tensor): Input tensor of shape (B, C, H, W).

Returns:
list[torch.Tensor]: A list of feature maps extracted at various scales.
list[torch.Tensor]: List of feature maps at different scales.
"""
features = self.model(x)

# Convert NHWC to NCHW if needed
if self._is_channel_last:
# Convert to channel-first (B, C, H, W).
features = [
feature.permute(0, 3, 1, 2).contiguous() for feature in features
]

# Add dummy feature for scale 1/2 if missing (transformer-style models)
if self._is_transformer_style:
# Models using a transformer-like hierarchy may not generate
# all expected feature maps. Insert a dummy feature map to ensure
# compatibility with decoders expecting a 5-level pyramid.
B, _, H, W = x.shape
dummy = torch.empty([B, 0, H // 2, W // 2], dtype=x.dtype, device=x.device)
features = [x] + [dummy] + features
else:
features = [dummy] + features

# Add input tensor as scale 1 feature if `self._is_skip_first` is False
if not self._is_skip_first:
features = [x] + features

return features

@property
def out_channels(self) -> list[int]:
"""
Returns the number of output channels for each feature stage.

Returns:
list[int]: A list of output channels for each stage of the encoder,
including the input channels at the first stage.
list[int]: A list of channel dimensions at each scale.
"""
return self._out_channels

@property
def output_stride(self) -> int:
"""
Returns the effective output stride based on the model depth.

Returns:
int: The effective output stride of the encoder, considering the depth.
int: The effective output stride.
"""
return min(self._output_stride, 2**self._depth)


def _merge_kwargs_no_duplicates(a: dict[str, Any], b: dict[str, Any]) -> dict[str, Any]:
"""
Merge two dictionaries, ensuring no duplicate keys exist.

Args:
a (dict): Base dictionary.
b (dict): Additional parameters to merge.

Returns:
dict: A merged dictionary.
"""
duplicates = a.keys() & b.keys()
if duplicates:
raise ValueError(f"'{duplicates}' already specified internally")
Expand Down