This blog is the implementation of the paper

Photo by v2osk on Unsplash

Efficient Net was introduced not through the crucible of the Image Net competition but in a paper titled “Efficient Net: Rethinking Model Scaling for Convolutional Neural Networks,” authored by Tan and Le in 2019. This groundbreaking work posed a fundamental question: “Can we create a model with significantly fewer parameters than the well-established deep CNN's like Res Net and VGG Net, yet achieve comparable levels of accuracy?” The answer, as it turned out, was a resounding “yes.” However, Efficient Net’s innovation wasn’t solely about reducing parameters; it sought an optimal equilibrium between model depth, width, and resolution. This was achieved through the introduction of a compound scaling method that enabled the creation of efficient yet highly accurate models.

In this blog, Using a Modified 3D Efficient Net on Registered and Skull Stripped OASIS MRI images to classify the subjects into Healthy and MCI, and extending the same model to perform Gender classification, Brain Age detection and Brain Age Deficit Prediction using a Double headed model (One Regression and one classification head)

OASIS is a publicly available and can be access at link. Raw MRI files is not suitable to be used directly in any analysis as each voxel in the MRI image in different subjects will point to different parts of the brain and non-brain matter which do not have any significance might hinder the model’s capability to learn.Keeping all this in mind I used a custom made pipeline including Brain extraction , bias correction and MNI template registering using FSL and the preprocessed MRI images can be found at the link below :

GitHub - blackpearl006/OASIS_MNI_Registered

To know more about the preprocessing steps in detail , read my Blog

MRI preprocessing using FSL

I’ve created a custom dataset class designed to handle MRI images stored in the NifTi format using the NiBabel library. This class loads these images as numpy arrays for easy data manipulation. Additionally, it performs a crucial data preprocessing step: normalization, ensuring that the data is consistently scaled for further analysis.

Moreover, this dataset class takes on the responsibility of labeling each subject. These labels are assigned based on a key clinical metric: the Clinical Dementia Rating (CDR) value. The resulting labels categorize subjects into two meaningful groups: Non-Demented individuals and those with a diagnosis of Probable Alzheimer’s Disease (AD).

class CustomDataset():
def __init__(self, data_dir, csv_path, num_samples = 200, transform = None):
self.filelabels = self.load_class_labels(csv_path)
self.file_paths = self.get_file_paths(data_dir)
self.transform = transform

def __len__(self):
return len(self.file_paths)

def __getitem__(self, idx):
file_path = self.file_paths[idx]
nifti_data = nib.load(file_path)
data = nifti_data.get_fdata()
preprocessed_data = self.preprocess_data(data)

preprocessed_tensor = torch.tensor(preprocessed_data, dtype=torch.float32).unsqueeze(0)

file_id = file_path.split('/')[-1][0:13]
if self.filelabels is not None:
label = torch.tensor(self.filelabels[file_id], dtype=torch.long)
return preprocessed_tensor, label
else:
return preprocessed_tensor

def get_file_paths(self, path):
file_paths = []
scans = os.listdir(path)
scans = [scan for scan in scans if scan[:4] == 'OAS1']
for folder_name in scans:
folder_path = os.path.join(path, folder_name) # Corrected variable name here
if os.path.isdir(folder_path) :
for file_name in os.listdir(folder_path):
file_path = os.path.join(folder_path, file_name)
label = self.filelabels[file_path.split('/')[-1][0:13]]
file_paths.append(file_path)
return file_paths


def load_class_labels(self, csv_path):
df = pd.read_csv(csv_path)
class_labels = {}
for _, row in df.iterrows():
id_value = row['ID']
cdr_value = row['CDR']
class_labels[id_value] = 0 if pd.isna(cdr_value) or float(cdr_value) == 0.0 else 1
return class_labels

def preprocess_data(self, data):
mean = data.mean()
std = data.std()
normalized_data = (data - mean) / std
return normalized_data


csv_path = 'path-to-metadata-file/kaggle.csv'
data_dir = 'path/to/downloaded/OASIS_MNI_Registered'
custom_dataset = CustomDataset(data_dir, csv_path)

I’ve divided the dataset into training, validation, and test sets, allocating 70% for training, 20% for validation, and 10% for testing. During model training, I’ve chosen a batch size of 32. This approach helps strike a balance between computational efficiency and effective model training.

The Efficient Net Architecture

Fundamental Building Block

At the core of this architecture lies a fundamental building block, a straightforward CNN layer enhanced with batch normalization and activated using the SiLU (Sigmoid-weighted Linear Unit) or Swish activation function. This building block forms the basis for our entire implementation.

class CNNBlock3d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0 , groups=1, act=True, bn=True, bias=False):
super(CNNBlock3d, self).__init__()
self.cnn = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=bias) #bias set to False as we are using BatchNorm

