Transfer learning approach to Classify the X-ray image that corresponds to corona disease Using ResNet50 pretrained by ChexNet

Coronavirus adversely has affected people worldwide. There are common symptoms between the Covid19 virus disease and other respiratory diseases like pneumonia or Influenza. Therefore, diagnosing it fast is crucial not only to save patients but also to prevent it from spreading. One of the most reliant methods of diagnosis is through X-ray images of a lung. With the help of deep learning approaches, we can teach the deep model to learn the condition of an affected lung. Therefore, it can classify the new sample as if it is a Covid19 infected patient or not. In this project, we train a deep model based on ResNet50 pretrained by ImageNet dataset and CheXNet dataset. Based on the imbalanced CoronaHack Chest X-Ray dataset introducing by Kaggle we applied both binary and multi-class classification. Also, we compare the results when using Focal loss and Cross entropy loss.


Introduction
The World Health Organization (WHO) labeled the COVID-19 outbreak as a pandemic in January 2020, a crisis that wrought havoc across various sectors such as the economy, politics, and education worldwide.In light of its high transmissibility, immediate measures were prioritized to slow the spread of the virus.A vital part of this response required a dependable, rapid, and widely accessible diagnostic tool, for which medical image processing through deep learning emerged as a solution.This approach is particularly useful because one of the primary indicators of COVID-19 is lung infection, visible on X-ray images.Despite presenting symptoms similar to other respiratory conditions like pneumonia and influenza, X-ray imaging M. Bolhassani offers a reliable means of differentiation.Consequently, throughout the pandemic, numerous organizations have amassed a collection of X-ray images ranging from healthy to COVID-19-affected lungs.
In our research, we first addressed binary classification to differentiate bacterial from viral infections, prompted by the imbalanced nature of existing datasets, using a transfer learning approach.Following this, we extended the technique to a four-category classification system.To mitigate the issue of dataset imbalance, we employed Focal loss, aimed at offsetting the disparity in data class representation.Building on our findings, we sought to enhance the model's accuracy by generating synthetic input data via conditional Generative Adversarial Networks (GANs).

Materials and Methods
Automatic classification of X-ray scan images is a challenging task when we have a highly imbalanced dataset.Therefore, in this article, we are trying to find the best approach to increase the classification accuracy.

Benchmark Datasets
In our research, we utilized the CoronaHack Chest X-Ray Dataset provided by the competition organizers to train a deep learning model for the classification of X-ray images indicating COVID-19 infection [1].The dataset is divided into four primary groups: Normal, Bacteria, Virus, and COVID-19.The distribution of the dataset is notably uneven, containing 1,575 normal cases, 2,778 bacterial cases, 1,494 viral cases, and just 82 COVID-19 cases.To visualize this disparity, we created a class distribution chart that clearly depicts the dataset's imbalance.Additionally, we included a histogram that outlines the allocation of training and testing samples as designated by the competition's organizers, which is depicted in Figure 1.

Data augmentation
Deep learning models require a substantial quantity of labeled data to make precise predictions on test sets.However, acquiring annotated medical imagery is challenging, costly, and time-intensive, necessitating the involvement of medical professionals to annotate or label the images [7].In our context, this means needing specialists to categorize lung X-ray images across a spectrum from normal to those indicating COVID-19 infection.To address the limited availability of training samples, we must devise an effective strategy.
Data augmentation is one approach to mitigate the issue of limited data.This technique allows us to enhance the model's performance by introducing variety in the training data, enabling the model to recognize samples with various alterations.In our study, we have implemented numerous augmentation strategies, such as random horizontal and vertical flips with a 50% probability, and random rotations with a 30% likelihood.

Architecture
To train a model capable of classifying lung X-ray images, we have selected two architectures recognized for their efficacy in medical image processing: DenseNet and ResNet, both of which have demonstrated strong performance with medical datasets.

