-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.py
30 lines (22 loc) · 873 Bytes
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet,self).__init__()
self.conv1=nn.Conv2d(in_channels=3,out_channels=10,kernel_size=3,padding=1)
self.conv2=nn.Conv2d(in_channels=10,out_channels=10,kernel_size=3,padding=1)
self.pool=nn.MaxPool2d(kernel_size=8)
self.dropout1= nn.Dropout(0.25)
self.conv3=nn.Conv2d(in_channels=10,out_channels=128,kernel_size=8)
self.dropout2= nn.Dropout(0.5)
self.conv4=nn.Conv2d(in_channels=128,out_channels=1,kernel_size=1)
def forward(self,x):
x=F.relu(self.conv1(x))
x=F.relu(self.conv2(x))
x=self.pool(x)
x=self.dropout1(x)
x=F.relu(self.conv3(x))
x=self.dropout2(x)
x=torch.sigmoid(self.conv4(x))
return x