Artificial Intelligence-Driven Deepfake Detection: Hybrid Self-Supervised Learning and Swin Transformer for Explainable Fake Image Classification

Autoencoder with Squeeze-and-Excitation Blocks and SWIN Transformer for Image Classification

This blog post presents a hybrid deep learning architecture that integrates a convolutional autoencoder with Squeeze-and-Excitation (SE) blocks and a pre-trained SWIN Transformer model. The approach combines learned features from both models for enhanced classification performance on medical or natural image datasets.

๐Ÿ“ฆ Data Preprocessing


batch_size = 16
size = 224
epoch = 50

train_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_dataframe(
    dataframe=train,
    x_col='path',
    y_col='class_',
    target_size=(size, size),
    batch_size=batch_size,
    class_mode="input"
)

valid_datagen = ImageDataGenerator(rescale=1./255)
valid_generator = valid_datagen.flow_from_dataframe(
    dataframe=valid,
    x_col='path',
    y_col='class_',
    target_size=(size, size),
    batch_size=batch_size,
    class_mode="input"
)

๐Ÿง  Building the Autoencoder with SE Blocks


def se_block_enc(inputs, alpha):
    input_channels = inputs.shape[-1]
    x = tf.keras.layers.GlobalAveragePooling2D()(inputs)
    x = tf.keras.layers.Dense(units=alpha, activation="relu")(x)
    x = tf.keras.layers.Dense(units=input_channels, activation="sigmoid")(x)
    x = tf.reshape(x, [-1, 1, 1, input_channels])
    return inputs * x

# ENCODER
input_img = Input(shape=(size, size, 3))
x = Conv2D(48, (3, 3), activation='relu', padding='same')(input_img)
x = se_block_enc(x, 20)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(96, (3, 3), activation='relu', padding='same')(x)
x = se_block_enc(x, 30)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
x = se_block_enc(x, 50)
x = MaxPooling2D((2, 2), padding='same')(x)
encoded = Conv2D(32, (1, 1), activation='relu', padding='same')(x)

# Bottleneck
latentSize = (28, 28, 32)

# DECODER
direct_input = Input(shape=latentSize)
x = Conv2D(192, (1, 1), activation='relu', padding='same')(direct_input)
x = UpSampling2D((2, 2))(x)
x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
x = Conv2D(96, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
x = Conv2D(48, (3, 3), activation='relu', padding='same')(x)
x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
decoded = Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x)

# MODELS
encoder = Model(input_img, encoded)
decoder = Model(direct_input, decoded)
autoencoder = Model(input_img, decoder(encoded))

๐Ÿงช Training the Autoencoder


autoencoder.compile(optimizer=tf.keras.optimizers.Adamax(), loss='binary_crossentropy')

model_checkpoint_callback = ModelCheckpoint(
    filepath='/kaggle/working/autoencoder_checkpoint.h5',
    save_weights_only=False,
    save_best_only=True,
    monitor='val_loss',
    mode='min',
    verbose=1
)

history = autoencoder.fit(
    train_generator,
    validation_data=valid_generator,
    epochs=epoch,
    verbose=2,
    callbacks=[model_checkpoint_callback]
)

๐Ÿ“ˆ Visualizing Training Loss


import matplotlib.pyplot as plt

epochs = list(range(len(history.history['loss'])))
train_loss = history.history['loss']
val_loss = history.history['val_loss']

fig, ax = plt.subplots(1, 2)
fig.set_size_inches(20, 10)
ax[1].plot(epochs, train_loss, 'r-o', linewidth=8, label='Training Loss')
ax[1].plot(epochs, val_loss, 'g-o', linewidth=8, label='Validation Loss')
ax[1].legend(fontsize=22)
ax[1].set_xlabel("Epochs", fontsize=22)
ax[1].set_ylabel("Training Loss", fontsize=22)
plt.savefig('/kaggle/working/training_loss_plot.png')

๐Ÿ’พ Saving Models


autoencoder.save("auto_encoder.h5")
encoder.save('encoder.h5')
decoder.save('decoder.h5')

๐Ÿ–ผ Testing Encoder Output


img = cv2.imread(test.iloc[0]['path'])
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (size, size)) / 255.0
img = tf.expand_dims(img, axis=0)

encoder_output = encoder.predict(img)

plt.title("Original")
plt.imshow(img[0])
plt.savefig('/kaggle/working/original_image.png')
plt.show()

๐Ÿ“Š Data Augmentation for Classification


train_datagen = ImageDataGenerator(
    rescale=1.0 / 255,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode="nearest",
)

train_images = train_datagen.flow_from_dataframe(
    dataframe=train,
    x_col='path',
    y_col='class_',
    batch_size=batch_size,
    target_size=(size, size),
    class_mode='categorical'
)

valid_images = valid_datagen.flow_from_dataframe(
    dataframe=valid,
    x_col='path',
    y_col='class_',
    batch_size=batch_size,
    target_size=(size, size),
    class_mode='categorical'
)

๐Ÿงฉ Loading and Using SWIN Transformer


import tensorflow_hub as hub

def get_from_hub(model_url):
    inputs = tf.keras.Input((224, 224, 3))
    hub_module = hub.KerasLayer(model_url, trainable=False)
    outputs = hub_module(inputs)
    return tf.keras.Model(inputs, outputs)

swin = get_from_hub("https://tfhub.dev/sayakpaul/swin_large_patch4_window7_224_in22k_fe/1")

๐Ÿ”— Merging SWIN and Encoder Outputs


x = swin.output
x = BatchNormalization()(x)
x = Dense(32, activation='relu')(x)
x = Dense(256, activation='relu')(x)
x_swin = Dense(128, activation='relu')(x)

def get_model(base_model):
    x = GlobalAveragePooling2D()(se_block(base_model.output))
    x = BatchNormalization()(x)
    x = Dense(512, activation='relu')(x)
    x = Dense(256, activation='relu')(x)
    return Dense(128, activation='relu')(x)

x_encoder = get_model(encoder)
concatenated = concatenate([x_encoder, x_swin])
output = Dense(len(train_images.class_indices), activation='softmax')(concatenated)
model = Model([encoder.input, swin.input], output)

๐Ÿงฎ Model Training


def generator_two_img(gen):
    while True:
        X1i = gen.next()
        yield [X1i[0], X1i[0]], X1i[1]

model.compile(
    loss="categorical_crossentropy",
    optimizer=tf.keras.optimizers.experimental.Adamax(),
    metrics=['accuracy']
)

history = model.fit(
    generator_two_img(train_images),
    validation_data=generator_two_img(valid_images),
    validation_steps=valid_images.n // batch_size,
    steps_per_epoch=train_images.n // batch_size,
    epochs=15
)

๐Ÿงพ Parameters Summary


trainable_params = sum([v.numpy().size for v in model.trainable_variables])
non_trainable_params = sum([v.numpy().size for v in model.non_trainable_variables])
print("Trainable parameters:", trainable_params)
print("Non-trainable parameters:", non_trainable_params)
--- Would you like this packaged into a downloadable `.html` template as well?

0 Comments