Skip to content

[RFC] Faster load time for large models #2350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
override device kwargs of base nn classes
  • Loading branch information
gau-nernst committed Mar 30, 2025
commit d70c4811791cbaaaed74530e05cb0479426f0e14
49 changes: 44 additions & 5 deletions timm/models/_builder.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from contextlib import contextmanager, nullcontext
import dataclasses
import logging
import os
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union

import torch
from torch import nn as nn
from torch.hub import load_state_dict_from_url

Expand Down Expand Up @@ -360,6 +362,27 @@ def resolve_pretrained_cfg(
return pretrained_cfg


@contextmanager
def make_meta_init(*classes):
def create_new_init(cls):
old_init = cls.__init__
def new_init(self, *args, **kwargs):
kwargs.update(device="meta")
old_init(self, *args, **kwargs)
return new_init

original_dict = dict()
for cls in classes:
original_dict[cls] = cls.__init__
cls.__init__ = create_new_init(cls)

yield

# restore original __init__()
for cls, old_init in original_dict.items():
cls.__init__ = old_init


def build_model_with_cfg(
model_cls: Callable,
variant: str,
Expand Down Expand Up @@ -419,11 +442,27 @@ def build_model_with_cfg(
if 'feature_cls' in kwargs:
feature_cfg['feature_cls'] = kwargs.pop('feature_cls')

# use meta-device init to speed up loading pretrained weights.
# when num_classes is changed, we can't use meta device init since we need
# the original __init__() to initialize head from scratch.
num_classes = 0 if features else kwargs.get("num_classes", pretrained_cfg["num_classes"])
use_meta_init = (
pretrained
and (num_classes == 0 or num_classes == pretrained_cfg["num_classes"])
)

# Instantiate the model
if model_cfg is None:
model = model_cls(**kwargs)
else:
model = model_cls(cfg=model_cfg, **kwargs)
base_classes = [nn.Linear, nn.Conv2d, nn.BatchNorm2d, nn.LayerNorm]
with make_meta_init(*base_classes) if use_meta_init else nullcontext():
if model_cfg is None:
model = model_cls(**kwargs)
else:
model = model_cls(cfg=model_cfg, **kwargs)

# convert meta-device tensors to concrete tensors
device = kwargs.get("device", torch.get_default_device())
model._apply(lambda t: (torch.empty_like(t, device=device) if t.is_meta else t))

model.pretrained_cfg = pretrained_cfg
model.default_cfg = model.pretrained_cfg # alias for backwards compat

Expand Down