|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Define Configuration for StarGAN"""
- import argparse
-
-
- def get_config():
- """Define configuration of Model"""
- parser = argparse.ArgumentParser(description='StarGAN')
-
- # Model configuration.
- parser.add_argument('--c_dim', type=int, default=5, help='dimension of domain labels (1st dataset)')
- parser.add_argument('--c2_dim', type=int, default=7, help='dimension of domain labels (2nd dataset)')
- parser.add_argument('--celeba_crop_size', type=int, default=178, help='crop size for the CelebA dataset')
- parser.add_argument('--rafd_crop_size', type=int, default=256, help='crop size for the RaFD dataset')
- parser.add_argument('--image_size', type=int, default=128, help='image resolution')
- parser.add_argument('--g_conv_dim', type=int, default=64, help='number of conv filters in the first layer of G')
- parser.add_argument('--d_conv_dim', type=int, default=64, help='number of conv filters in the first layer of D')
- parser.add_argument('--g_repeat_num', type=int, default=6, help='number of residual blocks in G')
- parser.add_argument('--d_repeat_num', type=int, default=6, help='number of strided conv layers in D')
- parser.add_argument('--lambda_cls', type=float, default=1, help='weight for domain classification loss')
- parser.add_argument('--lambda_rec', type=float, default=10, help='weight for reconstruction loss')
- parser.add_argument('--lambda_gp', type=float, default=10, help='weight for gradient penalty')
-
- # Training configuration.
- parser.add_argument('--dataset', type=str, default='CelebA', choices=['CelebA', 'RaFD', 'Both'])
- parser.add_argument('--batch_size', type=int, default=4, help='mini-batch size')
- parser.add_argument('--num_iters', type=int, default=200000, help='number of total iterations for training D')
- parser.add_argument('--epochs', type=int, default=59, help='number of epoch')
- parser.add_argument('--num_iters_decay', type=int, default=100000, help='number of iterations for decaying lr')
- parser.add_argument('--g_lr', type=float, default=0.0001, help='learning rate for G')
- parser.add_argument('--d_lr', type=float, default=0.0001, help='learning rate for D')
- parser.add_argument('--n_critic', type=int, default=5, help='number of D updates per each G update')
- parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer')
- parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer')
- parser.add_argument('--resume_iters', type=int, default=200000, help='resume training from this step')
- parser.add_argument('--selected_attrs', '--list', nargs='+', help='selected attributes for the CelebA dataset',
- default=['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'])
- parser.add_argument('--init_type', type=str, default='normal', choices=("normal", "xavier"),
- help='network initialization, default is normal.')
- parser.add_argument('--init_gain', type=float, default=0.02,
- help='scaling factor for normal, xavier and orthogonal, default is 0.02.')
-
- # Test configuration.
- parser.add_argument('--test_iters', type=int, default=200000, help='test model from this step')
-
- # Train Device.
- parser.add_argument('--num_workers', type=int, default=8)
- parser.add_argument('--mode', type=str, default='train', choices=['train', 'test'])
- parser.add_argument('--device_target', type=str, default='Ascend')
- parser.add_argument("--run_distribute", type=int, default=0, help="Run distribute, default: false.")
- parser.add_argument("--device_id", type=int, default=0, help="device id, default: 0.")
- parser.add_argument("--device_num", type=int, default=1, help="number of device, default: 0.")
- parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.")
-
-
- # Directories.
- parser.add_argument('--celeba_image_dir', type=str, default=r'./dataset/img_celeba')
- parser.add_argument('--attr_path', type=str, default=r'./dataset/list_attr_celeba.txt')
- parser.add_argument('--rafd_image_dir', type=str, default='data/RaFD/train')
- parser.add_argument('--log_dir', type=str, default='stargan/logs')
- parser.add_argument('--model_save_dir', type=str, default='./models/')
- parser.add_argument('--result_dir', type=str, default='./results')
-
- # Step size.
- parser.add_argument('--log_step', type=int, default=10)
- parser.add_argument('--sample_step', type=int, default=5000)
- parser.add_argument('--model_save_step', type=int, default=5000)
- parser.add_argument('--lr_update_step', type=int, default=1000)
-
- # export
- parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='MINDIR', \
- help='file format')
-
-
- config = parser.parse_args()
-
- return config
|