ResNet
Developed to address the issue of vanishing gradients in deep convolutional networks, ResNet introduces a method for preserving the gradient by using skip connections between blocks.This allows the gradient to have a shortcut path, maintaining its strength throughout the network's depth.A detailed diagram of the ResNet architecture [3] is provided in figure ??.The ResNet50 architecture is a deep learning model designed for image recognition, part of the ResNet (Residual Network) family that introduced the concept of residual learning to ease the training of very deep networks.It comprises 50 layers, including convolutional layers, activation layers (ReLU), batch normalization layers, and pooling layers, structured around residual blocks.These blocks have skip connections that allow inputs to bypass one or two layers and be added back to the output of a layer, combating the vanishing gradient problem in deep networks.The model is initialized with specific parameter settings, such as filters of varying sizes (e.g., 7x7 in the first convolutional layer, followed by 3x3 and 1x1 in subsequent layers) and strides to control the convolutional step.Training ResNet50 involves using a large dataset (e.g., ImageNet) with a backpropagation algorithm, typically employing techniques like stochastic gradient descent, with momentum and weight decay for optimization.The model also uses a softmax function in the output layer for classification.ResNet50's design enables it to learn robust feature representations for a wide variety of images, achieving remarkable accuracy in image classification tasks.

DenseNet
The traditional convolutional neural network has a problem in a case when we use a much deeper network.In this case, since the path of information from input to output becomes very long (as a result of a deep network) for both forward and backward paths, it is likely to face a vanishing gradient.DenseNet [2] were introduced to solve the issue mentioned above.In this architecture, the output of each layer is passed to the input of all other layers.This way, not only does the gradient vanishing solve but also we need to manage fewer parameters in comparison to the vanilla CNN.More details of DenseNet architecture are shown in Figure ?? on the left image.

Loss function
The next step is defining our loss function for the training part.We know that choosing the right loss function which is a hyperparameter depends on the problem we are facing.Therefore, to tackle a multiclassification problem multi-class cross entropy loss (Categorical Cross Entropy Loss Function) seems a wise choice.Equation ( 1) is the formula for this loss function.
Looking deeper into the problem we have, something forces us to be more cautious and it is the imbalanced data samples that we have as the input of our model.Focal loss [4] is another choice that we can leverage its properties to enhance the performance of our model.This loss function tries to generate a class weighting system in order to balance the samples in each batch size of data.Equation (2) shows the details regarding this function.

Metric
One of the most important steps in the deep learning process is to define metrics which means how we evaluate the performance of our model.We chose to use accuracy which is the simplest metric.Accuracy is defined by finding the number of predictions per class divided by the number of predictions on each epoch.
The average accuracy that we report in this project is obtained by averaging the accuracy of each epoch.

Weighted class
Our dataset is not balanced so we are not sure whether in each mini batch of data, all class samples exist or not.Therefore, one of our solutions for the mentioned problem is to use the weightedRandomSampler function in Pytorch.We first calculated the number of samples in each class.Then, we got the reciprocal value of them and considered them as weights of each class.In the next step, we passed the class weights to each class label.

Detail implementation
Transfer learning is our dominant approach to classifying COVID-19-infected cases.According to the fact that we have a small and imbalanced dataset in which covid-19 class is considerably smaller than other classes.This issue can deteriorate our prediction, as a result, we need to seek a method to alleviate the unbalanced class distribution of the classes.We started with ResNet50 architecture pre-trained by the ImageNet dataset.Then, we applied a weight sampling approach to have an equal number of samples from each class.We, also, examine using both categorical cross entropy and focal loss to see the effect of focal loss on our imbalanced dataset.Although our attempts paid off and the accuracy on both training and validation increased, still we tried to improve the accuracy of the model.We know that medical images are different from natural images, for instance, medical images are mostly in grayscale.Hence, a question arises here whether using a model pre-trained by ImageNet which is a collection of natural images, is a wise choice or not.This question motivated us to research this issue so we ended up to [5] in which they trained their model using CheXNet.CheXNet [6] is a model based on DenseNet architecture that is trained on large brain MRI scan images.We decided to use the same approach to check if it improved the performance of our model or not.Finally, we added some COVID-19 images to the dataset (our dataset has a very limited number of COVID-19 samples).
For training the model, we chose an epoch number of 20 for time and GPU restrictions.Also, we used the Adam optimizer with an initial learning rate with a value of 0.001.The learning rate is updated every 10 epochs and multiplied with a factor of 0.5 to avoid vanishing or exploding gradient.

