Skip to content

Update timm universal (support transformer-style model) #1004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Dec 19, 2024
Next Next commit
Update timm_universal.py
  • Loading branch information
brianhou0208 committed Dec 7, 2024
commit 363a361cec50c703a009257d9da121cbbedf01a8
140 changes: 131 additions & 9 deletions segmentation_models_pytorch/encoders/timm_universal.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,60 @@
"""
TimmUniversalEncoder provides a unified feature extraction interface built on the
`timm` library, supporting various backbone architectures, including traditional
CNNs (e.g., ResNet) and models adopting a transformer-like feature hierarchy
(e.g., Swin Transformer, ConvNeXt).

This encoder produces standardized multi-level feature maps, facilitating integration
with semantic segmentation tasks. It allows configuring the number of feature extraction
stages (`depth`) and adjusting `output_stride` when supported.

Key Features:
- Flexible model selection through `timm.create_model`.
- A unified interface that outputs consistent, multi-level features even if the
underlying model differs in its feature hierarchy.
- Automatic alignment: If a model lacks certain early-stage features (for example,
modern architectures that start from a 1/4 scale rather than 1/2 scale), the encoder
inserts dummy features to maintain consistency with traditional CNN structures.
- Easy access to channel information: Use the `out_channels` property to retrieve
the number of channels at each feature stage.

Feature Scale Differences:
- Traditional CNNs (e.g., ResNet) typically provide features at 1/2, 1/4, 1/8, 1/16,
and 1/32 scales.
- Transformer-style or next-generation models (e.g., Swin Transformer, ConvNeXt) often
start from the 1/4 scale (then 1/8, 1/16, 1/32), omitting the initial 1/2 scale
feature. TimmUniversalEncoder compensates for this omission to ensure a unified
multi-stage output.

Notes:
- Not all models support modifying `output_stride` (especially transformer-based or
transformer-like models).
- Certain models (e.g., TResNet, DLA) require special handling to ensure correct
feature indexing.
- Most `timm` models output features in (B, C, H, W) format. However, some
(e.g., MambaOut and certain Swin/SwinV2 variants) use (B, H, W, C) format, which is
currently unsupported.
"""

from typing import Any

import timm
import torch
import torch.nn as nn


class TimmUniversalEncoder(nn.Module):
"""
A universal encoder built on the `timm` library, designed to adapt to a wide variety of
model architectures, including both traditional CNNs and those that follow a
transformer-like hierarchy.

Features:
- Supports flexible depth and output stride for feature extraction.
- Automatically adjusts to input/output channel structures based on the model type.
- Compatible with both convolutional and transformer-like encoders.
"""

