Skip to content
Permalink
f3b3f10520
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
80 lines (71 sloc) 2.02 KB
from torch import nn
from torchsummary import summary
class CNNNetwork(nn.Module):
def __init__(self, dropoutProb):
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=16,
kernel_size=3,
stride=1,
padding=2
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
,nn.BatchNorm2d(16)
)
self.conv2 = nn.Sequential(
nn.Conv2d(
in_channels=16,
out_channels=32,
kernel_size=3,
stride=1,
padding=2
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.BatchNorm2d(32),
nn.Dropout(dropoutProb)
)
self.conv3 = nn.Sequential(
nn.Conv2d(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
padding=2
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.BatchNorm2d(64),
nn.Dropout(dropoutProb)
)
self.conv4 = nn.Sequential(
nn.Conv2d(
in_channels=64,
out_channels=128,
kernel_size=3,
stride=1,
padding=2
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.BatchNorm2d(128),
nn.Dropout(dropoutProb)
)
self.flatten = nn.Flatten()
self.linear = nn.Linear(128 * 5 * 4 , 4)
self.softmax = nn.Softmax(dim=1)
def forward(self, input_data):
x = self.conv1(input_data)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.flatten(x)
logits = self.linear(x)
predictions = self.softmax(logits)
return predictions
if __name__ == "__main__":
cnn = CNNNetwork()
summary(cnn.cuda(), (1, 64, 44))