Skin-Lesion Segmentation using Boundary-Aware Segmentation Network and Classification based on a Mixture of Convolutional and Transformer Neural Networks

Open in Google Colab

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.manifold import TSNE
import os
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from sklearn.model_selection import train_test_split

# Hyperparameters
positional_emb = True
conv_layers = 2
projection_dim = 128
num_heads = 2
transformer_units = [projection_dim, projection_dim]
transformer_layers = 2
stochastic_depth_rate = 0.1
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 128
num_epochs = 100
image_size = 32
num_classes = 3
input_shape = (32, 32, 3)

# Load and process CIFAR-10 dataset
cifar10_folder_path = 'input/'

def load_and_process_images(folder_path):
    images, labels = [], []
    for class_label in os.listdir(folder_path):
        class_path = os.path.join(folder_path, class_label)
        if os.path.isdir(class_path):
            for image_name in os.listdir(class_path):
                image_path = os.path.join(class_path, image_name)
                img = load_img(image_path, target_size=input_shape[:2])
                img_array = img_to_array(img)
                images.append(img_array)
                labels.append(int(class_label))
    return np.array(images), np.array(labels)

x_data, y_data = load_and_process_images(cifar10_folder_path)
x_train, x_test, y_train, y_test = train_test_split(x_data, y_data, test_size=0.4, random_state=42)
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")

# Define CCT Tokenizer class
class CCTTokenizer(layers.Layer):
    def __init__(self, kernel_size=3, stride=1, padding=1, pooling_kernel_size=3, pooling_stride=2,
                 num_conv_layers=conv_layers, num_output_channels=[64, 128], positional_emb=positional_emb, **kwargs):
        super().__init__(**kwargs)
        self.conv_model = keras.Sequential()
        for i in range(num_conv_layers):
            self.conv_model.add(
                layers.Conv2D(num_output_channels[i], kernel_size, stride, padding="valid", use_bias=False,
                              activation="relu", kernel_initializer="he_normal")
            )
            self.conv_model.add(layers.ZeroPadding2D(padding))
            self.conv_model.add(layers.MaxPooling2D(pooling_kernel_size, pooling_stride, "same"))

    def call(self, images):
        outputs = self.conv_model(images)
        reshaped = tf.reshape(outputs, (-1, tf.shape(outputs)[1] * tf.shape(outputs)[2], tf.shape(outputs)[-1]))
        return reshaped

# Define PositionEmbedding class
class PositionEmbedding(keras.layers.Layer):
    def __init__(self, sequence_length, initializer="glorot_uniform", **kwargs):
        super().__init__(**kwargs)
        self.sequence_length = int(sequence_length)
        self.initializer = keras.initializers.get(initializer)

    def build(self, input_shape):
        feature_size = input_shape[-1]
        self.position_embeddings = self.add_weight(
            name="embeddings", shape=[self.sequence_length, feature_size],
            initializer=self.initializer, trainable=True
        )

    def call(self, inputs, start_index=0):
        position_embeddings = tf.convert_to_tensor(self.position_embeddings)
        return tf.math.add(inputs, position_embeddings)

# Define SequencePooling class
class SequencePooling(layers.Layer):
    def __init__(self):
        super().__init__()
        self.attention = layers.Dense(1)

    def call(self, x):
        attention_weights = tf.nn.softmax(self.attention(x), axis=1)
        attention_weights = tf.transpose(attention_weights, perm=(0, 2, 1))
        weighted_representation = tf.matmul(attention_weights, x)
        return tf.squeeze(weighted_representation, axis=-2)

# Define StochasticDepth class
class StochasticDepth(layers.Layer):
    def __init__(self, drop_prob, **kwargs):
        super().__init__(**kwargs)
        self.drop_prob = drop_prob

    def call(self, x, training=None):
        if training:
            random_tensor = tf.random.uniform(tf.shape(x))
            random_tensor = tf.cast(random_tensor >= self.drop_prob, tf.float32)
            return x * random_tensor / (1 - self.drop_prob)
        return x

# Define MLP function
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.keras.activations.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

# Create CCT Model
def create_cct_model():
    inputs = layers.Input(input_shape)
    augmented = layers.Rescaling(scale=1.0 / 255)(inputs)
    cct_tokenizer = CCTTokenizer()
    encoded_patches = cct_tokenizer(augmented)

    if positional_emb:
        sequence_length = encoded_patches.shape[1]
        encoded_patches += PositionEmbedding(sequence_length=sequence_length)(encoded_patches)

    dpr = [x for x in np.linspace(0, stochastic_depth_rate, transformer_layers)]
    for i in range(transformer_layers):
        x1 = layers.LayerNormalization(epsilon=1e-5)(encoded_patches)
        attention_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim, dropout=0.1)(x1, x1)
        attention_output = StochasticDepth(dpr[i])(attention_output)
        x2 = layers.Add()([attention_output, encoded_patches])
        x3 = layers.LayerNormalization(epsilon=1e-5)(x2)
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        x3 = StochasticDepth(dpr[i])(x3)
        encoded_patches = layers.Add()([x3, x2])

    representation = layers.LayerNormalization(epsilon=1e-5)(encoded_patches)
    weighted_representation = SequencePooling()(representation)
    logits = layers.Dense(num_classes)(weighted_representation)
    return keras.Model(inputs=inputs, outputs=logits)

# Run Experiment
def run_experiment(model):
    optimizer = keras.optimizers.AdamW(learning_rate=learning_rate, weight_decay=weight_decay)
    model.compile(
        optimizer=optimizer,
        loss=keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1),
        metrics=[keras.metrics.CategoricalAccuracy(name="accuracy"),
                 keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy")]
    )

    checkpoint_callback = keras.callbacks.ModelCheckpoint("/tmp/checkpoint.weights.h5", monitor="val_accuracy",
                                                          save_best_only=True, save_weights_only=True)

    history = model.fit(x_train, y_train, batch_size=batch_size, epochs=num_epochs,
                        validation_split=0.1, callbacks=[checkpoint_callback])

    model.load_weights("/tmp/checkpoint.weights.h5")
    _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    return history

cct_model = create_cct_model()
history = run_experiment(cct_model)

0 Comments