def __init__(
self,
name: str,
Expand All @@ -14,7 +64,19 @@ def __init__(
output_stride: int = 32,
**kwargs: dict[str, Any],
):
"""
Initialize the encoder.

Args:
name (str): Name of the model to be loaded from the `timm` library.
pretrained (bool): If True, loads pretrained weights.
in_channels (int): Number of input channels (default: 3 for RGB).
depth (int): Number of feature extraction stages (default: 5).
output_stride (int): Desired output stride (default: 32).
**kwargs: Additional keyword arguments for `timm.create_model`.
"""
super().__init__()

common_kwargs = dict(
in_chans=in_channels,
features_only=True,
Expand All @@ -23,30 +85,90 @@ def __init__(
out_indices=tuple(range(depth)),
)

# not all models support output stride argument, drop it by default
if output_stride == 32:
common_kwargs.pop("output_stride")

self.model = timm.create_model(
name, **_merge_kwargs_no_duplicates(common_kwargs, kwargs)
)
# Load a preliminary model to determine its feature hierarchy structure.
self.model = timm.create_model(name, features_only=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to load a temporary model? I would try to avoid it if possible.

Copy link
Contributor Author

@brianhou0208 brianhou0208 Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that a temporary model is necessary because we need to determine feature_info.reduction() to classify the model as traditional, transformer, or VGG style. This affects the range of out_indices to be used:

common_kwargs["out_indices"] = tuple(range(depth))
  • If depth == 5, out_indices is
    • traditional-style (0, 1, 2, 3, 4)
    • transformer-style (0, 1, 2, 3)
    • vgg-style (0, 1, 2, 3, 4, 5)
  • If depth == 3, out_indices is
    • traditional-style (0, 1, 2)
    • transformer-style (0, 1)
    • vgg-style (0, 1, 2, 3)

Is there any other way to determine feature_info.reduction() in advance?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we slice features in forward instead of providing "out_indicies"? Otherwise, I would recommend using pretrained=False for the tmp model and maybe initialize it on the meta device to avoid double memory consumption.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In timm.create_model(), default is pretrained=False
I think initialize tmp model to torch.device("meta") is good

self.model = timm.create_model(name, pretrained=False, features_only=True).to("meta")

what do you think?

Copy link
Collaborator

@qubvel qubvel Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explicit pretrained=False would be nice, for meta it should be something like this:

with torch.device("meta"):
    tmp_model = timm.create_model(name, pretrained=False, features_only=True)

+ without self.
+ let's name it with tmp_

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets leave it as is for now, it can be optimized later if needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't use additional variable names, it shouldn't take up extra memory?
renamed temp_model to self.model

Although the variable names will be a little confusing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As is I mean:

# Load a temporary model to analyze its feature hierarchy
try:
    with torch.device("meta"):
        tmp_model = timm.create_model(name, features_only=True)
except Exception:
    tmp_model = timm.create_model(name, features_only=True)

sorry for the confusuion

Copy link
Collaborator

@qubvel qubvel Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't use additional variable names, it shouldn't take up extra memory?

don't think so, we still allocate twice.

  1. we have tmp model initialized and linked to self.model
  2. we initialize required model
  3. we unlink tmp model from self.model var name and link required one

two models exist at a time

Copy link
Contributor Author

@brianhou0208 brianhou0208 Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are right, thanks for your explanation


# Determine if this model uses a transformer-like hierarchy (i.e., starting at 1/4 scale)
# rather than a traditional CNN hierarchy (starting at 1/2 scale).
if len(self.model.feature_info.channels()) == 5:
# This indicates a traditional hierarchy: (1/2, 1/4, 1/8, 1/16, 1/32)
self._is_transformer_style = False
else:
# This indicates a transformer-like hierarchy: (1/4, 1/8, 1/16, 1/32)
self._is_transformer_style = True

if self._is_transformer_style:
if "tresnet" in name:
# 'tresnet' models start feature extraction at stage 1,
# so out_indices=(1, 2, 3, 4) for depth=5.
common_kwargs["out_indices"] = tuple(range(1, depth))
else:
# Most transformer-like models use out_indices=(0, 1, 2, 3) for depth=5.
common_kwargs["out_indices"] = tuple(range(depth - 1))

self.model = timm.create_model(
name, **_merge_kwargs_no_duplicates(common_kwargs, kwargs)
)
# Add a dummy output channel (0) to align with traditional encoder structures.
self._out_channels = (
[in_channels] + [0] + self.model.feature_info.channels()
)
else:
if "dla" in name:
# For 'dla' models, out_indices starts at 0 and matches the input size.
kwargs["out_indices"] = tuple(range(1, depth + 1))

self.model = timm.create_model(
name, **_merge_kwargs_no_duplicates(common_kwargs, kwargs)
)
self._out_channels = [in_channels] + self.model.feature_info.channels()

self._in_channels = in_channels
self._out_channels = [in_channels] + self.model.feature_info.channels()
self._depth = depth
self._output_stride = output_stride

def forward(self, x):
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
"""
Pass the input through the encoder and return extracted features.

Args:
x (torch.Tensor): Input tensor of shape (B, C, H, W).

Returns:
List[torch.Tensor]: A list of feature maps extracted at various scales.
"""
features = self.model(x)
features = [x] + features

if self._is_transformer_style:
# Models using a transformer-like hierarchy may not generate
# all expected feature maps. Insert a dummy feature map to ensure
# compatibility with decoders expecting a 5-level pyramid.
B, _, H, W = x.shape
dummy = torch.empty([B, 0, H // 2, W // 2], dtype=x.dtype, device=x.device)
features = [x] + [dummy] + features
else:
features = [x] + features

return features

@property
def out_channels(self):
def out_channels(self) -> list[int]:
"""
Returns:
List[int]: A list of output channels for each stage of the encoder,
including the input channels at the first stage.
"""
return self._out_channels

@property
def output_stride(self):
def output_stride(self) -> int:
"""
Returns:
int: The effective output stride of the encoder, considering the depth.
"""
return min(self._output_stride, 2**self._depth)


Expand Down