Experimental results
To show our results, we divided each selection of hyperparameters into a section.This will allow us to explain the output and drawbacks of each in detail.We considered Adam optimizer with the learning rate scheduler explained in section 3.In all figures in this section, the images on the left represent Training and validation accuracy vs the number of epochs, while the images on the right illustrate Training and validation loss vs the number of epochs.
The experimental results of each subsection are gathered in Table 1.In this section, we train our model based on ResNet50 pre-trained by ImageNet.The loss function is categorical cross-entropy.

Considering CE loss and weighted-class (PRCEW)
In this section, we train our model based on ResNet50 pre-trained by ImageNet.The loss function is categorical cross entropy while including weighted classes in the model to decrease the adverse effects of imbalanced data on the results.In this section, we train our model based on DenseNet121 pretrained on CheXNet considering CE loss.

Considering FL loss(PDCXFL)
In this section, we train our model based on DenseNet121 pretrained by CheXNet considering FL loss.

Extra COVID-19 samples
The limited number of COVID-19 samples is a problematic facts that prevent our model from performing at its best.Thus, at first, we intended to develop a GAN model to generate COVID-19 images.Though, because of a lack of enough time and GPU memory, and the fact that GAN models are very sensitive to the hyperparameters our developed GANs couldn't generate accurate images.Therefore, we decided to add some existing COVID-19 X-ray images on the Internet to our dataset.The results are shown in

Discussion
Based on the results from our implemented models, the highest validation accuracy we've attained falls below 90 percent.This outcome underscores the constraint posed by a limited training dataset.It's our contention that expanding the size of the labeled dataset could enhance the model's accuracy and its capacity to generalize.Nevertheless, acquiring more labeled data presents a time-consuming challenge.Therefore, exploring semi-supervised or self-supervised learning techniques could offer valuable avenues for refining the performance of our model.

Conclusion
In our study, we explored various strategies for classifying highly imbalanced datasets from Kaggle, focusing on the detection of COVID-19 in lung X-ray images-a critical area of research given the global impact of the pandemic.Our methodology was divided into three principal approaches: utilizing a pre-trained ResNet50 model on ImageNet, employing a pre-trained DenseNet121 on CheXNet, and developing a ResNet50 model from scratch.Our aim was to assess each method's effectiveness to determine the most suitable for analyzing lung X-ray scans.Initial findings indicated that models pre-trained on CheXNet exhibited superior performance on our dataset, surpassing other techniques.Additionally, we observed that implementing focal loss enhanced test and validation set results by addressing the issue of imbalanced input data, despite categorical cross-entropy showing higher training accuracy.A significant challenge was the limited number of COVID-19 samples, which impeded our ability to achieve optimal outcomes.To mitigate this, we incorporated additional COVID-19 samples from another Kaggle challenge.The application of our models, as discussed in section 4, to this augmented dataset led to marked improvements.This underscores the importance of our study in advancing COVID-19 detection, highlighting the potential of machine learning models to improve diagnostic accuracy and, consequently, patient outcomes in the face of this global health crisis.

Figure 1 :
Figure 1: On the left: train-test distribution, on the right: Class distribution of a dataset.

Figure 2 :
Figure 2: Histogram distribution of data samples.

4. 1 . 3
Considering FL loss (PRFL)In this section, we train our model based on ResNet50 pre-trained by ImageNet.The loss function is the focal loss to reduce the imbalanced distribution of sample data.

Figure 4 :
Figure 4: First row: ResNet pre-trained model with CE loss, Second row: ResNet pre-trained model with weighted class CE loss, Third row: ResNet pre-trained model with FL loss.

Table 1 :
Accuracy results for each proposed modelAccording to the results provided in Table1, pretraining DenseNet on the CheXNet dataset outperforms other models.

Table 2 :
Accuracy results for each proposed model