Extending Keras' ImageDataGenerator to Support Random Cropping

Quick link to my GitHub code: https://github.com/jkjung-avt/keras-cats-dogs-tutorial

Keras’ ‘ImageDataGenerator’ supports quite a few data augmentation schemes and is pretty easy to use. In the previous post, I took advantage of ImageDataGenerator’s data augmentations and was able to build the Cats vs. Dogs classififer with 99% validation accuracy, trained with relatively few data. However, the ImageDataGenerator lacks one important functionality which I’d really like to use: random cropping.

After crawling the web for a while, I was able to come up with a simple solution to the problem. The solution allows me to use all data augmentation functionalities in the original ‘ImageDataGenerator’, while adding random cropping to the mix. Here’s how I imeplemented it.

To begin with, I’d like to say I was deeply inspired by this StackOverflow discussion: Data Augmentation Image Data Generator Keras Semantic Segmentation. By following the example code within, I developed a crop_generator which takes batch (image) data from ‘ImageDataGenerator’ and does random cropping on the batch.

def random_crop(img, random_crop_size):
    # Note: image_data_format is 'channel_last'
    assert img.shape[2] == 3
    height, width = img.shape[0], img.shape[1]
    dy, dx = random_crop_size
    x = np.random.randint(0, width - dx + 1)
    y = np.random.randint(0, height - dy + 1)
    return img[y:(y+dy), x:(x+dx), :]


def crop_generator(batches, crop_length):
    """Take as input a Keras ImageGen (Iterator) and generate random
    crops from the image batches generated by the original iterator.
    """
    while True:
        batch_x, batch_y = next(batches)
        batch_crops = np.zeros((batch_x.shape[0], crop_length, crop_length, 3))
        for i in range(batch_x.shape[0]):
            batch_crops[i] = random_crop(batch_x[i], (crop_length, crop_length))
        yield (batch_crops, batch_y)

In my example train_cropped.py code, I used ImageDataGenerator.flow_from_directory() to resize all input images to (256, 256) and then use my own crop_generator to generate random (224, 224) crops from the resized images. Note that the resized (256, 256) images were processed ‘ImageDataGenerator’ already and thus had gone through all data augmentations such as random rotation, shifting, shearing, flipping, etc.

train_datagen = ImageDataGenerator(......)
train_batches = train_datagen.flow_from_directory(DATASET_PATH + '/train',
                                                  target_size=(256,256),
                                                  ......)
train_crops = crop_generator(train_batches, 224)
......
net_final.fit_generator(train_crops, ......)

The full train_cropped.py code could be found here. I trained a ResNet50 model with cropped images. The result was on par with the non-cropped version, i.e. 99% validation accuracy.

$ cd ~/project/keras-cats-dogs-tutorial
$ python3 train_cropped.py
......
Found 2000 images belonging to 2 classes.
Found 800 images belonging to 2 classes.
****************
Class #0 = cats
Class #1 = dogs
****************
......
Epoch 20/20
250/250 [==============================] - 35s 139ms/step - loss: 0.0316 - acc: 0.9870 - val_loss: 0.0276 - val_acc: 0.9900
......

In the predict_cropped.py script, I used ‘center crop’ for prediction. The full code is also on my GitHub repository.

Finally, I did look at a few images generated by my crop_generator. Note that the crops were preprocessed by ResNet50’s preprocess_input() so I had to add pixel_mean back to the crops before plotting them. They looked as expected (cropped)…

image patches generated by crop_generator

blog built using the cayman-theme by Jason Long. LICENSE