Training Keras Models with TFRecords and The tf.data API

Quick link: jkjung-avt/keras_imagenet

One of the challenges in training CNN models with a large image dataset lies in building an efficient data ingestion pipeline. Without that, the GPU’s could be constantly starving for data and thus training goes slowly. In this post, I’m sharing my experience in training Keras image classification models with tensorflow’s TFRecords and tf.data API. I think I train the models much more efficiently this way than reading original jpg files from the file system.

More specifically, I share the code I used to train Keras ImageNet (ILSVRC2012: 1,000 classes) image classification models. And I try to explain my use of TFRecords and the tf.data.TFRecordDataset API below.

Reference

About TFRecords

Quote from tensorflow documentation:

To read data efficiently it can be helpful to serialize your data and store it in a set of files (100-200MB each) that can each be read linearly. This is especially true if the data is being streamed over a network. This can also be useful for caching any data-preprocessing.

The TFRecord format is a simple format for storing a sequence of binary records.

Check out the last linked article in the ‘Reference’ section. In short, image data (especially large amount of data) could be read from disk much more efficientlt if the data is stored as aggregated and serialized database/records file(s), rather than as separate jpg files.

So TFRecords would be the format I use for training Keras models discussed in this post.

Creating TFRecords for ImageNet (ILSVRC2012) training data

Please check out the ‘Step-by-step’ guide in my jkjung-avt/keras_imagenet repository for how to create the TFRecords files. I mainly took and modified the build_imagenet_data.py script from tensorflow’s inception model code.

The script splits the training set (1,281,167 images) into 1,024 shards, and the validation set (50,000 images) into 128 shards. When done, each shard file would contain roughly the same number of jpg files. The image data in the shard files stays jpg encoded, otherwise the TFRecords files would take too much space.

When done, contents of my ${HOME}/data/ILSVRC2012/tfrecords/ directory are:

  train-00000-of-01024
  train-00001-of-01024
  ...
  train-01023-of-01024
  validation-00000-of-00128
  validation-00001-of-00128
  ...
  validation-00127-of-00128

Reading TFRecords and creating randomly shuffled data while training

To be more precise, we would want to parallelize the tasks of reading data from TFRecords files, randomizing the data, and data augmentation efficiently. TensorFlow’s Data Input Pipeline Performance documentation roughly describes how to do this. However, I found the code samples in that document a little bit confusing.

After some research, I found mrry’s suggestion on GitHub most helpful of achieving what I’d like to do. I ended up doing the following:

  1. Create a tf.data.Dataset which is a list of the TFRecords (shard) file names: either ‘train-xxxxx-of-01024’ or ‘validation-xxxxx-of-00128’.
  2. Next, shuffle() and repeat() the shards Dataset. So shards would generate shard file names in random order and indefinitely.
  3. Feed and interleave() (randomize more) the shards Dataset into tf.data.TFRecordsDataset. This results in a TFRecordsDataset which reads from shard files in random order.
  4. Shuffle the TFRecordsDataset so that the order of training images within a shard is randomized in each epoch.
  5. Parse and deserialize the TFRecords, with ‘prefetching’. I set num_parallel_calls to 4 on my desktop. You could adjust the value you want more parallel wrokers to do data generation and augmentation. And finally note that I implement jpg decoding and data augmentation in the deserialization function (parse_fn_train).

You could find my full python code implementation here.

  def get_dataset(tfrecords_dir, subset, batch_size):
      """Read TFRecords files and turn them into a TFRecordDataset."""
      files = tf.matching_files(os.path.join(tfrecords_dir, '%s-*' % subset))
      shards = tf.data.Dataset.from_tensor_slices(files)
      shards = shards.shuffle(tf.cast(tf.shape(files)[0], tf.int64))
      shards = shards.repeat()
      dataset = shards.interleave(tf.data.TFRecordDataset, cycle_length=4)
      dataset = dataset.shuffle(buffer_size=8192)
      parser = parse_fn_train if subset == 'train' else parse_fn_valid
      dataset = dataset.apply(
          tf.data.experimental.map_and_batch(
              map_func=parser,
              batch_size=batch_size,
              num_parallel_calls=config.NUM_DATA_WORKERS))
      dataset = dataset.prefetch(batch_size)
      return dataset

Training Keras CNN model with TFRecordsDataset

According to official documentation, tf.keras.Model’s fit() method could take “a tf.data dataset or a dataset iterator” as input. The dataset or iterator “should return a tuple of either (inputs, targets) or (inputs, targets, sample_weights).”

So after compiling the training model, I could just call model.fit() with the TFRecordsDataset implemented as described above. You could reference my source code here.

      # get training and validation data
      ds_train = get_dataset(dataset_dir, 'train', batch_size)
      ......

      model.fit(
          x=ds_train,
          ......

Preliminary Results

To recap, I’ve explained how I use sharded TFRecords for efficient I/O on the disk, as well as how to use tf.data.TFRecordDataset to ingest training data when training Keras CNN models. I take advantage of tf.data’s capabilities of processing data with multiple workers and shuffling/prefetching data on the fly. Furthermore, I do online data augmentation when deserializing TFRecords. This again takes advantage of multiple workers doing data fetching and processing in parallel. I think I achieve very good data ingestion performance this way.

I’m still in the process of training a Keras MobileNetv2 and a Keras ResNet50 models with the code. I hope to share the results when the trainings are done. Although I haven’t done proper benchmarking, I’m pretty sure that using TFRecordsDataset (with 4 parallel data workers) speeds up the training quite a bit comparing to using original jpg files.

2019-11-24 Update: I’ve written a new post about how I visualized and verified training images in TensorBoard: Displaying Images in TensorBoard

blog built using the cayman-theme by Jason Long. LICENSE