Commit 630ea249 authored by Ross Girshick's avatar Ross Girshick

command line options

parent 24b91ae0
......@@ -14,7 +14,8 @@ from utils.cython_nms import nms
from utils.timer import Timer
import matplotlib.pyplot as plt
import numpy as np
import caffe, cPickle, os, cv2
import caffe, cPickle, os, sys, cv2
import argparse
CLASSES = ('__background__',
'aeroplane', 'bicycle', 'bird', 'boat',
......@@ -23,6 +24,14 @@ CLASSES = ('__background__',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor')
NETS = {'vgg16': ('VGG16',
'vgg16_fast_rcnn_iter_40000.caffemodel'),
'vgg_cnn_m_1024': ('VGG_CNN_M_1024',
'vgg_cnn_m_1024_fast_rcnn_iter_40000.caffemodel'),
'caffenet': ('CaffeNet',
'caffenet_fast_rcnn_iter_40000.caffemodel')}
def vis_detections(im, class_name, dets, thresh=0.5):
"""Draw detected bounding boxes."""
inds = np.where(dets[:, -1] >= thresh)[0]
......@@ -90,21 +99,37 @@ def demo(net, image_name, classes):
print 'Close image window (ctrl-w) to continue'
vis_detections(im, cls, dets, thresh=CONF_THRESH)
def parse_args():
"""Parse input arguments."""
parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')
parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',
default=0, type=int)
parser.add_argument('--cpu', dest='cpu_mode',
help='Use CPU mode (overrides --gpu)',
action='store_true')
parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16]',
choices=NETS.keys(), default='vgg16')
args = parser.parse_args()
return args
if __name__ == '__main__':
gpu_id = 0
prototxt = 'models/VGG16/test.prototxt'
caffemodel = ('data/fast_rcnn_models/'
'vgg16_fast_rcnn_iter_40000.caffemodel')
# prototxt = 'models/VGG_CNN_M_1024/test.prototxt'
# caffemodel = ('data/fast_rcnn_models/'
# 'vgg_cnn_m_1024_fast_rcnn_iter_40000.caffemodel')
args = parse_args()
prototxt = os.path.join('models', NETS[args.demo_net][0], 'test.prototxt')
caffemodel = os.path.join('data', 'fast_rcnn_models',
NETS[args.demo_net][1])
if not os.path.isfile(caffemodel):
raise IOError(('{:s} not found.\nDid you run ./data/script/'
'fetch_fast_rcnn_models.sh?').format(caffemodel))
caffe.set_mode_gpu()
caffe.set_device(gpu_id)
if args.cpu_mode:
caffe.set_mode_cpu()
else:
caffe.set_mode_gpu()
caffe.set_device(args.gpu_id)
net = caffe.Net(prototxt, caffemodel, caffe.TEST)
print '\n\nLoaded network {:s}'.format(caffemodel)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment