forked from scotthlee/hamlet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
inversion_detection.py
135 lines (118 loc) · 4.82 KB
/
inversion_detection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
'''Uses EfficientNet B0 to tell whether image colors are inverted.
Notes:
Training files should be in a subfolder named 'img/'. In our example, the
images for training the model were in 'inv/img'. The arg passed to
--train_image_dir should not include the 'img/' extension.
If --model_folder is not specified, the model will be trained from scratch.
If training a model from scratch, filenames for images with inverted
colors should be prefixed with 'inv_'.
'''
import numpy as np
import os
import argparse
import tensorflow as tf
from multiprocessing import Pool
from hamlet import models
from hamlet.tools.image import flip_image
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--base_dir',
type=str,
default='D:/data/hamlet/',
help='directory holding the DICOM files')
parser.add_argument('--train_img_dir',
type=str,
default='inv/',
help='subfolder in base_dir holding the folder \
that holds the training data')
parser.add_argument('--check_img_dir',
type=str,
default='abn_train/',
help='subfolder in base_dir holding the folder \
that holds the images to be checked')
parser.add_argument('--model_folder',
type=str,
default=None,
help='subfolder in checkpoints holding the model file')
parser.add_argument('--batch_size',
type=int,
default=32,
help='minibatch size for the model')
args = parser.parse_args()
# Setting the directories
BASE_DIR = args.base_dir
TRAIN_DIR = args.base_dir + args.train_img_dir
IMG_DIR = args.base_dir + args.check_img_dir
LOG_DIR = 'output/inversion/logs/'
CHECK_DIR = 'output/inversion/checkpoints/'
# Parameters for the data loader
BATCH_SIZE = 64
IMG_HEIGHT = 224
IMG_WIDTH = 224
# Initializing a fresh model
mod = models.EFficientNet(num_classes=1,
full_model=False,
img_height=IMG_HEIGHT,
img_width=IMG_WIDTH,
augmentation=False,
learning_rate=1e-2)
if not args.model_folder:
# Loading the files and generating labels
img_files = os.listdir(TRAIN_DIR + 'img/')
labels = np.array(['inv_' in f for f in img_files], dtype=np.uint8)
# Making the training dataset
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
TRAIN_DIR,
labels=[l for l in labels],
label_mode='int',
validation_split=0.3,
subset='training',
seed=2022,
image_size=(img_height, img_width),
batch_size=batch_size
)
# Making the validation dataset
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
TRAIN_DIR,
labels=[l for l in labels],
label_mode='int',
validation_split=0.3,
subset='validation',
seed=2022,
image_size=(img_height, img_width),
batch_size=batch_size
)
# Setting up callbacks and metrics
tr_callbacks = [
tf.keras.callbacks.EarlyStopping(patience=1,
restore_best_weights=True),
tf.keras.callbacks.ModelCheckpoint(filepath=CHECK_DIR + 'training/',
save_weights_only=True,
monitor='val_loss',
save_best_only=True),
tf.keras.callbacks.TensorBoard(log_dir=LOG_DIR + 'training/')
]
# Fine-tuning the top layer of the model
mod.fit(train_ds,
validation_data=val_ds,
callbacks=tr_callbacks,
epochs=20)
else:
mod.load_weights(CHECK_DIR + args.model_folder)
# Getting the predictions for the main set of x-rays
files_to_check = os.listdir(IMG_DIR)
new_ds = tf.keras.preprocessing.image_dataset_from_directory(
IMG_DIR,
labels=None,
shuffle=False,
image_size=(IMG_HEIGHT, IMG_WIDTH),
batch_size=BATCH_SIZE
)
new_preds = mod.predict(new_ds, verbose=1).flatten().round()
to_flip = np.where(new_preds == 1)[0]
# Flipping images the model thinks are inverted
with Pool() as p:
input = [(files_to_check[i], IMG_DIR) for i in to_flip]
p.starmap(flip_image, input)
p.close()
p.join()