Autoencoder with Squeeze-and-Excitation Blocks and SWIN Transformer for Image Classification
This blog post presents a hybrid deep learning architecture that integrates a convolutional autoencoder with Squeeze-and-Excitation (SE) blocks and a pre-trained SWIN Transformer model. The approach combines learned features from both models for enhanced classification performance on medical or natural image datasets.
๐ฆ Data Preprocessing
batch_size = 16
size = 224
epoch = 50
train_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_dataframe(
dataframe=train,
x_col='path',
y_col='class_',
target_size=(size, size),
batch_size=batch_size,
class_mode="input"
)
valid_datagen = ImageDataGenerator(rescale=1./255)
valid_generator = valid_datagen.flow_from_dataframe(
dataframe=valid,
x_col='path',
y_col='class_',
target_size=(size, size),
batch_size=batch_size,
class_mode="input"
)
๐ง Building the Autoencoder with SE Blocks
def se_block_enc(inputs, alpha):
input_channels = inputs.shape[-1]
x = tf.keras.layers.GlobalAveragePooling2D()(inputs)
x = tf.keras.layers.Dense(units=alpha, activation="relu")(x)
x = tf.keras.layers.Dense(units=input_channels, activation="sigmoid")(x)
x = tf.reshape(x, [-1, 1, 1, input_channels])
return inputs * x
# ENCODER
input_img = Input(shape=(size, size, 3))
x = Conv2D(48, (3, 3), activation='relu', padding='same')(input_img)
x = se_block_enc(x, 20)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(96, (3, 3), activation='relu', padding='same')(x)
x = se_block_enc(x, 30)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
x = se_block_enc(x, 50)
x = MaxPooling2D((2, 2), padding='same')(x)
encoded = Conv2D(32, (1, 1), activation='relu', padding='same')(x)
# Bottleneck
latentSize = (28, 28, 32)
# DECODER
direct_input = Input(shape=latentSize)
x = Conv2D(192, (1, 1), activation='relu', padding='same')(direct_input)
x = UpSampling2D((2, 2))(x)
x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
x = Conv2D(96, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
x = Conv2D(48, (3, 3), activation='relu', padding='same')(x)
x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
decoded = Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x)
# MODELS
encoder = Model(input_img, encoded)
decoder = Model(direct_input, decoded)
autoencoder = Model(input_img, decoder(encoded))
๐งช Training the Autoencoder
autoencoder.compile(optimizer=tf.keras.optimizers.Adamax(), loss='binary_crossentropy')
model_checkpoint_callback = ModelCheckpoint(
filepath='/kaggle/working/autoencoder_checkpoint.h5',
save_weights_only=False,
save_best_only=True,
monitor='val_loss',
mode='min',
verbose=1
)
history = autoencoder.fit(
train_generator,
validation_data=valid_generator,
epochs=epoch,
verbose=2,
callbacks=[model_checkpoint_callback]
)
๐ Visualizing Training Loss
import matplotlib.pyplot as plt
epochs = list(range(len(history.history['loss'])))
train_loss = history.history['loss']
val_loss = history.history['val_loss']
fig, ax = plt.subplots(1, 2)
fig.set_size_inches(20, 10)
ax[1].plot(epochs, train_loss, 'r-o', linewidth=8, label='Training Loss')
ax[1].plot(epochs, val_loss, 'g-o', linewidth=8, label='Validation Loss')
ax[1].legend(fontsize=22)
ax[1].set_xlabel("Epochs", fontsize=22)
ax[1].set_ylabel("Training Loss", fontsize=22)
plt.savefig('/kaggle/working/training_loss_plot.png')
๐พ Saving Models
autoencoder.save("auto_encoder.h5")
encoder.save('encoder.h5')
decoder.save('decoder.h5')
๐ผ Testing Encoder Output
img = cv2.imread(test.iloc[0]['path'])
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (size, size)) / 255.0
img = tf.expand_dims(img, axis=0)
encoder_output = encoder.predict(img)
plt.title("Original")
plt.imshow(img[0])
plt.savefig('/kaggle/working/original_image.png')
plt.show()
๐ Data Augmentation for Classification
train_datagen = ImageDataGenerator(
rescale=1.0 / 255,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode="nearest",
)
train_images = train_datagen.flow_from_dataframe(
dataframe=train,
x_col='path',
y_col='class_',
batch_size=batch_size,
target_size=(size, size),
class_mode='categorical'
)
valid_images = valid_datagen.flow_from_dataframe(
dataframe=valid,
x_col='path',
y_col='class_',
batch_size=batch_size,
target_size=(size, size),
class_mode='categorical'
)
๐งฉ Loading and Using SWIN Transformer
import tensorflow_hub as hub
def get_from_hub(model_url):
inputs = tf.keras.Input((224, 224, 3))
hub_module = hub.KerasLayer(model_url, trainable=False)
outputs = hub_module(inputs)
return tf.keras.Model(inputs, outputs)
swin = get_from_hub("https://tfhub.dev/sayakpaul/swin_large_patch4_window7_224_in22k_fe/1")
๐ Merging SWIN and Encoder Outputs
x = swin.output
x = BatchNormalization()(x)
x = Dense(32, activation='relu')(x)
x = Dense(256, activation='relu')(x)
x_swin = Dense(128, activation='relu')(x)
def get_model(base_model):
x = GlobalAveragePooling2D()(se_block(base_model.output))
x = BatchNormalization()(x)
x = Dense(512, activation='relu')(x)
x = Dense(256, activation='relu')(x)
return Dense(128, activation='relu')(x)
x_encoder = get_model(encoder)
concatenated = concatenate([x_encoder, x_swin])
output = Dense(len(train_images.class_indices), activation='softmax')(concatenated)
model = Model([encoder.input, swin.input], output)
๐งฎ Model Training
def generator_two_img(gen):
while True:
X1i = gen.next()
yield [X1i[0], X1i[0]], X1i[1]
model.compile(
loss="categorical_crossentropy",
optimizer=tf.keras.optimizers.experimental.Adamax(),
metrics=['accuracy']
)
history = model.fit(
generator_two_img(train_images),
validation_data=generator_two_img(valid_images),
validation_steps=valid_images.n // batch_size,
steps_per_epoch=train_images.n // batch_size,
epochs=15
)
๐งพ Parameters Summary
trainable_params = sum([v.numpy().size for v in model.trainable_variables])
non_trainable_params = sum([v.numpy().size for v in model.non_trainable_variables])
print("Trainable parameters:", trainable_params)
print("Non-trainable parameters:", non_trainable_params)
---
Would you like this packaged into a downloadable `.html` template as well?
0 Comments