This repository has been archived by the owner on Sep 25, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 12
/
model.py
68 lines (60 loc) · 2.22 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from config import torch
import torch.nn as nn
import torchvision.models as models
class FireFinder(nn.Module):
"""
A model to classify aerial images that could potentially Fire from satellite images
We are using a pretrained resnet backbone model
and images given to model are classified into one of 3 classes.
0 - no Fire
1 - Fire
We currently use the resnet50 model as a backbone
"""
def __init__(
self,
backbone="resnet18",
simple=True,
dropout=0.4,
n_classes=2,
feature_extractor=False,
):
super(FireFinder, self).__init__()
backbones = {
"resnet18": models.resnet18,
"resnet34": models.resnet34,
"resnet50": models.resnet50,
"resnet101": models.resnet101,
"efficientnet_b0": lambda pretrained: models.efficientnet_b0(
pretrained=pretrained
),
}
try:
self.network = backbones[backbone](pretrained=True)
if backbone == "efficientnet_b0":
self.network.classifier[1] = nn.Linear(1280, n_classes)
else:
self.network.fc = nn.Linear(self.network.fc.in_features, n_classes)
except KeyError:
raise ValueError(f"Backbone model '{backbone}' not found")
if feature_extractor:
print("Running in future extractor mode.")
for param in self.network.parameters():
param.requires_grad = False
else:
print("Running in Finetuning mode.")
for m, p in zip(self.network.modules(), self.network.parameters()):
if isinstance(m, nn.BatchNorm2d):
p.requires_grad = False
if not simple and backbone != "efficientnet_b0":
fc = nn.Sequential(
nn.Linear(self.network.fc.in_features, 256),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(256, n_classes),
)
for layer in fc.modules():
if isinstance(layer, nn.Linear):
nn.init.xavier_uniform_(layer.weight)
self.network.fc = fc
def forward(self, x_batch):
return self.network(x_batch)