-
Notifications
You must be signed in to change notification settings - Fork 546
/
main.py
111 lines (96 loc) · 3.74 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse, os, torch
from GAN import GAN
from CGAN import CGAN
from LSGAN import LSGAN
from DRAGAN import DRAGAN
from ACGAN import ACGAN
from WGAN import WGAN
from WGAN_GP import WGAN_GP
from infoGAN import infoGAN
from EBGAN import EBGAN
from BEGAN import BEGAN
"""parsing and configuration"""
def parse_args():
desc = "Pytorch implementation of GAN collections"
parser = argparse.ArgumentParser(description=desc)
parser.add_argument('--gan_type', type=str, default='GAN',
choices=['GAN', 'CGAN', 'infoGAN', 'ACGAN', 'EBGAN', 'BEGAN', 'WGAN', 'WGAN_GP', 'DRAGAN', 'LSGAN'],
help='The type of GAN')
parser.add_argument('--dataset', type=str, default='mnist', choices=['mnist', 'fashion-mnist', 'cifar10', 'cifar100', 'svhn', 'stl10', 'lsun-bed'],
help='The name of dataset')
parser.add_argument('--split', type=str, default='', help='The split flag for svhn and stl10')
parser.add_argument('--epoch', type=int, default=50, help='The number of epochs to run')
parser.add_argument('--batch_size', type=int, default=64, help='The size of batch')
parser.add_argument('--input_size', type=int, default=28, help='The size of input image')
parser.add_argument('--save_dir', type=str, default='models',
help='Directory name to save the model')
parser.add_argument('--result_dir', type=str, default='results', help='Directory name to save the generated images')
parser.add_argument('--log_dir', type=str, default='logs', help='Directory name to save training logs')
parser.add_argument('--lrG', type=float, default=0.0002)
parser.add_argument('--lrD', type=float, default=0.0002)
parser.add_argument('--beta1', type=float, default=0.5)
parser.add_argument('--beta2', type=float, default=0.999)
parser.add_argument('--gpu_mode', type=bool, default=True)
parser.add_argument('--benchmark_mode', type=bool, default=True)
return check_args(parser.parse_args())
"""checking arguments"""
def check_args(args):
# --save_dir
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
# --result_dir
if not os.path.exists(args.result_dir):
os.makedirs(args.result_dir)
# --result_dir
if not os.path.exists(args.log_dir):
os.makedirs(args.log_dir)
# --epoch
try:
assert args.epoch >= 1
except:
print('number of epochs must be larger than or equal to one')
# --batch_size
try:
assert args.batch_size >= 1
except:
print('batch size must be larger than or equal to one')
return args
"""main"""
def main():
# parse arguments
args = parse_args()
if args is None:
exit()
if args.benchmark_mode:
torch.backends.cudnn.benchmark = True
# declare instance for GAN
if args.gan_type == 'GAN':
gan = GAN(args)
elif args.gan_type == 'CGAN':
gan = CGAN(args)
elif args.gan_type == 'ACGAN':
gan = ACGAN(args)
elif args.gan_type == 'infoGAN':
gan = infoGAN(args, SUPERVISED=False)
elif args.gan_type == 'EBGAN':
gan = EBGAN(args)
elif args.gan_type == 'WGAN':
gan = WGAN(args)
elif args.gan_type == 'WGAN_GP':
gan = WGAN_GP(args)
elif args.gan_type == 'DRAGAN':
gan = DRAGAN(args)
elif args.gan_type == 'LSGAN':
gan = LSGAN(args)
elif args.gan_type == 'BEGAN':
gan = BEGAN(args)
else:
raise Exception("[!] There is no option for " + args.gan_type)
# launch the graph in a session
gan.train()
print(" [*] Training finished!")
# visualize learned generator
gan.visualize_results(args.epoch)
print(" [*] Testing finished!")
if __name__ == '__main__':
main()