Skip to content

Update timm universal (support transformer-style model) #1004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Dec 19, 2024
Prev Previous commit
Next Next commit
Fix ruff style
  • Loading branch information
brianhou0208 committed Dec 18, 2024
commit 330e6e5ae3dfd1708f22231e93e1eba5905e8cd5
46 changes: 23 additions & 23 deletions segmentation_models_pytorch/encoders/timm_universal.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
"""
TimmUniversalEncoder provides a unified feature extraction interface built on the
`timm` library, supporting both traditional-style (e.g., ResNet) and transformer-style
TimmUniversalEncoder provides a unified feature extraction interface built on the
`timm` library, supporting both traditional-style (e.g., ResNet) and transformer-style
models (e.g., Swin Transformer, ConvNeXt).

This encoder produces consistent multi-level feature maps for semantic segmentation tasks.
It allows configuring the number of feature extraction stages (`depth`) and adjusting
This encoder produces consistent multi-level feature maps for semantic segmentation tasks.
It allows configuring the number of feature extraction stages (`depth`) and adjusting
`output_stride` when supported.

Key Features:
- Flexible model selection using `timm.create_model`.
- Unified multi-level output across different model hierarchies.
- Unified multi-level output across different model hierarchies.
- Automatic alignment for inconsistent feature scales:
- Transformer-style models (start at 1/4 scale): Insert dummy features for 1/2 scale.
- VGG-style models (include scale-1 features): Align outputs for compatibility.
- Transformer-style models (start at 1/4 scale): Insert dummy features for 1/2 scale.
- VGG-style models (include scale-1 features): Align outputs for compatibility.
- Easy access to feature scale information via the `reduction` property.

Feature Scale Differences:
- Traditional-style models (e.g., ResNet): Scales at 1/2, 1/4, 1/8, 1/16, 1/32.
- Transformer-style models (e.g., Swin Transformer): Start at 1/4 scale, skip 1/2 scale.
- Traditional-style models (e.g., ResNet): Scales at 1/2, 1/4, 1/8, 1/16, 1/32.
- Transformer-style models (e.g., Swin Transformer): Start at 1/4 scale, skip 1/2 scale.
- VGG-style models: Include scale-1 features (input resolution).

Notes:
- `output_stride` is unsupported in some models, especially transformer-based architectures.
- Special handling for models like TResNet and DLA to ensure correct feature indexing.
- VGG-style models use `_is_skip_first` to align scale-1 features with standard outputs.
- `output_stride` is unsupported in some models, especially transformer-based architectures.
- Special handling for models like TResNet and DLA to ensure correct feature indexing.
- VGG-style models use `_is_vgg_style` to align scale-1 features with standard outputs.
"""

from typing import Any
Expand All @@ -35,7 +35,7 @@

class TimmUniversalEncoder(nn.Module):
"""
A universal encoder leveraging the `timm` library for feature extraction from
A universal encoder leveraging the `timm` library for feature extraction from
various model architectures, including traditional-style and transformer-style models.

Features:
Expand Down Expand Up @@ -92,15 +92,15 @@ def __init__(
if reduction_scales == [2 ** (i + 2) for i in range(encoder_stage)]:
# Transformer-style downsampling: scales (4, 8, 16, 32)
self._is_transformer_style = True
self._is_skip_first = False
self._is_vgg_style = False
elif reduction_scales == [2 ** (i + 1) for i in range(encoder_stage)]:
# Traditional-style downsampling: scales (2, 4, 8, 16, 32)
self._is_transformer_style = False
self._is_skip_first = False
elif reduction_scales == [2 ** i for i in range(encoder_stage)]:
# Models including scale 1: scales (1, 2, 4, 8, 16, 32)
self._is_vgg_style = False
elif reduction_scales == [2**i for i in range(encoder_stage)]:
# Vgg-style models including scale 1: scales (1, 2, 4, 8, 16, 32)
self._is_transformer_style = False
self._is_skip_first = True
self._is_vgg_style = True
else:
raise ValueError("Unsupported model downsampling pattern.")

Expand All @@ -125,14 +125,14 @@ def __init__(
if "dla" in name:
# For 'dla' models, out_indices starts at 0 and matches the input size.
common_kwargs["out_indices"] = tuple(range(1, depth + 1))
if self._is_skip_first:
if self._is_vgg_style:
common_kwargs["out_indices"] = tuple(range(depth + 1))

self.model = timm.create_model(
name, **_merge_kwargs_no_duplicates(common_kwargs, kwargs)
)

if self._is_skip_first:
if self._is_vgg_style:
self._out_channels = self.model.feature_info.channels()
else:
self._out_channels = [in_channels] + self.model.feature_info.channels()
Expand Down Expand Up @@ -164,9 +164,9 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
B, _, H, W = x.shape
dummy = torch.empty([B, 0, H // 2, W // 2], dtype=x.dtype, device=x.device)
features = [dummy] + features
# Add input tensor as scale 1 feature if `self._is_skip_first` is False
if not self._is_skip_first:

# Add input tensor as scale 1 feature if `self._is_vgg_style` is False
if not self._is_vgg_style:
features = [x] + features

return features
Expand Down
Loading