Skip to content

Update timm universal (support transformer-style model) #1004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Dec 19, 2024
Prev Previous commit
Next Next commit
Fix ruff style and typing
  • Loading branch information
brianhou0208 committed Dec 7, 2024
commit f07e10782cc6746a551bd9e171eab5bba0f44ae3
16 changes: 7 additions & 9 deletions segmentation_models_pytorch/encoders/timm_universal.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
TimmUniversalEncoder provides a unified feature extraction interface built on the
`timm` library, supporting various backbone architectures, including traditional
CNNs (e.g., ResNet) and models adopting a transformer-like feature hierarchy
CNNs (e.g., ResNet) and models adopting a transformer-like feature hierarchy
(e.g., Swin Transformer, ConvNeXt).

This encoder produces standardized multi-level feature maps, facilitating integration
Expand All @@ -22,16 +22,16 @@
- Traditional CNNs (e.g., ResNet) typically provide features at 1/2, 1/4, 1/8, 1/16,
and 1/32 scales.
- Transformer-style or next-generation models (e.g., Swin Transformer, ConvNeXt) often
start from the 1/4 scale (then 1/8, 1/16, 1/32), omitting the initial 1/2 scale
start from the 1/4 scale (then 1/8, 1/16, 1/32), omitting the initial 1/2 scale
feature. TimmUniversalEncoder compensates for this omission to ensure a unified
multi-stage output.

Notes:
- Not all models support modifying `output_stride` (especially transformer-based or
- Not all models support modifying `output_stride` (especially transformer-based or
transformer-like models).
- Certain models (e.g., TResNet, DLA) require special handling to ensure correct
feature indexing.
- Most `timm` models output features in (B, C, H, W) format. However, some
- Most `timm` models output features in (B, C, H, W) format. However, some
(e.g., MambaOut and certain Swin/SwinV2 variants) use (B, H, W, C) format, which is
currently unsupported.
"""
Expand All @@ -46,7 +46,7 @@
class TimmUniversalEncoder(nn.Module):
"""
A universal encoder built on the `timm` library, designed to adapt to a wide variety of
model architectures, including both traditional CNNs and those that follow a
model architectures, including both traditional CNNs and those that follow a
transformer-like hierarchy.

Features:
Expand Down Expand Up @@ -94,10 +94,8 @@ def __init__(
# Determine if this model uses a transformer-like hierarchy (i.e., starting at 1/4 scale)
# rather than a traditional CNN hierarchy (starting at 1/2 scale).
if len(self.model.feature_info.channels()) == 5:
# This indicates a traditional hierarchy: (1/2, 1/4, 1/8, 1/16, 1/32)
self._is_transformer_style = False
else:
# This indicates a transformer-like hierarchy: (1/4, 1/8, 1/16, 1/32)
self._is_transformer_style = True

if self._is_transformer_style:
Expand Down Expand Up @@ -138,7 +136,7 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
x (torch.Tensor): Input tensor of shape (B, C, H, W).

Returns:
List[torch.Tensor]: A list of feature maps extracted at various scales.
list[torch.Tensor]: A list of feature maps extracted at various scales.
"""
features = self.model(x)

Expand All @@ -158,7 +156,7 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
def out_channels(self) -> list[int]:
"""
Returns:
List[int]: A list of output channels for each stage of the encoder,
list[int]: A list of output channels for each stage of the encoder,
including the input channels at the first stage.
"""
return self._out_channels
Expand Down