Introduction to Radiology in Machine Learning & Multi-Image Segmentation with TransUNet
Deep learning-based image segmentation is becoming essential in computer vision. Domains such as biomedical ML, autonomous vehicles & robotics, machine recognition, face detection, etc. rely on image segmentation to identify regions of interest in a given image to make decisions. In radiology, image segmentation seems promising as it helps radiologists to quickly identify affected areas in CT scans or MRIs. Radiologists can use such tools to diagnose patients precisely, providing them with relevant treatments.
Deep learning-based image segmentation tools are potent as they can identify tiny anomalies that human eyes can miss. As deep learning algorithms develop, novel SOTA models are being released, laying the foundation for next-gen radiology.
U-Net, the most used architecture for image segmentation in biomedicine, has paved its way as the primary architecture. This architecture uses convolutional neural networks as its main component. But with the introduction of transformers, convolutional neural networks are either being hybridized or removed for better/more accurate results.
In this article, we will discuss how UNet, hybridized with CNN and transformers, can yield better results. Apart from that, we will also explore the following:
- What is image segmentation?
- How is image segmentation used in image processing?
- How is image segmentation used in radiology machine learning?
- What is U-Net?
- What are Transformers?
Finally, we will learn how to develop our multi-segment U-Net model by hybridizing transformers and convolution neural networks.
What is Image Segmentation?
Image segmentation is computer vision task separating different objects in the image at a pixel level. Essentially, the pixels of all the individual objects in the image are clustered and grouped as a set of pixels, separating one object from another. These clustered pixels are then assigned a particular color known as a mask and a label for identification. This makes image analysis more straightforward, and meaningful information can be derived efficiently.
Source: Deep Lake
How many types of image segmentation exist?
As such, there are three types of image segmentation:
Semantic Segmentation: It is a segmentation task where different objects of the same class are segmented together. For instance, all the objects belonging to humans will be segmented with the same segmentation mask. Likewise, all the objects representing trees will be segmented with the same segmentation mask.
Instance Segmentation: Here, all objects, regardless of whether they belong to the same class, are uniquely segmented with different segmentation masks. Instance segmentation, as the name suggests, creates segmentation masks based on the number of instances.
Panoptic Segmentation: A combo of semantic and instance segmentation, utilizing the strengths of both approaches to create accurate masks for both well-defined and less distinct objects, resulting in improved overall performance.
Source: Deep Lake
The image above shows the difference between the different types of segmentations.
How is Segmentation Used in Radiology?
In radiology, image segmentation is used to detect different parts of an object in a medical image (CT, X-rays, & MRI scans) by putting a layer of segmentation masks (or painting a set of pixels relating to the object) and assigning labels to the same to detect anomalies such as malignant cells or find defects such as a none fracture bones for diagnosis and screening purposes.
U-Net: Convolutional Networks for Biomedical Image Segmentation
The U-Net architecture stands and remains the state-of-the-art architecture for image segmentation. It was introduced in 2015 by Ronneberger et al. The paper describes a CNN with a U-Net design.
Source: U-Net: Convolutional Networks for Biomedical Image Segmentation
The U-Net design is specifically used to extract features and construct precise segmentation masks. It consists of two main components: a downsampling block and an upsampling block connected at each layer.
The downsampling block is made up of a CNN that extracts features of the given image and yields two sets of outputs. One set is fed into the max-pooling layer, where the image is reduced, while the other is fed to the adjacent upsampling block, where localization is performed. See the image below.
Source: U-Net: Convolutional Networks for Biomedical Image Segmentation
The red arrow represents max-pooling, i.e., the reduction of the image, while the gray arrow represents the passing of the image to the upsampling block.
The upsampling blocks are made of transpose-CNN that expands the size of the image. Essentially, it takes two inputs. The first input comes from the adjacent downsampling block. The second input comes from the lower upsampling block. The network adds padding via transpose convolution operation for the latter, which has a reduced dimension. The upsampling block’s job is to concatenate the two inputs and generate a localized output of the original image. See the image below.
Source:U-Net: Convolutional Networks for Biomedical Image Segmentation
The green arrow represents upsampling of the feature where it can be concatenated to the larger features. Once the image is concatenated, the CNN performs extraction and localization.
We must bear in mind that each downsampling block is connected to a consecutive upsampling block. In effect, this allows for good extraction of the features through the downsampling blocks, followed by sound localization of the extracted features by the upsampling blocks at each layer.
Here is the code for a basic U-net:
1
2class Decoder(nn.Module):
3 def __init__(self, in_channels, middle_channels, out_channels):
4 super(Decoder, self).__init__()
5 self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
6 self.conv_relu = nn.Sequential(
7 nn.Conv2d(middle_channels, out_channels, kernel_size=3, padding=1),
8 nn.ReLU(inplace=True)
9 )
10 def forward(self, x1, x2):
11 x1 = self.up(x1)
12 x1 = torch.cat((x1, x2), dim=1)
13 x1 = self.conv_relu(x1)
14 return x1
15
1class UNet(nn.Module):
2 def __init__(self, n_class):
3 super().__init__()
4
5 self.base_model = torchvision.models.resnet18(True)
6 self.base_layers = list(self.base_model.children())
7 self.layer1 = nn.Sequential(
8 nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
9 self.base_layers[1],
10 self.base_layers[2])
11 self.layer2 = nn.Sequential(*self.base_layers[3:5])
12 self.layer3 = self.base_layers[5]
13 self.layer4 = self.base_layers[6]
14 self.layer5 = self.base_layers[7]
15 self.decode4 = Decoder(512, 256+256, 256)
16 self.decode3 = Decoder(256, 256+128, 256)
17 self.decode2 = Decoder(256, 128+64, 128)
18 self.decode1 = Decoder(128, 64+64, 64)
19 self.decode0 = nn.Sequential(
20 nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
21 nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=False),
22 nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False)
23 )
24 self.conv_last = nn.Conv2d(64, n_class, 1)
25
26 def forward(self, input):
27 e1 = self.layer1(input) # 64,128,128
28 e2 = self.layer2(e1) # 64,64,64
29 e3 = self.layer3(e2) # 128,32,32
30 e4 = self.layer4(e3) # 256,16,16
31 f = self.layer5(e4) # 512,8,8
32 d4 = self.decode4(f, e4) # 256,16,16
33 d3 = self.decode3(d4, e3) # 256,32,32
34 d2 = self.decode2(d3, e2) # 128,64,64
35 d1 = self.decode1(d2, e1) # 64,128,128
36 d0 = self.decode0(d1) # 64,256,256
37 out = self.conv_last(d0) # 1,256,256
38 return out
39
We can consider UNet a symmetric encoder-decoder network that leverages skip connections to extract vital information from the image and enhance detail retention. Since UNet leverages convolutional operation to extract features, it fails to model long-range dependency explicitly. This issue is because of the intrinsic locality of convolution operations. To tackle this issue, we use transformers.
What is Transformer in ML?
The Transformer architecture was introduced by Vaswani et al. in 2016. It was designed for sequential modeling, language modeling, or NLP. One of the reasons why transformers are popular is because of the self-attention mechanism, which yields global feature extraction compared to local feature extraction, as we see in CNN. Due to its global feature extraction capabilities, it serves as a decent tool for image segmentation.
Source: Attention Is All You Need
Self-Attention
Self Attention is an attention mechanism that operates the different positions of a sequence, namely a key-value pair and query, to yield good representations. The attention mechanism maps a given sequence to an output with a Scaled Dot-Product operation (as shown in the figure above).
Source: Attention Is All You Need
Q, K, and V are Query, Key, and Value in the formula above. dk is the scaling factor. d is usually the size of attention heads. Essentially, d can be calculated by the following formula:
d = hidden feature/number of attention head
Source: Attention Is All You Need
Notably, the capabilities of producing rich representation with self-attention are significantly increased by using the multi-head attention mechanism. According to the authors, it “allows the model to jointly attend to information from different representation subspaces at different positions”.
This operation allows the model to yield a global representation of the given input. In NLP, self-attention enables the model to learn the long-term dependencies, i.e., the correlation between the current output and the previous outputs, so that the contextual information is preserved over time.
In computer vision, the self-attention mechanism can enable the model to focus on the critical part of the image. It can neglect or discern the objects, such as the background from the foreground, and yield results based on the same. It can also enable models to preserve essential features extracted in the early stages of the convolution operation.
Source: Show, Attend and Tell: Neural Image Caption Generation with Visual Attention
In the image classification task, the model can concentrate or pay attention to the objects that yield good attention scores. The image above shows that the model focuses on the two humans more than the background and the frisbee. Hence the self-attention mechanism does a great job of finding meaningful feature representations in the given distribution compared to any other architecture.
Here is the code for the multi-head self-attention mechanism written in Pytorch:
1class Attention(nn.Module):
2 def __init__(self, num_attention_heads, hidden_size, attention_dropout_rate):
3 super(Attention, self).__init__()
4 self.num_attention_heads = num_attention_heads
5 self.attention_head_size = int(hidden_size / self.num_attention_heads)
6 self.all_head_size = self.num_attention_heads * self.attention_head_size
7
8 self.query = Linear(hidden_size, self.all_head_size)
9 self.key = Linear(hidden_size, self.all_head_size)
10 self.value = Linear(hidden_size, self.all_head_size)
11
12 self.out = Linear(hidden_size, self.all_head_size)
13 self.attn_dropout = Dropout(attention_dropout_rate)
14 self.proj_dropout = Dropout(attention_dropout_rate)
15
16 self.softmax = torch.nn.Softmax(dim=-1)
17
18 def transpose_for_scores(self, x):
19 new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
20 #multihead attention
21 x = x.view(*new_x_shape)
22 return x.permute(0, 2, 1, 3)
23
24 def forward(self, hidden_states):
25 #finding representations
26 mixed_query_layer = self.query(hidden_states)
27 mixed_key_layer = self.key(hidden_states)
28 mixed_value_layer = self.value(hidden_states)
29
30 #multi-head view
31 query_layer = self.transpose_for_scores(mixed_query_layer)
32 key_layer = self.transpose_for_scores(mixed_key_layer)
33 value_layer = self.transpose_for_scores(mixed_value_layer)
34
35 #calculating attention
36 attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
37 attention_scores = attention_scores / math.sqrt(self.attention_head_size)
38 attention_probs = self.softmax(attention_scores)
39 weights = attention_probs
40 attention_probs = self.attn_dropout(attention_probs)
41
42 context_layer = torch.matmul(attention_probs, value_layer)
43 context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
44 new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
45 context_layer = context_layer.view(*new_context_layer_shape)
46 attention_output = self.out(context_layer)
47 attention_output = self.proj_dropout(attention_output)
48
49 return attention_output, weights
50
The multi-head self-attention module is accompanied by various other modules, such as Norm and a linear layer. This whole set of mechanisms can be known as an attention block. See the image below.
Source: Attention Is All You Need
We can enclose the above diagram in the following code:
1class Block(nn.Module):
2 def __init__(
3 self,
4 num_attention_heads,
5 hidden_size,
6 linear_dim,
7 dropout_rate,
8 attention_dropout_rate,
9 eps,
10 std_norm,
11 ):
12
13 super(Block, self).__init__()
14 self.hidden_size = hidden_size
15 self.attention_norm = LayerNorm(hidden_size, eps=eps)
16 self.ffn_norm = LayerNorm(hidden_size, eps=eps)
17 self.ffn = Mlp(
18 hidden_size=hidden_size,
19 linear_dim=linear_dim,
20 dropout_rate=dropout_rate,
21 std_norm=std_norm,
22 )
23 self.attn = Attention(
24 num_attention_heads=num_attention_heads,
25 hidden_size=hidden_size,
26 attention_dropout_rate=attention_dropout_rate,
27 )
28
29 def forward(self, x):
30 h = x
31 x = self.attention_norm(x)
32 x, weights = self.attn(x)
33 x = x + h
34 h = x
35 x = self.ffn_norm(x)
36 x = self.ffn(x)
37 x = x + h
38 return x, weights
39
Encoder
In a transformer, the encoder consists of three major components:
- Embedding layer: It creates n x n number of patches of the image, which is then sequentially fed into the Multi-head self-attention block.
- Multi-head self-attention: block yields global representations or features extracted from the given input.
- Multilayer perceptron: the MLP consists of two linear layers with Gaussian Error Linear Unit (GELU). Essentially, it is a classification head that is local and translationally equivariant.
Source: An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale
The figure above represents a schematic diagram of the Vision Transformer. It is worth noting that the Vision Transformer uses only the encoder part of the entire Transformer architecture and not the decoder. A classification head extends the encoder block of the ViT.
To better understand the architecture, we will explore each module separately. We will explore the encoder module, the transformer module, and the entire ViT module:
Encoder module:
The core idea of the encoder module is to create a list of attention blocks and then pass the input one at a time through the attention block, followed by normalizing them.
1class Encoder(nn.Module):
2 def __init__(
3 self,
4 num_layers,
5 hidden_size,
6 num_attention_heads,
7 linear_dim,
8 dropout_rate,
9 attention_dropout_rate,
10 eps,
11 std_norm,
12 ):
13 super(Encoder, self).__init__()
14 self.layer = nn.ModuleList()
15 self.encoder_norm = LayerNorm(hidden_size, eps=eps)
16 for _ in range(num_layers):
17 layer = Block(
18 num_attention_heads,
19 hidden_size,
20 linear_dim,
21 dropout_rate,
22 attention_dropout_rate,
23 eps,
24 std_norm,
25 )
26 self.layer.append(copy.deepcopy(layer))
27
28 def forward(self, hidden_states):
29 attn_weights = []
30 for layer_block in self.layer:
31 hidden_states, weights = layer_block(hidden_states)
32 attn_weights.append(weights)
33 encoded = self.encoder_norm(hidden_states)
34 return encoded, attn_weights
35
Transformer module:
This module sequentially arranges the embedding module and the encoder module.
1class Transformer(nn.Module):
2 def __init__(
3 self,
4 img_size,
5 hidden_size,
6 in_channels,
7 num_layers,
8 num_attention_heads,
9 linear_dim,
10 dropout_rate,
11 attention_dropout_rate,
12 eps,
13 std_norm,
14 ):
15 super(Transformer, self).__init__()
16 .
17 .
18 .
19
20 def forward(self, input_ids):
21 embedding_output = self.embeddings(input_ids)
22 encoded, attn_weights = self.encoder(embedding_output)
23 return encoded, attn_weights
24
Vision Transformer:
The ViT module, like the transformers module, combines just two modules: the transformer module and a linear layer that acts as a classification head. Because it is a classification head, the final output has to pass through the cross-entropy loss.
1class VisionTransformer(nn.Module):
2 def __init__(
3 self,
4 img_size,
5 num_classes,
6 hidden_size,
7 in_channels,
8 num_layers,
9 num_attention_heads,
10 linear_dim,
11 dropout_rate,
12 attention_dropout_rate,
13 eps,
14 std_norm,
15 ):
16 super(VisionTransformer, self).__init__()
17 self.classifier = "token"
18
19 self.transformer = Transformer(
20 img_size,
21 hidden_size,
22 in_channels,
23 num_layers,
24 num_attention_heads,
25 linear_dim,
26 dropout_rate,
27 attention_dropout_rate,
28 eps,
29 std_norm,
30 )
31 self.head = Linear(hidden_size, num_classes)
32
33 def forward(self, x, labels=None):
34 x, attn_weights = self.transformer(x)
35 logits = self.head(x[:, 0])
36
37 if labels is not None:
38 loss_fct = CrossEntropyLoss()
39 loss = loss_fct(logits.view(-1, 400), labels.view(-1))
40 return loss
41 else:
42 return logits, attn_weights
43
Hybridizing Encoder and UNet
So far, we have discussed CNN-based UNet and Attention-based Transformers. Further on, we will see how we can combine both to create a hybrid architecture.
We want to create a hybridized UNet because both CNN-based UNet and Transformers have advantages and disadvantages. Combining both can help us leverage both models’ advantages in a single model.
We have learned that CNN-based UNet has a significant limitation because of the convolutional operations that make them unable to model long-term dependencies. Transformers, on the other hand, can model long-term dependencies. So why not use transformers instead of CNN?
Although transformers can model long-term dependencies, they cannot model local features, which CNN is best at. Hence, combining both models can help us to leverage the power of both the CNN and self-attention mechanism to yield better results.
How can we combine them both?
In the released paper, “TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation, " the authors proposed using CNN to create image patches or image embedding that can be fed into the transformer’s encoder. The transformer can encode image patches from a CNN feature map as the input sequence to extract global features. The decoder, made up of transpose CNN can then upsamples the encoded features. Like the original UNet at each layer, the upsample features will be combined with the high-resolution CNN feature maps to yield high-quality and precise localization. For a better understanding, look at the image below.
Source: TransUNet | Transformers Make Strong Encoders for Medical Image Segmentation
As you can see, the CNN layer extracts local features, and the transformer extracts global features. The local features are combined or concatenated with the upsampling block at each layer for precise segmentation masks. To comprehend the whole architecture, we will break it into different modules and explore the working of what each module does.
Encoder:
We will explore and understand the code released by the authors in their paper.
To start with, let us explore the CNN module. The CNN module leverages ResNet. The ResNet defined in the paper is ResNet-50. This network is very efficient in extracting features from the given input. Although it is a big network, it tackles vanishing gradients using skip connections.
Here is the pseudo-code for the ResNet:
1class ResNetV2(nn.Module):
2 def __init__(self, block_units, width_factor):
3 .
4 .
5 self.root = '''it consists of conv layers, group norm, and an activation'''
6 self.body = '''it consists of pre-activation bottleneck block '''
7
8 def forward(self, x):
9 features = []
10 b, c, in_size, _ = x.size()
11 x = self.root(x)
12 features.append(x)
13 x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x)
14 for i in range(len(self.body)-1):
15 x = self.body[i][18]
16 right_size = int(in_size / 4 / (i+1))
17 if x.size()[2] != right_size:
18 pad = right_size - x.size()[2]
19 assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size)
20 feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device)
21 feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:]
22 else:
23 feat = x
24 features.append(feat)
25 x = self.body[-1][18]
26 return x, features[::-1]
27
ResNet defined consists of two essential modules:
Root block:
The idea of the root block is to standardize the given input image.
1nn.Sequential(OrderedDict([
2 ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)),
3 ('gn', nn.GroupNorm(32, width, eps=1e-6)),
4 ('relu', nn.ReLU(inplace=True)),
5 # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0))
6 ]))
7
Body block:
The body block contains a sequential layer of preactivation CNN module, a bottleneck block extracting vital information from the image. The preactivation block itself is a ResNet block.
1nn.Sequential(OrderedDict([
2 ('block1', nn.Sequential(OrderedDict(
3 [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] +
4 [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)],
5 )))
6
The ResNetV2 is then integrated into the Embedding module, which crops the image and feeds it into the transformers. See the figure below.
Source: TransUNet | Transformers Make Strong Encoders for Medical Image Segmentation
ResNetV2 provides an edge to the feature extraction process. This hybrid embedding module returns two outputs: the embedded features are passed into the transformers, while the features extracted from the ResNet are passed into the upsampling block.
In the forward function of the ResNetV2 module, you will find that the output from each CNN layer, which is located in the body block, is stored in a “features” list. This list is reversed once the output from all the CNN layers is stored, enabling the concatenation of global features for upsampling.
Source: TransUNet | Transformers Make Strong Encoders for Medical Image Segmentation
1class Embeddings(nn.Module):
2 """Construct the embeddings from the patch, position embeddings.
3 """
4 def __init__(self, config, img_size, in_channels=3):
5 super(Embeddings, self).__init__()
6 self.hybrid = None
7 self.config = config
8 img_size = _pair(img_size)
9
10 if config.patches.get("grid") is not None: # ResNet
11 grid_size = config.patches["grid"]
12 patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
13 patch_size_real = (patch_size[0] * 16, patch_size[1] * 16)
14 n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1])
15 self.hybrid = True
16 else:
17 patch_size = _pair(config.patches["size"])
18 n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
19 self.hybrid = False
20
21 if self.hybrid:
22 self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor)
23 in_channels = self.hybrid_model.width * 16
24 self.patch_embeddings = Conv2d(in_channels=in_channels,
25 out_channels=config.hidden_size,
26 kernel_size=patch_size,
27 stride=patch_size)
28 self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size))
29
30 self.dropout = Dropout(config.transformer["dropout_rate"])
31
32 def forward(self, x):
33 if self.hybrid:
34 x, features = self.hybrid_model(x)
35 else:
36 features = None
37 x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2))
38 x = x.flatten(2)
39 x = x.transpose(-1, -2) # (B, n_patches, hidden)
40
41 embeddings = x + self.position_embeddings
42 embeddings = self.dropout(embeddings)
43 return embeddings, features
44
Now we will understand how the two outputs from the embedding module, i.e., embeddings and features, can be processed to acquire the segmentation masks. The process of acquiring a segmentation mask will be carried out in the Vision Transformer. But it has to be built in the form of a UNet, which means that we have to add two additional components to the existing transformer architecture: the encoder module. The additional two components are a decoder and a segmentation head.
Let’s build each component step by step for better understanding:
Transformer or encoder module: This module will receive only the embedded output and then feed it to the attention block followed by the multilayer perceptron.
1class Transformer(nn.Module):
2 def __init__(self, config, img_size, vis):
3 super(Transformer, self).__init__()
4 self.embeddings = Embeddings(config, img_size=img_size)
5 self.encoder = Encoder(config, vis)
6
7 def forward(self, input_ids):
8 embedding_output, features = self.embeddings(input_ids)
9 encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden)
10 return encoded, attn_weights, features
11
As illustrated, since we have the ResNetV2 in the embedding, the transformers yield three outputs instead of two, as seen in earlier encoded features, attention weights, and convolutional features.
TransUnet
Now comes the main component, ViT or TransUnet. This module consists of three critical parts: Transformer, DecoderCup, and Segmentation head. We already saw the transformer architecture. Here we will explore the decoder cup and the segmentation head.
a) DecoderCup
The idea of the decoder cup is to upsample the given input.
1class DecoderCup(nn.Module):
2 def __init__(self, config):
3 super().__init__()
4 self.config = config
5 head_channels = 512
6 self.conv_more = Conv2dReLU(
7 config.hidden_size,
8 head_channels,
9 kernel_size=3,
10 padding=1,
11 use_batchnorm=True,
12 )
13 decoder_channels = config.decoder_channels
14 in_channels = [head_channels] + list(decoder_channels[:-1])
15 out_channels = decoder_channels
16
17 if self.config.n_skip != 0:
18 skip_channels = self.config.skip_channels
19 for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip
20 skip_channels[3-i]=0
21
22 else:
23 skip_channels=[0,0,0,0]
24
25 blocks = [
26 DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
27 ]
28 self.blocks = nn.ModuleList(blocks)
29
30 def forward(self, hidden_states, features=None):
31 B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
32 h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
33 x = hidden_states.permute(0, 2, 1)
34 x = x.contiguous().view(B, hidden, h, w)
35 x = self.conv_more(x)
36 for i, decoder_block in enumerate(self.blocks):
37 if features is not None:
38 skip = features[i] if (i < self.config.n_skip) else None
39 else:
40 skip = None
41 x = decoder_block(x, skip=skip)
42 return x
43
b) Segmentation head
This module inherits the nn.Sequential class, which by default has a forward() function. The function of this module is to upsample the final output yielded by the last layer of decodercup to create a segmentation mask.
1class SegmentationHead(nn.Sequential):
2
3 def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
4 conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
5 upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
6 super().__init__(conv2d, upsampling)
7
8
c) Vision Transformer
As usual, this module sequentially arranges all the modules together to create a seamless pipeline.
1class VisionTransformer(nn.Module):
2 def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
3 super(VisionTransformer, self).__init__()
4 self.num_classes = num_classes
5 self.zero_head = zero_head
6 self.classifier = config.classifier
7 self.transformer = Transformer(config, img_size, vis)
8 self.decoder = DecoderCup(config)
9 self.segmentation_head = SegmentationHead(
10 in_channels=config['decoder_channels'][-1],
11 out_channels=config['n_classes'],
12 kernel_size=3,
13 )
14 self.config = config
15
16 def forward(self, x):
17 if x.size()[1] == 1:
18 x = x.repeat(1,3,1,1)
19 x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden)
20 x = self.decoder(x, features)
21 logits = self.segmentation_head(x)
22 return logits
23
What is UWMGI Dataset?
The UWMGI dataset is the competition dataset released in Kaggle. The dataset represents gastrointestinal tract images, consisting of roughly 50 cases. The images in each case are a 16-bit grayscale in a PNG format. The segmentation masks representing three classes, stomach, small and large bowels, are provided as Run Length encoded or RLE-encoded masks.
You can check out this notebook that explains how you can download the data from Kaggle and upload it to Deep Lake after preprocessing it to the required Numpy format.
Source: Deep Lake
How to download UWMGI Dataset?
To download the data, you can follow the steps given below:
- Install the Python package: pip3 install deeplake
- Visit https://app.activeloop.ai/
- Select the data: You can search for the dataset you want to work within the search bar. In our case, it will be “UWMGI”
Once you find the dataset copy, the link provided.
Source: Deep Lake
- Access the data:
Now, let’s download the data from the Activeloop server by following the code below:
1import deeplake
2import torch
3from torchvision import transforms, models
4
5ds = deeplake.load('hub://perceptronai/UWMGI| training')
6
hub://perceptronai/UWMGI-1 loaded successfully.
This dataset can be visualized in Jupyter Notebook by ds.visualize() or at https://app.activeloop.ai/perceptronai/UWMGI-1
1tform_i = transforms.Compose([
2 transforms.ToTensor(),
3 transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
4])
5
6tform_m = transforms.Compose([
7 transforms.ToTensor()])
8
9deeplake_loader = ds.pytorch(batch_size=1, transform={ 'images': tform_i, 'masks': tform_m})
10
The ds.pytorch class works exactly like the data loader class from Pytorch. You can use the same methodology to get the images and masks. For instance,
1imgs, msks = next(iter(deeplake_loader))
2
Deep Lake also provides a way to visualize the images. You can use the simple command ds.visualize() to visualize the images.
Here are some additional practices that can help you to understand and discover once you download the data.
How to visualize the UWMGI Dataset?
Once you have downloaded the data, you can explore it by running this script: print(ds.tensors.keys())
dict_keys([‘images’, 'labels’, ‘masks’])
You can use the above keys to get the required output according to the task. Since we are working with segmentation, we want to explore ‘images’ and ‘masks’. We can then use these keys to visualize the image.
1plt.imshow(ds.tensors["images"][1].numpy())
2
Similarly, we can visualize the masks using the same command:
1plt.imshow(ds.tensors["masks"][0].numpy()[:, :, 0])
2
Final image after combining image and its respective segmentation.
Querying the Data
Data querying is an essential tool for ML practitioners as it can enable them to filter out necessary information according to the requirement to train the model. For instance, datasets such as the CoCo dataset have 91 classes, of which only 80 classes contain segmented masks. If we are building a DL model that has to be trained to classify only humans, then we can filter out images that only contain classes pertaining to humans.
Deep Lakes enables its users to filter out images in two ways, i.e., using UI and CLI. Let’s explore both of them in detail.
How to query datasets using Deep Lake UI?
To filter data using Activeloop’s UI, you only need to click on the “Run query” button, as shown in the image below.
Source: Deep Lake
Once you click on the button, you will get the SQL command line just below the button. You can explore the queries by clicking on the “Example Queries” button on the command line.
We will use the second command to filter out images that contain only humans. We can copy-paste the command into the SQL command line and press shift+enter to execute it.
Source: Deep Lake
You can then use the “Save query result” button to save the query.
Source: Deep Lake
After saving the query result, you can click on “Query history” to get the saved query.
Source: Deep Lake
After which, you can use the query id, which will serve as a Python API enabling you to train the DL model on the filtered dataset.
How to query datasets using Python?
Querying data using CLI is quite simple. Activeloop provides filtering with user-defined functions which allow you to filter data based on your requirements. Here are the steps which can enable filtering data.
a) Load the data.
1ds = deeplake.load('hub://perceptronai/UWMGI| training')
2
b) Create a list of objects that you want to filter.
1labels_list = ['stomach’, 'large bowel']
2
c) Create a function for filtering data using the deeplake.compute decorator.
1labels_list = ['stomach']
2
3@deeplake.compute
4def filter_labels(sample_in, labels_list):
5 # print(sample_in.labels.data()['text'][0])
6 return sample_in.labels.data()['text'][0] in labels_list
7
d) We can start filtering the data using the ds.filter function, where we will pass the filter_labels function.
1ds_view = ds.filter(filter_labels(labels_list)) ```
2
3 Another way to filter data is by using the ds.query function.
4
5```python
6print(ds.query("SELECT * WHERE CONTAINS(labels, 'stomach')"))
7
Dataset(path=’hub://perceptronai/UWMGI| training’, index=Index([(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99)]), tensors=['images’, 'labels’, ‘masks’])
Once the data is filtered, you can check the length of the data and even plot an image and its corresponding segmentation mask.
1print(len(ds_view))
2
100
1Image.fromarray(ds_view.images[1].numpy())
2
Once the data is filtered, you can use it to train your DL model.
Training and Testing Loop
Once the dataset is downloaded, we can then start defining the model. We will be finetuning the model since the authors have already uploaded the weights and biases in their repository. To do that, we will clone the repository. You can find the complete training code here.
1!git clone https://github.com/Beckschen/TransUNet
2
Once we have cloned the repository, we can then import the model.
1from TransUNet.networks.vit_seg_modeling import VisionTransformer as ViT_seg
2from TransUNet.networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg
3
We will then configure the model according to our requirements. We can define a class that will download the pre-trained weights and biases. We can also use the same class to define the model. The following code is an example of how we can download and configure the model:
1class TransUnet(nn.Module):
2 def __init__(self,
3 model_name=CFG.MODEL_NAME,
4 pretrain_path='./R50+ViT-B_16.npz',
5 n_classes=3,
6 n_skip=3,
7 dropout_rate=0.2,
8 mlp_dim=3072,
9 num_heads=12,
10 num_layers=8,
11 img_size=CFG.img_size[0]):
12
13 super(TransUnet, self).__init__()
14
15 #config
16 config_vit = CONFIGS_ViT_seg[model_name]
17 config_vit.pretrained_path = pretrain_path
18 config_vit.n_classes = n_classes
19 config_vit.n_skip = n_skip
20 config_vit.transformer.dropout_rate = dropout_rate
21 config_vit.transformer.mlp_dim = mlp_dim
22 config_vit.transformer.num_heads = num_heads
23 config_vit.transformer.num_layers = num_layers
24
25 #model
26 self.model = ViT_seg(config_vit, img_size=img_size, num_classes=n_classes)
27 self.model.load_from(weights=np.load(config_vit.pretrained_path))
28
29 def forward(self, x):
30 img_segs = self.model(x)
31
32 return img_segs
33
34model = TransUnet()
35
Let us define the other essential modules, including the loss function, optimizer, and scheduler. If your task requires you to predict binary segmentation masks, then you can use LovaszSoftmax, Hausdorff, FocalLoss, DiceLoss, or DiceBCELoss. These loss functions can be downloaded from the following git repo:
1!git clone https://github.com/JunMa11/SegLoss.git
2
But since our task requires us to predict multilabel segmentation, we will use CrossEntropyLoss.
1def Loss():
2 if CFG.criterion == 'Multiclass':
3 criterion = nn.CrossEntropyLoss()
4 elif CFG.criterion == 'DiceBCELoss':
5 criterion = DiceBCELoss()
6 elif CFG.criterion == 'DiceLoss':
7 criterion = DiceLoss()
8 elif CFG.criterion == 'FocalLoss':
9 criterion = FocalDLoss()
10 elif CFG.criterion == 'Hausdorff':
11 criterion = Hausdorff_loss()
12 elif CFG.criterion == 'Lovasz':
13 criterion = Lovasz_loss()
14 return criterion
15
We will use an Adam optimizer with a learning rate of 2e-3 and a scaler for the backward propagation and uniform scaling of gradients, respectively.
1optimizer = Adam(model.parameters(), lr=CFG.lr)
2scaler = torch.cuda.amp.GradScaler()
3
Now we can write our training loop and start the finetuning process.
1def train_engine(model, train_loader, device=CFG.device):
2 epoch_loss = []
3 loss_list = []
4 acc = []
5 num_correct = 0
6 num_pixels = 0
7
8 model = model.to(device)
9 model.train()
10 for epoch in range(20):
11 loop = tqdm(enumerate(train_loader),total=len(train_loader), desc="training")
12 for batch_idx, (X, y) in loop:
13 X = X.to(device)
14 y = y.to(device)
15 # forward
16 with torch.cuda.amp.autocast():
17 logits = model(X)
18 loss = Loss()(logits, y)
19
20 # backward
21 optimizer.zero_grad()
22 scaler.scale(loss).backward()
23 scaler.step(optimizer)
24 scaler.update()
25
26 #loss and accuracy
27 softmax = nn.Softmax(dim=1)
28 preds = softmax(logits)
29 preds = torch.argmax(preds,axis=1)
30 num_correct += (preds == y).sum()
31 num_pixels += torch.numel(preds)
32 loss_list.append(loss.item())
33 accuracy=num_correct.item()/num_pixels*100
34 acc.append(accuracy)
35 # update tqdm loop
36 loop.set_postfix(loss=loss.item(), accuracy=accuracy)
37 epoch_loss.append(loss.item())
38 return loss_list, epoch_loss, acc
39
40loss, epoch_loss, accuracy = train_engine(model, train_loader)
41
training: 79%|███████▊ | 23606/30000 [58:13<15:46, 6.76it/s, accuracy=67.8, loss=0.23]
The result after the first epoch:
One of the advantages of using a pre-trained model is that the model learns faster than training the model from scratch. As you can see from the image above, the model is starting to learn anomalies.
When we train the model for a more extended period, the model will start predicting accurate segmentation masks. Please check the following section for the final results.
Key Takeaways
Here are some of the key takeaways from this article:
- Pre-trained TransUNet can be used in radiology machine learning projects.
- The TransUnet architecture is SOTA regarding image segmentation, as it can model hard-to-find anomalies.
- Because the network uses both CNN and a self-attention mechanism, it can find local features and preserve them longer, thus enabling it to find features that would typically be out-of-reach.
- As seen below, the network can predict correct segmentation masks even for images that aren’t clear, showing the reliability of TransUnet.
- Deep Lake provides an efficient way to load, query, visualize, & stream the data for training and testing purposes.
- The queried dataset can be saved and later materialized, which enables reproducibility.
Concluding remarks: UNet vs TransUNet
This article showed how to finetune & build a deep-learning model for machine learning in radiology.UNet remains one of the primary architectures for the task of image segmentation, no matter what the field is. Especially in radiology, it can also impact lives. Accuracy score is an essential metric for developing UNet for medical purposes. Lee et al. introduced a variant of UNet predicting affinities between the nearest neighbors. The network can leverage this information from the derived affinities to create a segmentation mask by greedily clustering the mean affinity. This implies that UNet can achieve superhuman accuracy in predictive tasks.
Likewise, many methods, such as error detection and correction tasks (both important in image segmentation), may enhance the model’s accuracy.
TransUNet can leverage both of these techniques. We’ve already seen how TransUNet uses the multi-head self-attention mechanism to preserve long-range dependencies. Similarly, we can enhance the preserving power of TransUNet by predicting the affinity of the neighbor, clustering likable data points, and creating segmentation masks. This, combined with error detection, increases the accuracy of the model.
References
- Deep Lake: a Lakehouse for Deep Learning
- U-Net: Convolutional Networks for Biomedical Image Segmentation
- Attention Is All You Need
- An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale
- TransUNet | Transformers Make Strong Encoders for Medical Image Segmentation
- https://github.com/Beckschen/TransUNet
- https://github.com/JunMa11/SegLoss.git : Loss functions for image segmentation
- https://www.kaggle.com/stainsby/fast-tested-rle
- https://www.kaggle.com/competitions/uw-madison-gi-tract-image-segmentation/data
- https://www.kaggle.com/code/awsaf49/uwmgi-unet-train-pytorch
- Activeloop
- Superhuman Accuracy on the SNEMI3D Connectomics Challenge
- An Error Detection and Correction Framework for Connectomics
- Show, Attend and Tell: Neural Image Caption Generation with Visual Attention