AI Tool Building
Patrick Rodriguez | Posted on Wed 01 February 2017 in programming
Accelerating Deep Learning with Multiprocess Image Augmentation in Keras¶
from IPython.display import display, Image; display(Image('./results.png'))
Introduction¶
TLDR: By adding multiprocessing support to Keras ImageDataGenerator, benchmarking on a 6-core i7-6850K and 12GB TITAN X Pascal: 3.5x speedup of training with image augmentation on in memory datasets, 3.9x speedup of training with image augmentation on datasets streamed from disk.
When exploring Deep Learning models, it isn't only beneficial to have good performance for the final training run. Accelerating training speed means more network models can be tried and more hyperparameter settings can be explored in the same amount of time. The more that we can experiment, the better our results can become.
In my experience with training a moderately sized network on my home desktop, I found one bottleneck to be creating additional images to augment my dataset. Keras provides an ImageDataGenerator class that can take images, in memory or on disk, and create many different variations based on a set of parameters: rotations, flips, zooms, altering colors, etc. For reference, here is a great tutorial on improving network accuracy with image augmentation.
While training my initial models, I was waiting upwards of an entire day to see enough results to decide what to change. I saw that I was taking nowhere near full advantage of my CPU or GPU. As a result, I decided to add some Python multiprocessing support to a fork of ImageDataGenerator. I was able to drastically cut my training time and was finally able to steer my experiments in the right direction!
For reference, I am using:
- Intel Core i7-6850K
- NVIDIA TITAN X Pascal 12GB
- 96GB RAM
- 64-bit Ubuntu 16.04
- Python 2.7.13 :: Continuum Analytics, Inc.
- Keras 1.2.1
- Tensorflow 0.12.1
You can use the multiprocessing-enabled ImageDataGenerator that is included with this repo as a drop-in replacement for the version that currently ships with Keras. If it makes sense, the code may get incorporated into the main branch at some point.
import numpy as np
import pandas as pd
import keras as K
import matplotlib.pyplot as plt
import multiprocessing
import time
import collections
import sys
import signal
%matplotlib inline
# The original class can be imported like this:
# from keras.preprocessing.image import ImageDataGenerator
# We access the modified version through T.ImageDataGenerator
import tools.image_gen_2 as T
# Useful for checking the output of the generators after code change
try:
from importlib import reload
reload(T)
except:
reload(T)
These are helper methods used throughout the notebook.
def preprocess_img(img):
img = img.astype(np.float32) / 255.0
img -= 0.5
return img * 2
def plot_images(img_gen, title):
fig, ax = plt.subplots(6, 6, figsize=(10, 10))
plt.suptitle(title, size=32)
plt.setp(ax, xticks=[], yticks=[])
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
for (imgs, labels) in img_gen:
for i in range(6):
for j in range(6):
if i*6 + j < 32:
ax[i][j].imshow(imgs[i*6 + j])
break
Benchmark: CIFAR10 - In Memory Performance, Image Generation Only¶
CIFAR10 is a toy dataset that includes 50,000 training images and 10,000 test images of shape 32x32x3.
It includes the following 10 classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck
from keras.datasets.cifar10 import load_data
from keras.utils.np_utils import to_categorical
(X_train, y_train), (X_test, y_test) = load_data()
y_train_cat = to_categorical(y_train)
y_test_cat = to_categorical(y_test)
Here is an example of how to set up a multiprocessing.Pool
and add it as an argument to the ImageDataGenerator constructor. This is the only change to the class' public interface. If you leave out the pool
parameter or set it to None
, the generator will operate in its original single process mode.
try:
pool.terminate()
except:
pass
n_process = 4
pool = multiprocessing.Pool(processes=n_process)
start = time.time()
gen = T.ImageDataGenerator(
featurewise_center=False,
samplewise_center=False,
featurewise_std_normalization=False,
samplewise_std_normalization=False,
zca_whitening=False,
rotation_range=45,
width_shift_range=.1,
height_shift_range=.1,
shear_range=0.,
zoom_range=0,
channel_shift_range=0,
fill_mode='nearest',
cval=0.,
horizontal_flip=True,
vertical_flip=False,
rescale=1/255.,
#preprocessing_function=preprocess_img, # disable for nicer visualization
pool=pool # <-------------- Only change needed!
)
gen.fit(X_train)
X_train_aug = gen.flow(X_train, y_train_cat, seed=0)
print('{} process, duration: {}'.format(4, time.time() - start))
plot_images(X_train_aug, 'Augmented Images generated with {} processes'.format(n_process))
pool.terminate()
Now that we have verified that the images are being properly generated with multiple processes, we want to benchmark how the number of processes affects performance. Idealy, we would like to see speedups scale linearly with the number of processes added. However, as explained by Amdahl's Law, there are diminishing returns due to additional overhead.
The following benchmark will first test image augmentation without multiprocessing, then do a test for an increasing number of processes, up to a max of the number of logical CPUs your system has. It does multiple rounds of these tests so that we may average the results.
durs = collections.defaultdict(list)
num_cores = 2
try:
num_cores = multiprocessing.cpu_count()
except:
pass
for j in range(10):
print('Round', j)
for num_p in range(0, num_cores + 1):
pool = None
if num_p > 0:
pool = multiprocessing.Pool(processes=num_p)
start = time.time()
gen = T.ImageDataGenerator(
featurewise_center=False,
samplewise_center=False,
featurewise_std_normalization=False,
samplewise_std_normalization=False,
zca_whitening=False,
rotation_range=45,
width_shift_range=.1,
height_shift_range=.1,
shear_range=0.,
zoom_range=0,
channel_shift_range=0,
fill_mode='nearest',
cval=0.,
horizontal_flip=True,
vertical_flip=False,
rescale=None,
preprocessing_function=preprocess_img,
dim_ordering='default',
pool=pool
)
gen.fit(X_train)
X_train_aug = gen.flow(X_train, y_train_cat, seed=0)
for i, (imgs, labels) in enumerate(X_train_aug):
if i == 1000:
break
dur = time.time() - start
#print(num_p, dur)
sys.stdout.write('{}: {} ... '.format(num_p, dur))
sys.stdout.flush()
durs[num_p].append(dur)
if pool:
pool.terminate()
df = pd.DataFrame(durs)
df
df_mean = pd.DataFrame(df.mean(axis=0))
plt.figure(figsize=(10,5))
plt.plot(df_mean, marker='o')
plt.xlabel('# Processes')
plt.ylabel('Seconds')
plt.title('Image Augmentation time vs. # Processes')
speedups = 1 / (df_mean / df_mean[0][0])
plt.figure(figsize=(10,5))
plt.plot(speedups, marker='o')
plt.xlabel('# Processes')
plt.ylabel('Speedup')
plt.hlines(1, -1, df_mean.shape[0], colors='red', linestyles='dashed')
plt.title('Image Augmentation speedup vs. # Processes')
best_ix = np.argmax(speedups.values)
print('Best speedup: {0:.2f}x with {1} processes.'.format(speedups.values[best_ix][0], best_ix))
As we can see, we are able to cut image generation time in half. However, does the speedup remain when we are also sending the images to the GPU for network trianing?
Benchmark: CIFAR10 - In Memory Performance, Image Generation with GPU Training¶
import tools.sysmonitor as SM
reload(SM)
Let us take a model from one of the Keras examples:
from keras.models import Sequential
from keras.layers import Conv2D, Activation, MaxPooling2D, Dropout, Flatten, Dense
model = Sequential()
model.add(Conv2D(32, 3, 3, border_mode='same',
input_shape=(32, 32, 3)))
model.add(Activation('relu'))
model.add(Conv2D(32, 3, 3))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Conv2D(64, 3, 3, border_mode='same'))
model.add(Activation('relu'))
model.add(Conv2D(64, 3, 3))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(10))
model.add(Activation('softmax'))
model.summary()
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
When we are running lengthier training sessions, we may want to interrupt training to try a different approach: tweak hyperparameters, choose a different optimizer, adjust the network architecture, etc. In order to handle this gracefully with multiprocessing, we need to tell the child processes to ignore the interrupt signals. The parent process will catch the KeyboardInterrupt exception allow us to continue working interactively in the Notebook. Without this infrastructure, the processes will remain in limbo as detailed here.
pool = None
def init_worker():
signal.signal(signal.SIGINT, signal.SIG_IGN)
def setup_generator(processes=None, batch_size=32):
global pool
try:
pool.terminate()
except:
pass
if processes:
pool = multiprocessing.Pool(processes=processes, initializer=init_worker)
else:
pool = None
gen = T.ImageDataGenerator(
featurewise_center=False,
samplewise_center=False,
featurewise_std_normalization=False,
samplewise_std_normalization=False,
zca_whitening=False,
rotation_range=45,
width_shift_range=.1,
height_shift_range=.1,
shear_range=0.,
zoom_range=[.8, 1],
channel_shift_range=20,
fill_mode='nearest',
cval=0.,
horizontal_flip=True,
vertical_flip=False,
rescale=None,
preprocessing_function=preprocess_img,
pool=pool
)
test_gen = T.ImageDataGenerator(
preprocessing_function=preprocess_img,
pool=pool
)
gen.fit(X_train)
test_gen.fit(X_train)
X_train_aug = gen.flow(X_train, y_train_cat, seed=0, batch_size=batch_size)
X_test_aug = test_gen.flow(X_test, y_test_cat, seed=0, batch_size=batch_size)
return X_train_aug, X_test_aug
def run_benchmark(processes=None, batch_size=32, vert=True, plot=True):
X_train_aug, X_test_aug = setup_generator(processes=processes, batch_size=batch_size)
sys_mon = SM.SysMonitor()
sys_mon.start()
try:
model.fit_generator(X_train_aug, 50000/batch_size, epochs=5,
validation_data=X_test_aug, validation_steps=10000/batch_size)
except KeyboardInterrupt:
print '\n\nTraining Interrupted\n'
return None
sys_mon.stop()
title = None
if not processes:
title = '{0:.2f} seconds of computation, no multiprocessing, batch size = {1}'.format(sys_mon.duration, batch_size)
else:
title = '{0:.2f} seconds of computation, using {1} processes, batch size = {2}'.format(sys_mon.duration, processes, batch_size)
if plot:
sys_mon.plot(title, vert)
if not processes:
processes = 0
return {
'processes': processes,
'batch_size': batch_size,
'duration': sys_mon.duration,
'title': title
}
run_benchmark(processes=None, batch_size=32)
run_benchmark(processes=7, batch_size=32)
Now let's try a variety of different test scenarios:
runs = []
runs.append(run_benchmark(processes=None, batch_size=32))
runs.append(run_benchmark(processes=7, batch_size=32))
runs[0]['duration'] / runs[1]['duration']
As we can see, we can get a 1.8x speedup by using 7 processes. The GPU and CPU utilization is markedly higher and more consistent.
Let's see if batch size affects the outcome:
runs.append(run_benchmark(processes=None, batch_size=256))
runs.append(run_benchmark(processes=7, batch_size=256))
runs[2]['duration'] / runs[3]['duration']
With a batch size of 256, we get an even larger speedup of 3.3x
runs.append(run_benchmark(processes=None, batch_size=1024))
runs.append(run_benchmark(processes=7, batch_size=1024))
runs[4]['duration'] / runs[5]['duration']
With a batch size of 1024, we get a speedup of 3.48%. We also notice an interesting phenomenon. Without multiprocessing, the GPU is interittently going to 0 utilization. However, with 7 processes, we can see consistent >60% GPU utilization with a long initial period of >80%. Notice that with this batch size, we are able to get to lower losses a lot quicker than with lower batch sizes. This pattern will not necessarily continue with additional epochs, but it may be promising in some cases.
runs.append(run_benchmark(processes=None, batch_size=4096))
runs.append(run_benchmark(processes=7, batch_size=4096))
runs[6]['duration'] / runs[7]['duration']
A larger batch size of 4096 may or may not be a good choice in all cases, but when it comes to measuring system performance, we can see that the GPU usage is not consistent in the single process case. On the other hand, we are getting between 80-100% GPU utilization with 7 processes.
Let's do a final experiment with this dataset to see how Image Augmentation + GPU Training time scales with process count:
processes_counts = [None]
processes_counts.extend(range(1, 13))
results = []
for pc in processes_counts:
print('process count', pc)
results.append(run_benchmark(processes=pc, batch_size=4096, plot=False))
durs_4096 = pd.DataFrame([x['duration'] for x in results])
plt.figure(figsize=(10,5))
plt.plot(durs_4096, marker='o')
plt.xlabel('# Processes')
plt.ylabel('Seconds')
plt.title('Image Augmentation + GPU Training time vs. # Processes')
speedups_4096 = 1 / (durs_4096 / durs_4096.ix[0])
plt.figure(figsize=(10,5))
plt.plot(speedups_4096, marker='o')
plt.xlabel('# Processes')
plt.ylabel('Speedup')
plt.hlines(1, -1, speedups_4096.shape[0], colors='red', linestyles='dashed')
plt.title('Image Augmentation + GPU Training speedup vs. # Processes')
best_ix = np.argmax(speedups_4096.values)
print('Best speedup: {0:.2f}x with {1} processes.'.format(speedups_4096.values[best_ix][0], best_ix))
Benchmark: Dogs vs. Cats - On Disk Performance, Image Generation witih GPU Training¶
Using the images in the dogs vs. cats dataset provided by Kaggle, we can test the performance of image augmentation on images loaded from disk on the fly.
To follow along, unzip the downloaded training zip file, then create a data/train/cat
, /data/train/dog
, data/validation/cat
, and data/validation/dog
folders.
Then move the images that have indicies starting with 8 into the appropriate validation folders.
mv cat.8* data/validation/cat/
mv dog.8* data/validation/dog/
import os
paths = sorted(os.listdir('./data/train/cat'))
fig, ax = plt.subplots(5, 5, figsize=(15, 15))
for i in range(5):
for j in range(5):
ix = i*5 + j
img = plt.imread('./data/train/cat/' + paths[ix])
ax[i][j].imshow(img)
paths = sorted(os.listdir('./data/train/dog'))
fig, ax = plt.subplots(5, 5, figsize=(15, 15))
for i in range(5):
for j in range(5):
ix = i*5 + j
img = plt.imread('./data/train/dog/' + paths[ix])
ax[i][j].imshow(img)
gen = T.ImageDataGenerator(
featurewise_center=False,
samplewise_center=False,
featurewise_std_normalization=False,
samplewise_std_normalization=False,
zca_whitening=False,
rotation_range=45,
width_shift_range=.1,
height_shift_range=.1,
shear_range=0.,
zoom_range=[.8, 1],
channel_shift_range=0,
fill_mode='nearest',
cval=0.,
horizontal_flip=True,
vertical_flip=False,
rescale=1/255.,
# preprocessing_function=preprocess_img,
#dim_ordering='default',
# pool=None
)
test_gen = T.ImageDataGenerator(
preprocessing_function=preprocess_img,
# pool=None
)
train_generator = gen.flow_from_directory(
'data/train',
target_size=(150, 150),
batch_size=32,
class_mode='binary')
test_generator = gen.flow_from_directory(
'data/validation',
target_size=(150, 150),
batch_size=32,
class_mode='binary')
fig, ax = plt.subplots(6, 6, figsize=(15, 15))
for (imgs, labels) in train_generator:
for i in range(6):
for j in range(6):
if i*6 + j < 32:
ax[i][j].imshow(imgs[i*6 + j])
break
from keras.models import Sequential
from keras.layers import Convolution2D, MaxPooling2D
from keras.layers import Activation, Dropout, Flatten, Dense
model = Sequential()
model.add(Convolution2D(32, 3, 3, input_shape=(299, 299, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Convolution2D(32, 3, 3))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Convolution2D(64, 3, 3))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten()) # this converts our 3D feature maps to 1D feature vectors
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation('sigmoid'))
model.compile(loss='binary_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
def setup_cat_dog_generator(processes=None, batch_size=32):
global pool
try:
pool.terminate()
except:
pass
if processes:
pool = multiprocessing.Pool(processes=processes, initializer=init_worker)
else:
pool = None
gen = T.ImageDataGenerator(
featurewise_center=False,
samplewise_center=False,
featurewise_std_normalization=False,
samplewise_std_normalization=False,
zca_whitening=False,
rotation_range=45,
width_shift_range=.1,
height_shift_range=.1,
shear_range=0.,
zoom_range=[.8, 1],
channel_shift_range=20,
fill_mode='nearest',
cval=0.,
horizontal_flip=True,
vertical_flip=False,
rescale=None,
preprocessing_function=preprocess_img,
pool=pool
)
test_gen = T.ImageDataGenerator(
preprocessing_function=preprocess_img,
pool=pool
)
gen.fit(X_train)
test_gen.fit(X_train)
X_train_aug = gen.flow_from_directory(
'data/train',
target_size=(299, 299),
batch_size=batch_size,
class_mode='binary')
X_test_aug = gen.flow_from_directory(
'data/validation',
target_size=(299, 299),
batch_size=batch_size,
class_mode='binary')
return X_train_aug, X_test_aug
def run_cat_dog_benchmark(processes=None, batch_size=32, vert=True, plot=True):
X_train_aug, X_test_aug = setup_cat_dog_generator(processes=processes, batch_size=batch_size)
sys_mon = SM.SysMonitor()
sys_mon.start()
try:
model.fit_generator(
X_train_aug,
22778/batch_size,
epochs=2,
validation_data=X_test_aug,
validation_steps=2222/batch_size)
except KeyboardInterrupt:
print '\n\nTraining Interrupted\n'
return None
sys_mon.stop()
title = None
if not processes:
title = '{0:.2f} seconds of computation, no multiprocessing, batch size = {1}'.format(sys_mon.duration, batch_size)
else:
title = '{0:.2f} seconds of computation, using {1} processes, batch size = {2}'.format(sys_mon.duration, processes, batch_size)
if plot:
sys_mon.plot(title, vert)
if not processes:
processes = 0
return {
'processes': processes,
'batch_size': batch_size,
'duration': sys_mon.duration,
'title': title
}
In the following benchmark runs, you can see how inconsistent the GPU is being used without multiprocessing. Even with multiprocessing, the CPU is struggling to get enough data to the GPU to keep the GPU utilization stable. However, it's averaging out to be much higher than before.
Before running each benchmark, I run:
sync; echo 3 > /proc/sys/vm/drop_caches
in the shell. This clears any diles that may be cached in memory that could be skewing the benchmarking results.
runs = []
runs.append(run_cat_dog_benchmark(processes=None, batch_size=64))
runs.append(run_cat_dog_benchmark(processes=7, batch_size=64))
runs.append(run_cat_dog_benchmark(processes=11, batch_size=64))
runs[0]['duration'] / runs[2]['duration']
As we can see, we can get an even bigger performance gain when flowing from disk. Using 11 processes, we are getting 3.94x performance over single threaded. This will really help a lot when working with larger than memory datasets.