Semi-supervised Deep Learning for Medical Image Segmentation

Hritam Basak
Heartbeat
Published in
18 min readFeb 28, 2023

--

The past few years have witnessed exponential growth in medical image analysis using deep learning. Be it stroke detection from brain MRIs, melanoma detection, robotic surgery, etc., medical image segmentation has always been the cornerstone of these applications [1].

In this article we will look into medical image segmentation and see how deep learning can be helpful in these cases. We will also look for the research gaps in this field, which could inspire potential future projects. Finally, we will look at some of the recent semi-supervised medical image segmentation algorithms. Let’s dive in!

Medical Image Segmentation Example
Figure 1: Example of the original image and its segmented counterpart -[Source1, Source2]

What is Medical Image Segmentation?

As the name suggests, medical image segmentation refers to the task of extracting the desired region of interest (ROI) from medical images (X-ray, CT, MR, etc.,). This can be done manually by expert clinicians or automatically by using computer-aided systems. While manual segmentation requires lots of expense, time, and expertise, it’s prone to human errors and inter-observer variability.

“The cost associated with the labeling process thus may render large, fully labeled training sets infeasible, whereas acquisition of unlabeled data is relatively inexpensive.” — [Source]

Automated segmentation, on the other hand, is quite accurate, fast, and requires minimal (or no) supervision. The recent developments in automatic medical image segmentation mostly rely on deep learning techniques, which involve training a neural network with a very large amount of data to learn from.

This can be further classified as supervised and unsupervised learning. Supervised training requires pixel-level annotation (i.e. ground truth) for every input image [e.g. U-Net, U-Net++], whereas unsupervised learning eliminates this requirement [see this review paper]. Semi-supervised learning lies in between supervised and unsupervised learning, which we will learn in detail in the following sections.

What is Semi-supervised Learning (SSL)?

SSL is a machine learning paradigm that combines a very small amount of labeled data along with a large amount of unlabelled data for training. It has been observed that unlabelled data, in conjunction with a small amount of labeled data, can incorporate significant improvement in learning accuracy [3, 4]. The workflow for SSL can be summarized in the following diagram.

Semi-supervised Learning
Figure 2: Semi-supervised self-training method. Image by author.

Why do we need SSL in medical image segmentation?

Supervised medical image segmentation is one of the most widely explored fundamental computer vision problems, competent in producing state-of-the-art segmentation performances. These methods have achieved exponential progress in the last decade due to the rapid evolution of deep convolutional networks (ResNet, VGG, etc.), with Fully Convolutional Networks (FCN) [5] being the cornerstone.

However, they rely on the availability of pixel-level annotations for the entire dataset. Acquiring these large-scale medical data, labeled by expert clinicians is quite tedious and expensive, thus methods alleviating this requirement are highly expedient.

“A frequent problem when applying machine learning methods to medical images, is the lack of labeled data, even when larger sets of unlabeled data may be more widely available.” — [Cheplygina et al.]

SSL-based methods are promising directions to this end, requiring a minimal amount of annotations, and producing pseudo labels for a large portion of unlabelled data, which are further used to train the segmentation network. In recent years, SSL-based methods have been widely recognized for their superior performance in downstream tasks (like segmentation, object detection, etc.), not only in natural scene images but also in biomedical image analysis.

Now, that we know what SSL is and why it is necessary for medical image segmentation, let’s look at some of the SSL-based methods that have been developed in the last few years. An awesome, up-to-date, carefully curated list of recent SSL-based medical image segmentation methods along with some of the available codes for ready reference can be found here.

Different SSL methods in medical image segmentation

A consensus exists that “Deep learning is data-hungry.” So, if we provide pixel-level annotation (i.e. ground truth) for only a small portion of training data, the model will produce terrible segmentation masks as shown in the figure below.

Figure 3: Segmentation performance deteriorates with a decrease in the number of labels during training [6].

To alleviate this problem, researchers have developed several SSL techniques in the last few years. The three major SSL methods widely used in the medical domain are:

  1. Consistency Regularization
  2. Self-training and co-training
  3. Uncertainty Estimation

Let’s go through them one by one.

Consistency Regularization

Consistency regularization (CR) is an SSL-based method that is based on the continuity assumption of machine learning [7]. The notion behind the effectiveness of CR is that the segmentation mask from two similar input images should ideally be similar.

