Skip to content

Update timm universal (support transformer-style model) #1004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Dec 19, 2024
Prev Previous commit
Next Next commit
Add tests/test_models & fix type
  • Loading branch information
brianhou0208 committed Dec 19, 2024
commit e7bc6e048e7b40aad74d6f9132d1019f7b6bb92b
2 changes: 1 addition & 1 deletion segmentation_models_pytorch/encoders/timm_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(

# Determine the model's downsampling pattern and set hierarchy flags
encoder_stage = len(tmp_model.feature_info.reduction())
reduction_scales = tmp_model.feature_info.reduction()
reduction_scales = list(tmp_model.feature_info.reduction())

if reduction_scales == [2 ** (i + 2) for i in range(encoder_stage)]:
# Transformer-style downsampling: scales (4, 8, 16, 32)
Expand Down
24 changes: 11 additions & 13 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@ def get_encoders():
]
encoders = smp.encoders.get_encoder_names()
encoders = [e for e in encoders if e not in exclude_encoders]
encoders.append("tu-resnet34") # for timm universal encoder
return encoders
encoders.append("tu-resnet34") # for timm universal traditional-like encoder
encoders.append("tu-convnext_atto") # for timm universal transformer-like encoder
encoders.append("tu-darknet17") # for timm universal vgg-like encoder
encoders.append("mit_b0")
return encoders[-3:]


ENCODERS = get_encoders()
Expand Down Expand Up @@ -80,16 +83,12 @@ def test_forward(model_class, encoder_name, encoder_depth, **kwargs):
or model_class is smp.MAnet
):
kwargs["decoder_channels"] = (16, 16, 16, 16, 16)[-encoder_depth:]
if model_class in [smp.UnetPlusPlus, smp.Linknet] and encoder_name.startswith(
"mit_b"
):
return # skip mit_b*
if (
model_class is smp.FPN
and encoder_name.startswith("mit_b")
and encoder_depth != 5
):
return # skip mit_b*
if model_class in [smp.UnetPlusPlus, smp.Linknet]:
if encoder_name.startswith("mit_b") or encoder_name.startswith("tu-convnext"):
return # skip transformer-like model*
if model_class is smp.FPN and encoder_depth != 5:
if encoder_name.startswith("mit_b") or encoder_name.startswith("tu-convnext"):
return # skip transformer-like model*
model = model_class(
encoder_name, encoder_depth=encoder_depth, encoder_weights=None, **kwargs
)
Expand Down Expand Up @@ -180,7 +179,6 @@ def test_dilation(encoder_name):
or encoder_name.startswith("vgg")
or encoder_name.startswith("densenet")
or encoder_name.startswith("timm-res")
or encoder_name.startswith("mit_b")
):
return

Expand Down
Loading