|
- import argparse
- import os
- import sys
- import pickle
- import math
-
- import torch
- import numpy as np
- from torchvision import utils
-
- from model import Generator, Discriminator
-
-
- def convert_modconv(vars, source_name, target_name, flip=False):
- weight = vars[source_name + "/weight"].value().eval()
- mod_weight = vars[source_name + "/mod_weight"].value().eval()
- mod_bias = vars[source_name + "/mod_bias"].value().eval()
- noise = vars[source_name + "/noise_strength"].value().eval()
- bias = vars[source_name + "/bias"].value().eval()
-
- dic = {
- "conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0),
- "conv.modulation.weight": mod_weight.transpose((1, 0)),
- "conv.modulation.bias": mod_bias + 1,
- "noise.weight": np.array([noise]),
- "activate.bias": bias,
- }
-
- dic_torch = {}
-
- for k, v in dic.items():
- dic_torch[target_name + "." + k] = torch.from_numpy(v)
-
- if flip:
- dic_torch[target_name + ".conv.weight"] = torch.flip(
- dic_torch[target_name + ".conv.weight"], [3, 4]
- )
-
- return dic_torch
-
-
- def convert_conv(vars, source_name, target_name, bias=True, start=0):
- weight = vars[source_name + "/weight"].value().eval()
-
- dic = {"weight": weight.transpose((3, 2, 0, 1))}
-
- if bias:
- dic["bias"] = vars[source_name + "/bias"].value().eval()
-
- dic_torch = {}
-
- dic_torch[target_name + f".{start}.weight"] = torch.from_numpy(dic["weight"])
-
- if bias:
- dic_torch[target_name + f".{start + 1}.bias"] = torch.from_numpy(dic["bias"])
-
- return dic_torch
-
-
- def convert_torgb(vars, source_name, target_name):
- weight = vars[source_name + "/weight"].value().eval()
- mod_weight = vars[source_name + "/mod_weight"].value().eval()
- mod_bias = vars[source_name + "/mod_bias"].value().eval()
- bias = vars[source_name + "/bias"].value().eval()
-
- dic = {
- "conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0),
- "conv.modulation.weight": mod_weight.transpose((1, 0)),
- "conv.modulation.bias": mod_bias + 1,
- "bias": bias.reshape((1, 3, 1, 1)),
- }
-
- dic_torch = {}
-
- for k, v in dic.items():
- dic_torch[target_name + "." + k] = torch.from_numpy(v)
-
- return dic_torch
-
-
- def convert_dense(vars, source_name, target_name):
- weight = vars[source_name + "/weight"].value().eval()
- bias = vars[source_name + "/bias"].value().eval()
-
- dic = {"weight": weight.transpose((1, 0)), "bias": bias}
-
- dic_torch = {}
-
- for k, v in dic.items():
- dic_torch[target_name + "." + k] = torch.from_numpy(v)
-
- return dic_torch
-
-
- def update(state_dict, new):
- for k, v in new.items():
- if k not in state_dict:
- raise KeyError(k + " is not found")
-
- if v.shape != state_dict[k].shape:
- raise ValueError(f"Shape mismatch: {v.shape} vs {state_dict[k].shape}")
-
- state_dict[k] = v
-
-
- def discriminator_fill_statedict(statedict, vars, size):
- log_size = int(math.log(size, 2))
-
- update(statedict, convert_conv(vars, f"{size}x{size}/FromRGB", "convs.0"))
-
- conv_i = 1
-
- for i in range(log_size - 2, 0, -1):
- reso = 4 * 2 ** i
- update(
- statedict,
- convert_conv(vars, f"{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"),
- )
- update(
- statedict,
- convert_conv(
- vars, f"{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1
- ),
- )
- update(
- statedict,
- convert_conv(
- vars, f"{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False
- ),
- )
- conv_i += 1
-
- update(statedict, convert_conv(vars, f"4x4/Conv", "final_conv"))
- update(statedict, convert_dense(vars, f"4x4/Dense0", "final_linear.0"))
- update(statedict, convert_dense(vars, f"Output", "final_linear.1"))
-
- return statedict
-
-
- def fill_statedict(state_dict, vars, size, n_mlp):
- log_size = int(math.log(size, 2))
-
- for i in range(n_mlp):
- update(state_dict, convert_dense(vars, f"G_mapping/Dense{i}", f"style.{i + 1}"))
-
- update(
- state_dict,
- {
- "input.input": torch.from_numpy(
- vars["G_synthesis/4x4/Const/const"].value().eval()
- )
- },
- )
-
- update(state_dict, convert_torgb(vars, "G_synthesis/4x4/ToRGB", "to_rgb1"))
-
- for i in range(log_size - 2):
- reso = 4 * 2 ** (i + 1)
- update(
- state_dict,
- convert_torgb(vars, f"G_synthesis/{reso}x{reso}/ToRGB", f"to_rgbs.{i}"),
- )
-
- update(state_dict, convert_modconv(vars, "G_synthesis/4x4/Conv", "conv1"))
-
- conv_i = 0
-
- for i in range(log_size - 2):
- reso = 4 * 2 ** (i + 1)
- update(
- state_dict,
- convert_modconv(
- vars,
- f"G_synthesis/{reso}x{reso}/Conv0_up",
- f"convs.{conv_i}",
- flip=True,
- ),
- )
- update(
- state_dict,
- convert_modconv(
- vars, f"G_synthesis/{reso}x{reso}/Conv1", f"convs.{conv_i + 1}"
- ),
- )
- conv_i += 2
-
- for i in range(0, (log_size - 2) * 2 + 1):
- update(
- state_dict,
- {
- f"noises.noise_{i}": torch.from_numpy(
- vars[f"G_synthesis/noise{i}"].value().eval()
- )
- },
- )
-
- return state_dict
-
-
- if __name__ == "__main__":
- device = "cuda"
-
- parser = argparse.ArgumentParser(
- description="Tensorflow to pytorch model checkpoint converter"
- )
- parser.add_argument(
- "--repo",
- type=str,
- required=True,
- help="path to the offical StyleGAN2 repository with dnnlib/ folder",
- )
- parser.add_argument(
- "--gen", action="store_true", help="convert the generator weights"
- )
- parser.add_argument(
- "--disc", action="store_true", help="convert the discriminator weights"
- )
- parser.add_argument(
- "--channel_multiplier",
- type=int,
- default=2,
- help="channel multiplier factor. config-f = 2, else = 1",
- )
- parser.add_argument("path", metavar="PATH", help="path to the tensorflow weights")
-
- args = parser.parse_args()
-
- sys.path.append(args.repo)
-
- import dnnlib
- from dnnlib import tflib
-
- tflib.init_tf()
-
- with open(args.path, "rb") as f:
- generator, discriminator, g_ema = pickle.load(f)
-
- size = g_ema.output_shape[2]
-
- n_mlp = 0
- mapping_layers_names = g_ema.__getstate__()['components']['mapping'].list_layers()
- for layer in mapping_layers_names:
- if layer[0].startswith('Dense'):
- n_mlp += 1
-
- g = Generator(size, 512, n_mlp, channel_multiplier=args.channel_multiplier)
- state_dict = g.state_dict()
- state_dict = fill_statedict(state_dict, g_ema.vars, size, n_mlp)
-
- g.load_state_dict(state_dict)
-
- latent_avg = torch.from_numpy(g_ema.vars["dlatent_avg"].value().eval())
-
- ckpt = {"g_ema": state_dict, "latent_avg": latent_avg}
-
- if args.gen:
- g_train = Generator(size, 512, n_mlp, channel_multiplier=args.channel_multiplier)
- g_train_state = g_train.state_dict()
- g_train_state = fill_statedict(g_train_state, generator.vars, size, n_mlp)
- ckpt["g"] = g_train_state
-
- if args.disc:
- disc = Discriminator(size, channel_multiplier=args.channel_multiplier)
- d_state = disc.state_dict()
- d_state = discriminator_fill_statedict(d_state, discriminator.vars, size)
- ckpt["d"] = d_state
-
- name = os.path.splitext(os.path.basename(args.path))[0]
- torch.save(ckpt, name + ".pt")
-
- batch_size = {256: 16, 512: 9, 1024: 4}
- n_sample = batch_size.get(size, 25)
-
- g = g.to(device)
-
- z = np.random.RandomState(0).randn(n_sample, 512).astype("float32")
-
- with torch.no_grad():
- img_pt, _ = g(
- [torch.from_numpy(z).to(device)],
- truncation=0.5,
- truncation_latent=latent_avg.to(device),
- randomize_noise=False,
- )
-
- Gs_kwargs = dnnlib.EasyDict()
- Gs_kwargs.randomize_noise = False
- img_tf = g_ema.run(z, None, **Gs_kwargs)
- img_tf = torch.from_numpy(img_tf).to(device)
-
- img_diff = ((img_pt + 1) / 2).clamp(0.0, 1.0) - ((img_tf.to(device) + 1) / 2).clamp(
- 0.0, 1.0
- )
-
- img_concat = torch.cat((img_tf, img_pt, img_diff), dim=0)
-
- print(img_diff.abs().max())
-
- utils.save_image(
- img_concat, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1)
- )
|