The continuity assumption states that close-together datapoints are likely to have the same label. This is the fundamental theory behind Consistency Regularization.

Let’s assume we have a dataset D consisting of N1 numbers of labeled samples and N2 numbers of unlabelled samples (where N1<<N2). So, when we train the model in a semi-supervised setting, we handle the labeled and unlabelled sets differently: the traditional supervised training strategy is utilized for N1 labeled samples whereas, for N2 unlabelled samples, we enforce similarity between predictions from similar data points and penalize their dissimilarities.

First, we sample image A from the unlabelled set and generate its augmentation A’. This augmentation could be a simple geometrical affine transformation (i.e. rotation, cropping, flipping, etc.) [8], random transformation (i.e. color, contrast, brightness change), or adversarial perturbation [9].

None of these transformations incorporates any major fundamental change in the image features, hence the segmentation map of the original image should be consistent with the segmentation of its augmented counterpart. Pictorially, we can represent consistency regularization as shown below:

Consistency Regularization workflow
Figure 4: Overall workflow of Consistency Regularization using data augmentation. Image by author.

Another similar approach to consistency regularization is network perturbation [10]. Here we use two different networks with identical architecture but different network parameters (randomly initialized) in two parallel branches and feed them with a single input image. Ideally, the segmentation output from these two branches should be identical.

However, due to different initialization of network parameters, the segmented maps from the two branches will differ from each other. This inconsistency between two outputs is penalized in the loss function and the network parameters are updated iteratively. The network perturbation approach for consistency regularization is shown in the following figure.

Figure 5: Overall workflow of Consistency Regularization using network perturbation. Image by author.

Now that we understand the basics of consistency regularization, let’s dive into some of the recent works in this domain.

Join 20,000 of your colleagues at Deep Learning Weekly for the latest products, acquisitions, technologies, deep-dives and more.

Recent trends in consistency regularization

Lately, there has been an evolution of numerous methods involving consistency regularization strategies for medical image segmentation. Some of them utilize CR for end-to-end model training [11, 12], whereas others use it to fine-tune a pre-trained model for a specific downstream task [13].

Bortsova et al. proposed [14] proposed a transformation-consistent framework for the segmentation of lung X-ray images by exploring the equivariance of elastic perturbations. Luo et al. [15] proposed a dual-task consistency regularization method between the output of the regression layer and the segmentation layer (i.e. level-set-based segmentation map and pixel-level segmentation layer).

This was further extended by Wang et al. [16], where the authors proposed enforcing three independent task-level consistency between different auxiliary tasks (i.e. the foreground and background reconstruction, etc.). The authors argue that these auxiliary tasks aid the model to learn high-level semantic information from the unlabelled images, leading to superior performance. However, these multi-task consistency regularization methods are heavy, complicated, and require a lot of computational overhead for training.

Recently, an interpolation-based consistency regularization method for semi-supervised medical image segmentation was proposed in the paper “An Embarrassingly Simple Consistency Regularization Method for Semi-supervised Medical Image Segmentation,” published at IEEE International Symposium on Biomedical Imaging (ISBI), 2022. The authors have also provided the working code for this work here: GitHub Code. As the name suggests, the proposed method is exceedingly simple, yet quite effective. So, let’s have an in-depth discussion on this work.

Interpolation based consistency regularization method
Figure 6: Overall framework of our proposed architecture in [17]. Image source: original paper.

The work is based on a simple assumption that “segmentation of an interpolated image from two unlabelled data should be consistent with interpolation of segmentation maps from the same images.” The authors propose a new image perturbation strategy in the adversarial direction by generating interpolation of two unlabelled images. The steps can be summarized as follows:

Figure 7: Pseudo-code of the proposed method in [17]. Image source: original paper.

The proposed method was quite effective for semi-supervised cardiac MRI segmentation on ACDC [18] and MMWHS [19] datasets, achieving state-of-the-art segmentation accuracy. The reported DSC for ACDC were 73.56%, 79.05%, and 89.80% using 1.25%, 2.5%, and 10% labels, respectively. Similar results were also obtained for MMWHS with DSC scores of 66.86%,
71.24%, and 79.83% using 10%, 20%, and 40% labeled volumes, respectively. The comparison of results is shown below.

