credit: imgaug
by Cristian Duguet, Adrian Hutter, Nandika Kalra, Arshak Navruzyan
Overview
Data augmentation of images is a widely adopted technique to improve models generalization [3] by applying transformations (or sets of combined transformations) during training and optionally at test time. Often practitioners look for the best set of transformations manually, either by relying on domain expertise or making assumptions about generally useful transformations. This process of manual search for optimal augmentations can be time consuming and compute intensive.
The fast.ai library does a great job in providing smart defaults in other areas like setting hyperparameters, model architecture (e.g. custom head) and augmentations are no exception. After a lot of experimentation, the fast.ai team has found a standard set of augmentations which is applied indiscriminately to every dataset, and has proven to work effectively most of the time. However, there may be room for improvement if the appropriate augmentations can be found for a given dataset in a fast way.
Over the last months, our team has been working to create a method to automate the search for the best augmentation set, in a computationally efficient manner, and with as little domain-specific input as possible. The purpose of this research aligns with the goals of platform.ai, which seeks to offer deep learning to a wider non-engineer audience.
Current approaches
There are a few approaches to automatically finding an optimal set of augmentations, but the most common paradigm is to generate various augmented datasets and fully or partially train “child models” with each set to determine impact on performance. The central focus of these approaches is to make the search as efficient as possible through gaussian process, reinforcement learning and other search methods.
One obvious disadvantage of this formulation is that they are not computationally cheap. In order to achieve even a modest improvement, thousands or tens of thousands of child models have to be trained and evaluated.
A more novel approach is from Ratner [5] where they use a GAN to find the set of transformation parameters to create augmented data which lies within a defined distribution of interest, representative of the training set. By treating the augmentation search as a sequence modeling problem, this approach attempts to find not only the right augmentations but also their parameterization, which can be non-differentiable and hence requires a reinforcement learning loop in addition to the generative and discriminative model training.
We are looking for a more computationally efficient way (at the sacrifice of any performance due to composability) for automatically finding a set of parameters for the image transform functions provided by the fast.ai library.
Methodology
The fastai library provides different types of image transformations which are used for data augmentation. These transformations can be grouped in affine, pixel, cropping, lighting and coordinate transformations. The default augmentation set, which is obtained by calling get_transforms() has the following parameters:
Table 1: List of transformations available in fast.ai. There are more transforms available in their documentation, but we focused on the ones listed
Validation
The performance of the training set augmented with the chosen transformations will be compared to the performance of the training set augmented with the default fast.ai augmentation set, by training both networks for a determined routine. This will be done for different datasets. The performance metric is the error rate on the validation set.
Datasets
A variety of datasets and domains were selected to assess the robustness of this method. A group of dataset, which could be representative of very different use cases was gathered and the following 6 were selected.
- Oxford-IIIT Pet Dataset
- Stanford Dogs Dataset
- Planet: Understanding the Amazon from Space
- CIFAR-10
- Kuzushiji-MNIST
- Food-101
TTA Search
The Test Time Augmentation is a technique to leverage the use of augmentation for prediction instead of training. With TTA, the probability of prediction of an image would comprise 2 different predictions:
Where and are the predicted values for the original and the augmented image, respectively, and is a weight factor.
We have observed that the Test Time Augmentation can be a good indicator of when a certain augmentation is a good candidate for training. This means, if the average of all TTA predictions on a set using certain augmentation had higher accuracy than the the average of normal predictions, then training the network with that augmentation will most likely improve the performance of the network.
Figure 1: Normalized error rate of a trained dataset with a certain augmentation vs normalized TTA for that augmentation as well. The normalization is the relative change to the baseline case. Each point is a different augmentation tried on the network, each color is an experiment on a different dataset.
We identified that very harmful transformations for training, which would be characterized by a high err/err_none had a very TTA_err/TTA_err_none value. For the Pets and Dogs dataset, for example, it was the dihedral transform, while for Planet it was resize_crop. Most transforms of our set were detrimental for CIFAR-10, this may be due to the composability benefit discussed earlier [5].
Based on the observed behaviour, we search for a certain augmentation set using TTA, in a procedure which works as follows.
- Split the training set into two subsets of size 80% and 20%, respectively.
- Train the last layer group on 80% of the training set for EPOCHS_HEAD epochs, without any data augmentation.
- Calculate the error rate ERR_NONE on the remaining 20% of the training set.
- For each kind of transformation, for each possible magnitude, calculate the TTA error rate on the remaining 20% of the training set. For TTA, we base predictions on WEIGHT_UNTRANSFORMED * LOGITS_UNTRANSFORMED + (1 - WEIGHT_UNTRANSFORMED) * LOGITS_TRANSFORMED. Where WEIGHT_UNTRANSFORMED describes the amount of influence the augmentation has on the prediction.
- For each kind of transformation, choose the magnitude which leads to the lowest TTA error rate, if that error rate is lower than THRESHOLD * ERR_NONE; otherwise, don't include that kind of transformation in the final set of augmentations.
- With the chosen set of augmentations, train the head for EPOCHS_HEAD epochs and the full network for EPOCHS_FULL.
- As a baseline, train the network for the same number of epochs using the transforms provided by get_transforms().
Out of the transformations available in the fast.ai library, we have tested our method with the following transforms/parameters list
Table 2: List of tested augmentations using our search method, including a list of the tested parameters.
* that is the probability with which each transform was used for final training. For the TTA search method the probability was 1.0.
Findings
We have found that the Test Time Augmentation for transformation delivers information about the performance improvement with a particular augmentation, and helps rapidly deciding on image transformations which are constructive for better generalizing the network.
Table 3 shows the performance improvement by this method, in comparison to the baseline case. The list of augmentations picked out for each dataset are detailed in Table 4.
Table 3: Top-1 error rates for the found augmentation sets.
Table 4 : List of selected augmentations and its parameters, for each dataset. The values between brackets represent uniform distribution of the RandTransform class.
It is worth noting that the augmentations picked by out method seemed qualitatively reasonable. For example, for the Planet dataset it chooses dihedral flips (which might include upside-down flips), while for Kuzushiji-MNIST it chooses neither left-right nor dihedral flips, since any of these flips would be damaging for the transformation.
With more time, it may be worth investigating how consistent the differences between the error rates are. The differences between the TTA-based selection of augmentations and get_transforms() might be smaller than the differences between different runs for the same set of augmentations.
Acknowledgements
We would to thank Jeremy Howard for his insightful guidance and David Tedaldi for mentorship during this project.
References
- Cubuk, Ekin D., et al. AutoAugment: Learning Augmentation Policies from Data. arXiv preprint arXiv:1805.09501 (2018).
- Geng, Mingyang, et al. Learning data augmentation policies using augmented random search. arXiv preprint arXiv:1811.04768 (2018).
- Perez, Luis, et al. The Effectiveness of Data Augmentation in Image Classification using Deep Learning. arXiv preprint arXiv:1712.04621 (2017).
- Krizhevsky, Alex, et al. ImageNet Classification with Deep Convolutional Neural Networks. Neural Information Processing Systems. 25. 10.1145/3065386 (2012).
- Ratner, Alexander J., et al. Learning to compose domain-specific transformations for data augmentation. Advances in neural information processing systems. 2017.