Skip to content

Commit cc50f50

Browse files
authored
Added counting of Dense mapping layer
Counting Dense mapping layer and convert generator and g_ema accordingly
1 parent 0ce34eb commit cc50f50

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

‎convert_weight.py‎

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,10 @@ def discriminator_fill_statedict(statedict, vars, size):
137137
return statedict
138138

139139

140-
def fill_statedict(state_dict, vars, size):
140+
def fill_statedict(state_dict, vars, size, n_mlp):
141141
log_size = int(math.log(size, 2))
142142

143-
for i in range(8):
143+
for i in range(n_mlp):
144144
update(state_dict, convert_dense(vars, f"G_mapping/Dense{i}", f"style.{i + 1}"))
145145

146146
update(
@@ -237,9 +237,15 @@ def fill_statedict(state_dict, vars, size):
237237

238238
size = g_ema.output_shape[2]
239239

240-
g = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier)
240+
n_mlp = 0
241+
mapping_layers_names = g_ema.__getstate__()['components']['mapping'].list_layers()
242+
for layer in mapping_layers_names:
243+
if layer[0].startswith('Dense'):
244+
n_mlp += 1
245+
246+
g = Generator(size, 512, n_mlp, channel_multiplier=args.channel_multiplier)
241247
state_dict = g.state_dict()
242-
state_dict = fill_statedict(state_dict, g_ema.vars, size)
248+
state_dict = fill_statedict(state_dict, g_ema.vars, size, n_mlp)
243249

244250
g.load_state_dict(state_dict)
245251

@@ -248,7 +254,7 @@ def fill_statedict(state_dict, vars, size):
248254
ckpt = {"g_ema": state_dict, "latent_avg": latent_avg}
249255

250256
if args.gen:
251-
g_train = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier)
257+
g_train = Generator(size, 512, n_mlp, channel_multiplier=args.channel_multiplier)
252258
g_train_state = g_train.state_dict()
253259
g_train_state = fill_statedict(g_train_state, generator.vars, size)
254260
ckpt["g"] = g_train_state
@@ -292,5 +298,4 @@ def fill_statedict(state_dict, vars, size):
292298

293299
utils.save_image(
294300
img_concat, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1)
295-
)
296-
301+
)

0 commit comments

Comments
 (0)