Source: [17]

The authors also reported a comparison of the results using different percentages of labeled data, as shown below. The method achieved results very close to fully-supervised counterparts by using only a fraction of labels.

Source: [17]

The reported segmentation results in the paper for the two datasets are shown in the following diagram. The authors provided a visual comparison of the segmentation mask with several other existing state-of-the-art semi-supervised methods.

Source: [17]

Another recent work on cross-teaching between CNN and the transformer network was proposed in the paper Semi-Supervised Medical Image Segmentation via Cross Teaching between CNN and Transformer, published in Medical Imaging with Deep Learning (MIDL), 2022. The authors proposed a simple consistency regularization method between a CNN and a Transformer network for semi-supervised medical image segmentation in this work. The overall workflow is shown in the following diagram.

“Specifically, we simplify the classical deep co-training from consistency regularization to cross teaching, where the prediction of a network is used as the pseudo label to supervise the other network directly end-to-end. Considering the difference in learning paradigm between CNN and Transformer, we introduce the Cross Teaching between CNN and Transformer rather than just using CNNs.”— [Source]

Source: [20]

This work was actually inspired by three different existing works from recent literature: Cross-pseudo Supervision [21], Deep Co-training [22], and Co-teaching [23]. This method proposes a network-level perturbation (similar to Figure 5). A CNN operates by using local convolution operations whereas a transformer performs long-range self-attention operations. Hence, the outputs from these two networks essentially have different properties. Now, based on the predicted pseudo labels from both networks, the authors propose a cross-teaching loss as follows:

“Differently from consistency regularization loss, the cross teaching loss is a bidirectional loss function, one stream is from the CNN to the Transformer and the other is the Transformer to the CNN, there are no explicit constraints to enforce their predictions to become similar” — [Source]

The overall objective function is formulated by fusing the standard supervised CE loss along with the unsupervised loss term as follows:

Upon evaluation of the cardiac MRI segmentation task on the ACDC dataset, this method performs extremely well, compared to the existing SOTA. The authors have provided a detailed comparison of the evaluation metrics in terms of the Dice Similarity Coefficient (DSC) and Hausdorff Distance 95% (HD95), as shown in the following table:

Source: [20]

The authors have also provided a visual comparison of the results with several other state-of-the-art semi-supervised methods on 3-label and 7-label ACDC validation set images, as shown below.

Source: [20]

In a nutshell, consistency regularization-based methods have been proven to be quite effective in the recent past. Although these methods may suffer from several drawbacks, with the inclination of the scientific community towards SSL-based methods, we can expect several other exciting papers to come shortly.

Self-training and co-training

Self-training involves training a semi-supervised network in two steps. First, a network trained on a limited amount of annotations in a supervised fashion is used to generate an initial set of predictions from unlabelled images. In the second step, this network is trained again iteratively by using both the labeled images with their ground truths and unlabelled images with their pseudo-labels (generated in the first step) as proxy or false ground truths. These pseudo labels are updates and we expect the quality of these pseudo labels to improve throughout the epochs. However, it has been found that the initial estimation of pseudo labels immensely affects the overall training, and may lead to possible degradation in the segmentation performance [24].

Recently, Chaitanya et al. [25] proposed a self-training strategy along with local contrastive learning for accurate medical image segmentation. The main contribution of the work is summarized as follows:

“In this paper, we propose a local contrastive loss to learn good pixel level features useful for segmentation by exploiting semantic label information obtained from pseudo-labels of unlabeled images alongside limited annotated images. In particular, we define the proposed loss to encourage similar representations for the pixels that have the same pseudo-label/ label while being dissimilar to the representation of pixels with different pseudo-label/label in the dataset. We perform pseudo-label based self-training and train the network by jointly optimizing the proposed contrastive loss on both labeled and unlabeled sets and segmentation loss on only the limited labeled set.” — [Source]

The overall workflow of this work is shown below:

Source: [25]

The overall objective of the method is to learn intra-class compactness and inter-class separability by leveraging class-wise semantic and structural information from the images. This is done in two steps: 1) First, the authors minimize the supervised loss between the predictions and available ground truths for labeled images, and 2) The second optimization step combines the supervised loss along with the proposed pixel-level contrastive loss, as shown in the following equation. For a detailed understanding of the mathematical symbols and expressions, please refer to the original paper [25].