# if groups = in_channels then it is for Depth wise convolutional; For each channel different Convolutional kernel
# very limited change in loss but a very high decrease in number of paramteres
# if groups = 1 : normal_conv kernel of size kernel_size**3

self.bn = nn.BatchNorm3d(out_channels) if bn else nn.Identity()
self.silu = nn.SiLU() if act else nn.Identity() ##SiLU <--> Swish same Thing
# 1 layer in MBConv doesn't have activation function

def forward(self, x):
out = self.cnn(x)
out = self.bn(out)
out = self.silu(out)
return out

# return self.silu(self.bn(self.cnn(x)))

The SiLU, also known as the Swish activation function, implemented in our network exhibits a unique characteristic. It permits a small amount of negative gradient to flow within the network, in contrast to the widely popular ReLU activation function, which restricts negative gradient flow. Additionally, although the paper does not explicitly detail batch normalization, we have incorporated it into our model. This addition simplifies the learning process for the model, enhancing its capabilities.

Difference between ReLU and SiLU (Swish)

The Ranking master

The squeeze excitation layer is used to compute attention score for each channels, effectively assigning a value to each channel based on its importance. This process guides us in determining how much emphasis we should place on each channel. Subsequently, the original feature map is then scaled with these channel-specific attention scores, optimizing our model’s performance.

# same architecture as in paper
# reducedim_ratio = 0.25 , same as 1/4

class SqueezeExcitation(nn.Module):
def __init__(self, in_channels, reduced_dim):
super(SqueezeExcitation, self).__init__()
self.se = nn.Sequential(
nn.AdaptiveAvgPool3d(1), # input C x H x W --> C x 1 X 1 ONE value of each channel
nn.Conv3d(in_channels, reduced_dim, kernel_size=1), # expansion
nn.SiLU(), # activation
nn.Conv3d(reduced_dim, in_channels, kernel_size=1), # brings it back
nn.Sigmoid(),
)

def forward(self, x):
return x*self.se(x)

The Tape

The Stochastic Depth mechanism operates like a form of dropout, randomly deactivating some layers but only during the training phase. This randomness adds a dynamic element to the network, making it adaptive and robust. It allows for variability in the network’s depth, ensuring that different configurations are explored during training to enhance model performance.

class StochasticDepth(nn.Module):
def __init__(self, survival_prob=0.8):
super(StochasticDepth, self).__init__()
self.survival_prob =survival_prob

def forward(self, x): #form of dropout , randomly remove some layers not during testing
if not self.training:
return x
binary_tensor = torch.rand(x.shape[0], 1, 1, 1, 1, device= x.device) < self.survival_prob # maybe add 1 more here
return torch.div(x, self.survival_prob) * binary_tensor

The MobileConv3d

The MBConv3d class is instrumental in our model and is responsible for various phases, including expansion, depthwise convolution, squeeze excitation, and the output phase. It introduces depth-adaptive and efficient features to the network, enhancing its capacity to learn and make accurate predictions while also considering potential downsampled situations. The Stochastic Depth mechanism is utilized to introduce randomness during training, further contributing to the model's robustness.

class MBConv3d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
expand_ratio = 6,
reduction = 4, #squeeze excitation 1/4 = 0.25
survival_prob =0.8 # for stocastic depth
):
super(MBConv3d, self).__init__()

self.survival_prob = 0.8
self.use_residual = in_channels == out_channels and stride == 1 # Important if we downsample then we can't use skip connections
hidden_dim = int(in_channels * expand_ratio)
self.expand = in_channels != hidden_dim # every first layer in MBConv
reduced_dim = int(in_channels/reduction)
self.padding = padding

##expansion phase

self.expand = nn.Identity() if (expand_ratio == 1) else CNNBlock3d(in_channels, hidden_dim, kernel_size = 1)

##Depthwise convolution phase
self.depthwise_conv = CNNBlock3d(hidden_dim, hidden_dim,
kernel_size = kernel_size, stride = stride,
padding = padding, groups = hidden_dim
)

# Squeeze Excitation phase
self.se = SqueezeExcitation(hidden_dim, reduced_dim = reduced_dim)

#output phase
self.pointwise_conv = CNNBlock3d(hidden_dim, out_channels, kernel_size = 1, stride = 1, act = False, padding = 0)
# add Sigmoid Activation as mentioned in the paper

# drop connect
self.drop_layers = StochasticDepth(survival_prob = survival_prob)


def forward(self, x):

residual = x
x = self.expand(x)
x = self.depthwise_conv(x)
x = self.se(x)
x = self.pointwise_conv(x)

