Commit 2572d763 authored by Ross Girshick's avatar Ross Girshick

reduce prefetch process overhead

parent 1b12f840
......@@ -15,7 +15,7 @@ from fast_rcnn.config import cfg
from roi_data_layer.minibatch import get_minibatch
import numpy as np
import yaml
from multiprocessing import Process, queues
from multiprocessing import Process, Queue
class RoIDataLayer(caffe.Layer):
"""Fast R-CNN data layer used for training."""
......@@ -34,28 +34,17 @@ class RoIDataLayer(caffe.Layer):
self._cur += cfg.TRAIN.IMS_PER_BATCH
return db_inds
@staticmethod
def _prefetch(minibatch_db, num_classes, output_queue):
"""Prefetch minibatch blobs (if enabled cfg.TRAIN.USE_PREFETCH)."""
blobs = get_minibatch(minibatch_db, num_classes)
output_queue.put(blobs)
def _get_next_minibatch(self):
"""Return the blobs to be used for the next minibatch.
If cfg.TRAIN.USE_PREFETCH is True, then blobs will be computed in a
separate process and made available through the self._prefetch_queue
queue.
separate process and made available through self._blob_queue.
"""
db_inds = self._get_next_minibatch_inds()
minibatch_db = [self._roidb[i] for i in db_inds]
if cfg.TRAIN.USE_PREFETCH:
self._prefetch_process = Process(target=RoIDataLayer._prefetch,
args=(minibatch_db,
self._num_classes,
self._prefetch_queue))
self._prefetch_process.start()
return self._blob_queue.get()
else:
db_inds = self._get_next_minibatch_inds()
minibatch_db = [self._roidb[i] for i in db_inds]
return get_minibatch(minibatch_db, self._num_classes)
def set_roidb(self, roidb):
......@@ -63,13 +52,21 @@ class RoIDataLayer(caffe.Layer):
self._roidb = roidb
self._shuffle_roidb_inds()
if cfg.TRAIN.USE_PREFETCH:
self._get_next_minibatch()
self._blob_queue = Queue(10)
self._prefetch_process = BlobFetcher(self._blob_queue,
self._roidb,
self._num_classes)
self._prefetch_process.start()
# Terminate the child process when the parent exists
def cleanup():
print 'Terminating BlobFetcher'
self._prefetch_process.terminate()
self._prefetch_process.join()
import atexit
atexit.register(cleanup)
def setup(self, bottom, top):
"""Setup the RoIDataLayer."""
if cfg.TRAIN.USE_PREFETCH:
self._prefetch_process = None
self._prefetch_queue = queues.SimpleQueue()
# parse the layer parameter string, which must be valid YAML
layer_params = yaml.load(self.param_str_)
......@@ -106,11 +103,7 @@ class RoIDataLayer(caffe.Layer):
def forward(self, bottom, top):
"""Get blobs and copy them into this layer's top blob vector."""
if cfg.TRAIN.USE_PREFETCH:
blobs = self._prefetch_queue.get()
self._get_next_minibatch()
else:
blobs = self._get_next_minibatch()
blobs = self._get_next_minibatch()
for blob_name, blob in blobs.iteritems():
top_ind = self._name_to_top_map[blob_name]
......@@ -126,3 +119,40 @@ class RoIDataLayer(caffe.Layer):
def reshape(self, bottom, top):
"""Reshaping happens during the call to forward."""
pass
class BlobFetcher(Process):
"""Experimental class for prefetching blobs in a separate process."""
def __init__(self, queue, roidb, num_classes):
super(BlobFetcher, self).__init__()
self._queue = queue
self._roidb = roidb
self._num_classes = num_classes
self._perm = None
self._cur = 0
self._shuffle_roidb_inds()
# fix the random seed for reproducibility
np.random.seed(cfg.RNG_SEED)
def _shuffle_roidb_inds(self):
"""Randomly permute the training roidb."""
# TODO(rbg): remove duplicated code
self._perm = np.random.permutation(np.arange(len(self._roidb)))
self._cur = 0
def _get_next_minibatch_inds(self):
"""Return the roidb indices for the next minibatch."""
# TODO(rbg): remove duplicated code
if self._cur + cfg.TRAIN.IMS_PER_BATCH >= len(self._roidb):
self._shuffle_roidb_inds()
db_inds = self._perm[self._cur:self._cur + cfg.TRAIN.IMS_PER_BATCH]
self._cur += cfg.TRAIN.IMS_PER_BATCH
return db_inds
def run(self):
print 'BlobFetcher started'
while True:
db_inds = self._get_next_minibatch_inds()
minibatch_db = [self._roidb[i] for i in db_inds]
blobs = get_minibatch(minibatch_db, self._num_classes)
self._queue.put(blobs)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment