forked from babbu3682/MTD-GAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
82 lines (67 loc) · 2.37 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
import torch.nn as nn
import torch.nn.functional as F
# Create Model
from arch.RED_CNN.networks import RED_CNN
from arch.EDCNN.networks import EDCNN
from arch.CTformer.networks import CTformer
from arch.Restormer.networks import Restormer
from arch.Diffusion.networks import DDPM, DDIM, PNDM, DPM
from arch.WGAN_VGG.networks import WGAN_VGG
from arch.MAP_NN.networks import MAP_NN
from arch.DUGAN.networks import DUGAN
from arch.Ours.networks import *
def get_model(name):
# CNN-based models
if name == "RED_CNN":
model = RED_CNN()
elif name == "EDCNN":
model = EDCNN()
# TR-based models
elif name == "CTformer":
model = CTformer(img_size=64, tokens_type='performer', embed_dim=64, depth=1, num_heads=8, kernel=4, stride=4, mlp_ratio=2., token_dim=64)
elif name == "Restormer":
model = Restormer(LayerNorm_type='BiasFree')
# GAN-based models
elif name == "WGAN_VGG":
model = WGAN_VGG()
elif name == "MAP_NN" or name == "MAP_NN_brain":
model = MAP_NN()
elif name == "DU_GAN" or name == "DU_GAN_brain":
model = DUGAN()
# DN-based models
elif name == "DDPM":
model = DDPM()
elif name == "DDIM":
model = DDIM()
elif name == "PNDM":
model = PNDM()
elif name == "DPM":
model = DPM()
# Ours
elif name == "MTD_GAN_Method":
model = MTD_GAN_Method()
# Ablation studies
elif name == "Ablation_CLS":
model = Ablation_CLS()
elif name == "Ablation_SEG":
model = Ablation_SEG()
elif name == "Ablation_CLS_SEG":
model = Ablation_CLS_SEG()
elif name == "Ablation_CLS_REC":
model = Ablation_CLS_REC()
elif name == "Ablation_SEG_REC":
model = Ablation_SEG_REC()
elif name == "Ablation_CLS_SEG_REC":
model = Ablation_CLS_SEG_REC()
elif name == "Ablation_CLS_SEG_REC_NDS":
model = Ablation_CLS_SEG_REC_NDS()
elif name == "Ablation_CLS_SEG_REC_RC":
model = Ablation_CLS_SEG_REC_RC()
elif name == "Ablation_CLS_SEG_REC_NDS_RC":
model = Ablation_CLS_SEG_REC_NDS_RC()
elif name == "Ablation_CLS_SEG_REC_NDS_RC_ResFFT":
model = Ablation_CLS_SEG_REC_NDS_RC_ResFFT()
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Number of Learnable Params:', n_parameters)
return model