Skip to content

Update timm universal (support transformer-style model) #1004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Dec 19, 2024
Prev Previous commit
Next Next commit
Update timm_universal.py
1. rename temporary model
2. create temporary model on meta device to speed up
  • Loading branch information
brianhou0208 committed Dec 18, 2024
commit d8ea35f5bf714941baedde3326fe6b0f67496d90
12 changes: 8 additions & 4 deletions segmentation_models_pytorch/encoders/timm_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,18 @@ def __init__(
common_kwargs.pop("output_stride")

# Load a temporary model to analyze its feature hierarchy
self.model = timm.create_model(name, features_only=True)
try:
with torch.device("meta"):
tmp_model = timm.create_model(name, features_only=True)
except Exception:
tmp_model = timm.create_model(name, features_only=True)

# Check if model output is in channel-last format (NHWC)
self._is_channel_last = getattr(self.model, "output_fmt", None) == "NHWC"
self._is_channel_last = getattr(tmp_model, "output_fmt", None) == "NHWC"

# Determine the model's downsampling pattern and set hierarchy flags
encoder_stage = len(self.model.feature_info.reduction())
reduction_scales = self.model.feature_info.reduction()
encoder_stage = len(tmp_model.feature_info.reduction())
reduction_scales = tmp_model.feature_info.reduction()

if reduction_scales == [2 ** (i + 2) for i in range(encoder_stage)]:
# Transformer-style downsampling: scales (4, 8, 16, 32)
Expand Down
Loading