|
- import argparse
-
- import torch
- from torchvision import utils
-
- from model import Generator
-
-
- if __name__ == "__main__":
- torch.set_grad_enabled(False)
-
- parser = argparse.ArgumentParser(description="Apply closed form factorization")
-
- parser.add_argument(
- "-i", "--index", type=int, default=0, help="index of eigenvector"
- )
- parser.add_argument(
- "-d",
- "--degree",
- type=float,
- default=5,
- help="scalar factors for moving latent vectors along eigenvector",
- )
- parser.add_argument(
- "--channel_multiplier",
- type=int,
- default=2,
- help='channel multiplier factor. config-f = 2, else = 1',
- )
- parser.add_argument("--ckpt", type=str, required=True, help="stylegan2 checkpoints")
- parser.add_argument(
- "--size", type=int, default=256, help="output image size of the generator"
- )
- parser.add_argument(
- "-n", "--n_sample", type=int, default=7, help="number of samples created"
- )
- parser.add_argument(
- "--truncation", type=float, default=0.7, help="truncation factor"
- )
- parser.add_argument(
- "--device", type=str, default="cuda", help="device to run the model"
- )
- parser.add_argument(
- "--out_prefix",
- type=str,
- default="factor",
- help="filename prefix to result samples",
- )
- parser.add_argument(
- "factor",
- type=str,
- help="name of the closed form factorization result factor file",
- )
-
- args = parser.parse_args()
-
- eigvec = torch.load(args.factor)["eigvec"].to(args.device)
- ckpt = torch.load(args.ckpt)
- g = Generator(args.size, 512, 8, channel_multiplier=args.channel_multiplier).to(args.device)
- g.load_state_dict(ckpt["g_ema"], strict=False)
-
- trunc = g.mean_latent(4096)
-
- latent = torch.randn(args.n_sample, 512, device=args.device)
- latent = g.get_latent(latent)
-
- direction = args.degree * eigvec[:, args.index].unsqueeze(0)
-
- img, _ = g(
- [latent],
- truncation=args.truncation,
- truncation_latent=trunc,
- input_is_latent=True,
- )
- img1, _ = g(
- [latent + direction],
- truncation=args.truncation,
- truncation_latent=trunc,
- input_is_latent=True,
- )
- img2, _ = g(
- [latent - direction],
- truncation=args.truncation,
- truncation_latent=trunc,
- input_is_latent=True,
- )
-
- grid = utils.save_image(
- torch.cat([img1, img, img2], 0),
- f"{args.out_prefix}_index-{args.index}_degree-{args.degree}.png",
- normalize=True,
- range=(-1, 1),
- nrow=args.n_sample,
- )
|