if self.use_residual: #and self.depthwise_conv.stride[0] == 1:
x = self.drop_layers(x)
x += residual
return x
Expansion Phase:
In this phase, the network aims to increase the channel dimensions of the input features.
Expansion helps create a richer representation by transforming low-dimensional feature maps into higher-dimensional ones.
It typically employs 1x1 convolutions to increase the number of channels.
Depthwise Convolution Phase:
This phase focuses on the spatial dimension of the features, aiming to capture local patterns.
Depthwise convolutions apply a separate convolutional filter to each input channel, which helps reduce the computational cost.
They are particularly useful for detecting spatial features within the data.
Squeeze Excitation Phase:
This phase introduces attention mechanisms to assign importance to different channels.
It involves two steps: squeezing and exciting.
The “squeezing” step computes the global statistics for each channel, often through global average pooling.
The “exciting” step learns how to re-weight these channels, emphasizing or de-emphasizing them based on their importance.
Squeeze Excitation enhances the network’s capability to focus on relevant features.
Output Phase:
After feature extraction, the output phase involves 1x1 convolutions to reduce the number of channels and prepare the features for the final classification or regression tasks.
Optionally, activation functions like Sigmoid may be applied to the output to ensure that the network’s predictions fall within a suitable range, such as [0, 1] for binary classification tasks.

Put all the Pieces together

class EfficeientNet3d(nn.Module):
def __init__(self, width_mult=1, depth_mult=1, dropout_rate=0.1, num_classes=2):
super(EfficeientNet3d, self).__init__()
last_channels = ceil(512 * width_mult)

self.first_layer = CNNBlock3d(1, 64, kernel_size=7, stride=2, padding=3)
self.pool = nn.MaxPool3d(1, stride=2)
self.features = self._feature_extractor(width_mult, depth_mult, last_channels)
self.classifier = nn.Sequential(
nn.Dropout(dropout_rate),
nn.Linear(last_channels * 3 * 4 * 3, 400),
nn.Linear(400, 64),
nn.Linear(64, num_classes), # Adjust the output size based on the number of classes
)

def _feature_extractor(self, width_mult, depth_mult, last_channel):
# Your previous code for scaling channels and layers

layers = []
in_channels = 64 # Initial input channels after the first layer
final_in_channel = 0 #Initialzse

# Define configurations for the custom MBConv blocks
mbconv_configurations = [
(3, 1, 64, 64, 1),
(5, 2, 64, 96, 1),
(5, 2, 96, 128, 2),
(5, 2, 128, 192, 3),
(3, 1, 192, 256, 1),
]

for kernel_size, stride, in_channels, out_channels, repeats in mbconv_configurations:
layers += [
MBConv3d(in_channels if repeat == 0 else out_channels,
out_channels,
kernel_size=kernel_size,
stride=stride if repeat == 0 else 1,
expand_ratio=1, # Assuming you want expansion factor 1 for these blocks
padding=kernel_size // 2
)
for repeat in range(repeats)
]
final_in_channel = out_channels
print(f'in_channels : {in_channels}, out_channels: {out_channels}, kernelsize : {kernel_size}, stride: {stride}, repeats: {repeats}')
# print(f'final_in_channels : {final_in_channel}')
layers.append(MBConv3d(final_in_channel, last_channel, kernel_size=1, stride=1, padding=0))
return nn.Sequential(*layers)

def forward(self, inputs):
out = self.first_layer(inputs)
out = self.pool(out)
x = self.features(out)
dummy = x.view(x.shape[0], -1)
out = self.classifier(dummy)
return out

The “EfficeientNet3d” class comprises several crucial components. It starts with “first_layer,” which efficiently extracts fundamental features from the input 3D images. Subsequently, a pooling operation is applied to further refine the data.

The heart of this architecture lies in the “features” section, where a sequence of custom-designed MBConv3d blocks comes into play. These blocks expertly manage channel expansion, depthwise convolutions, and squeeze excitation, allowing the network to capture intricate details within the medical images effectively. The hyper parameters are in accordance with the paper

The “EfficeientNet3d” model also includes a “classifier” section for interpreting the extracted features and producing meaningful predictions. This classifier can be configured to suit the specific classification problem, whether it involves gender classification, brain age detection, or brain age deficit prediction.

With this versatile EfficeientNet3d model, you have the flexibility to tackle a wide range of tasks that align with your specific use case. Whether it’s gender classification, brain age detection, or brain age deficit prediction, the architecture can be adapted and fine-tuned to address your unique requirements. Its adaptability makes it a powerful tool that can be customized for a variety of medical imaging and analysis tasks, providing you with the means to achieve your goals effectively.

You can access my presentation of the implementation through the provided link here, The complete code is available in my kaggle notebook. I want to highlight that the Gender classification, Brain Age detection, and Brain Age Deficit Prediction have already been successfully executed. If you require any assistance or have questions, feel free to reach out to me at reachninadaithal@gmail.com. I’m here to help and provide support for any inquiries or further insights.