@@ -137,10 +137,10 @@ def discriminator_fill_statedict(statedict, vars, size):
137137 return statedict
138138
139139
140- def fill_statedict (state_dict , vars , size ):
140+ def fill_statedict (state_dict , vars , size , n_mlp ):
141141 log_size = int (math .log (size , 2 ))
142142
143- for i in range (8 ):
143+ for i in range (n_mlp ):
144144 update (state_dict , convert_dense (vars , f"G_mapping/Dense{ i } " , f"style.{ i + 1 } " ))
145145
146146 update (
@@ -237,9 +237,15 @@ def fill_statedict(state_dict, vars, size):
237237
238238 size = g_ema .output_shape [2 ]
239239
240- g = Generator (size , 512 , 8 , channel_multiplier = args .channel_multiplier )
240+ n_mlp = 0
241+ mapping_layers_names = g_ema .__getstate__ ()['components' ]['mapping' ].list_layers ()
242+ for layer in mapping_layers_names :
243+ if layer [0 ].startswith ('Dense' ):
244+ n_mlp += 1
245+
246+ g = Generator (size , 512 , n_mlp , channel_multiplier = args .channel_multiplier )
241247 state_dict = g .state_dict ()
242- state_dict = fill_statedict (state_dict , g_ema .vars , size )
248+ state_dict = fill_statedict (state_dict , g_ema .vars , size , n_mlp )
243249
244250 g .load_state_dict (state_dict )
245251
@@ -248,7 +254,7 @@ def fill_statedict(state_dict, vars, size):
248254 ckpt = {"g_ema" : state_dict , "latent_avg" : latent_avg }
249255
250256 if args .gen :
251- g_train = Generator (size , 512 , 8 , channel_multiplier = args .channel_multiplier )
257+ g_train = Generator (size , 512 , n_mlp , channel_multiplier = args .channel_multiplier )
252258 g_train_state = g_train .state_dict ()
253259 g_train_state = fill_statedict (g_train_state , generator .vars , size )
254260 ckpt ["g" ] = g_train_state
@@ -292,5 +298,4 @@ def fill_statedict(state_dict, vars, size):
292298
293299 utils .save_image (
294300 img_concat , name + ".png" , nrow = n_sample , normalize = True , range = (- 1 , 1 )
295- )
296-
301+ )
0 commit comments