The method performed extremely well in the semi-supervised segmentation task on three medical datasets and outperformed several other SSL-based methods. Moreover, the results are quite close to the fully-supervised benchmark, that too by utilizing only a small percentage of labels, as shown in the table below.

Source: [25]

These results clearly show the efficacy of self-training strategies to yield extremely accurate segmentation performance with many unlabelled data. With a recent shift in research interest in self-training strategies, we can expect exponential growth and many more novel applications of self-training in medical image segmentation to come shortly.

Co-training, on the other hand, works on the intuition that semi-supervised training scenarios can be described by two complementary sets of information (often called views). These views are conditionally independent in general, given the corresponding class labels. The general idea is to train classifiers or segmentation networks simultaneously for each set of views for the labeled set so that their predictions are consistent for the unlabelled set. Thus, the search space is reduced by enforcing agreement between the classifiers, which eventually helps the model generalize well on unseen data.

Existing ML applications may surprise you — watch our interview with GE Healthcare’s Vignesh Shetty to learn how his team is using ML in the healthcare setting.

Co-training has already been successfully used for natural language processing tasks, however, its application in vision-based works has been limited. The prime reason for this limitation is that these methods require complementary networks to learn from unlabelled data, which requires additional computational overhead. Peng et al. [26] proposed an efficient deep co-training strategy for medical image segmentation for the first time, and have set a benchmark. Let’s look at the work.

Source: [26]

“As in standard multi-view learning approaches, we train multiple models in a collaborative manner and, once trained, combine their outputs to predict the labels of new images.” — [Source]

This work was motivated by the success of CNNs in image segmentation tasks, and the authors extended the idea to formulate a deep co-training algorithm. Specifically, the authors train an ensemble of multiple segmentation networks in parallel. They propose three independent loss functions, following their linear combination to formulate the overall training loss.

The first one is the supervised loss computed for labeled samples. Note that, as the overall workflow comprises two models, the authors compute two separate supervised loss terms as shown in the following equations. The supervised loss is a simple pixel-wise cross-entropy loss, computed between the prediction and ground truth.

The second loss proposed in the paper is the ensemble agreement loss. The authors enforce multiple segmentation networks to output similar predictions for the same unlabeled images. Enforcing this mutual agreement may help to improve the generalizability and robustness of individual models by restricting their parameter search space to cross-view consistent solutions. The paper proposes to use a Jensen-Shannon divergence (JSD) loss term to minimize the distance between multiple class distributions obtained from different models.

The final loss term is diversity loss. The most standard way to obtain diversity is to have diverse features from different views or to generate them by splitting the existing features into complementary subsets.

“A key principle of ensemble learning is having diversity between models in the ensemble. If all models learn the same class distribution, then combining their output will not be superior to individual model predictions. In co-training, diversity is essential so that models can learn from one another during training.” — [Source]

The prime condition for the ensemble model to work is the agreement between the agents for unlabelled images. Besides, their prediction must be constrained by the available ground truth for labeled images. Hence, the authors utilize a technique to generate adversarial augmentation from both labeled and unlabelled samples, which in turn is used to teach the other agent models. The diversity loss proposed in the work can be summarized below for the dual-view co-training strategy.

Here H() refers to cross-entropy measurement and g() indicates the adversarial samples. According to the paper, this special loss aids the model to be robust to the adversarial samples generated from other ones, thereby helping to avoid the unwanted collapse of the decision boundary on each other. In the case of both networks being identical, the adversarial loss reaches its maximum value.

Source [26]

The final loss is computed as the summation of these three loss terms, as shown below.

This method shows really promising results, although trained using few annotations. Additionally, they have performed experiments to observe the performance improvements by an increasing number of views and percentage of labeled data. All the results are shown below.

Source [26]
Source [26]

To summarize, co-training is another popular and widely used SSL method for medical image segmentation, although it suffers from a significant drawback. It requires multiple networks to train in parallel, increasing the computational complexity and training time. Computation and generating a combination of multiple segmentation predictions from parallel models during testing also entails great resources, causing these methods to run on edge devices. Besides, finding an optimally weighted combination of multiple loss terms is of utmost importance to ensure loss convergence, which requires extensive experimentation.

Uncertainty Estimation

Uncertainty estimation for medical image segmentation is a crucial aspect of ensuring the reliability and accuracy of automated segmentation methods. One major source of uncertainty in medical image segmentation is the presence of noise in the images. Noise can be caused by a variety of factors, such as imaging equipment, the subject’s motion during the imaging process, and the imaging modality itself. Noise can introduce false edges and structures in the image, making it difficult for the segmentation algorithm to accurately identify the true boundaries of the structures of interest.

Another source of uncertainty in medical image segmentation is the variability in the appearance of structures within the same class. For example, different tumors may have different shapes, sizes, and intensities, making it difficult for the segmentation algorithm to accurately segment them. Additionally, the presence of partial volume effects, where a structure is only partially visible in the image, can also introduce uncertainty in the segmentation process.

To estimate the uncertainty in medical image segmentation, various methods have been proposed. One popular method is the use of Bayesian deep learning, where a probabilistic model is used to estimate the uncertainty in the segmentation. This approach uses a deep neural network to predict a probability map of the segmentation, rather than a discrete segmentation mask. The probability map represents the likelihood of each pixel belonging to a particular structure and can be used to estimate the uncertainty in the segmentation.

Another method for uncertainty estimation in medical image segmentation is the use of ensemble methods. Ensemble methods involve combining multiple segmentation models to produce a more robust and accurate segmentation. For example, one can use multiple trained models on different data subsets, or different architectures, and then combine their predictions to produce a final segmentation. By using ensemble methods, one can estimate the uncertainty by looking at the variability of the predictions of different models.

Uncertainty can also be estimated by using the bootstrap method, where the dataset is resampled multiple times with replacement. Then the segmentation is performed on each resampled dataset, and the variability of the segmentation is used to estimate the uncertainty.

Additionally, Monte Carlo dropout is another method that can be used to estimate uncertainty in medical image segmentation. Monte Carlo dropout involves randomly dropping out neurons during the forward pass of the segmentation algorithm, and averaging the predictions over multiple forward passes. This allows for the estimation of uncertainty by looking at the variability of the predictions.

There are several recent papers that discuss uncertainty estimation for medical image segmentation. This GitHub repository can be a good starting point. These papers provide a good starting point for understanding the current state of the art in uncertainty estimation for medical image segmentation. It is worth noting that these papers focus on different types of medical images, different sources of uncertainty, and different methods to estimate it. It is important to consider the specific task and dataset you are working with when selecting a method for uncertainty estimation in medical image segmentation.

Conclusion

In conclusion, uncertainty estimation for medical image segmentation is a crucial aspect of ensuring the reliability and accuracy of automated segmentation methods. It is increasingly challenging nowadays to obtain pixel-level annotations for large-scale datasets, especially in the medical domain. Generation of these annotated datasets requires domain expertise, plenty of time, and money. Hence, an important and more practical approach is to learn from limited annotations and then generalize them to unlabeled images. That’s why semi-supervised medical image segmentation has massively drawn the attention of the scientific community in the recent past and is going to be the cornerstone of medical image analysis in the upcoming days.

Various sources of uncertainty, such as noise and variability in the appearance of structures, can make it difficult for segmentation algorithms to accurately identify the boundaries of structures of interest. To estimate the uncertainty, various methods such as Bayesian deep learning, ensemble methods, bootstrap method, and Monte Carlo dropout have been proposed, and they can be used to produce more robust and accurate segmentation.

Editor’s Note: Heartbeat is a contributor-driven online publication and community dedicated to providing premier educational resources for data science, machine learning, and deep learning practitioners. We’re committed to supporting and inspiring developers and engineers from all walks of life.

Editorially independent, Heartbeat is sponsored and published by Comet, an MLOps platform that enables data scientists & ML teams to track, compare, explain, & optimize their experiments. We pay our contributors, and we don’t sell ads.

If you’d like to contribute, head on over to our call for contributors. You can also sign up to receive our weekly newsletter (Deep Learning Weekly), check out the Comet blog, join us on Slack, and follow Comet on Twitter and LinkedIn for resources, events, and much more that will help you build better ML models, faster.

--

--