Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Zahra Rajabi
py-faster-rcnn
Commits
630ea249
Commit
630ea249
authored
Apr 29, 2015
by
Ross Girshick
Browse files
command line options
parent
24b91ae0
Changes
1
Hide whitespace changes
Inline
Side-by-side
tools/demo.py
View file @
630ea249
...
...
@@ -14,7 +14,8 @@ from utils.cython_nms import nms
from
utils.timer
import
Timer
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
caffe
,
cPickle
,
os
,
cv2
import
caffe
,
cPickle
,
os
,
sys
,
cv2
import
argparse
CLASSES
=
(
'__background__'
,
'aeroplane'
,
'bicycle'
,
'bird'
,
'boat'
,
...
...
@@ -23,6 +24,14 @@ CLASSES = ('__background__',
'motorbike'
,
'person'
,
'pottedplant'
,
'sheep'
,
'sofa'
,
'train'
,
'tvmonitor'
)
NETS
=
{
'vgg16'
:
(
'VGG16'
,
'vgg16_fast_rcnn_iter_40000.caffemodel'
),
'vgg_cnn_m_1024'
:
(
'VGG_CNN_M_1024'
,
'vgg_cnn_m_1024_fast_rcnn_iter_40000.caffemodel'
),
'caffenet'
:
(
'CaffeNet'
,
'caffenet_fast_rcnn_iter_40000.caffemodel'
)}
def
vis_detections
(
im
,
class_name
,
dets
,
thresh
=
0.5
):
"""Draw detected bounding boxes."""
inds
=
np
.
where
(
dets
[:,
-
1
]
>=
thresh
)[
0
]
...
...
@@ -90,21 +99,37 @@ def demo(net, image_name, classes):
print
'Close image window (ctrl-w) to continue'
vis_detections
(
im
,
cls
,
dets
,
thresh
=
CONF_THRESH
)
def
parse_args
():
"""Parse input arguments."""
parser
=
argparse
.
ArgumentParser
(
description
=
'Train a Fast R-CNN network'
)
parser
.
add_argument
(
'--gpu'
,
dest
=
'gpu_id'
,
help
=
'GPU device id to use [0]'
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
'--cpu'
,
dest
=
'cpu_mode'
,
help
=
'Use CPU mode (overrides --gpu)'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--net'
,
dest
=
'demo_net'
,
help
=
'Network to use [vgg16]'
,
choices
=
NETS
.
keys
(),
default
=
'vgg16'
)
args
=
parser
.
parse_args
()
return
args
if
__name__
==
'__main__'
:
gpu_id
=
0
prototxt
=
'models/VGG16/test.prototxt'
caffemodel
=
(
'data/fast_rcnn_models/'
'vgg16_fast_rcnn_iter_40000.caffemodel'
)
# prototxt = 'models/VGG_CNN_M_1024/test.prototxt'
# caffemodel = ('data/fast_rcnn_models/'
# 'vgg_cnn_m_1024_fast_rcnn_iter_40000.caffemodel')
args
=
parse_args
()
prototxt
=
os
.
path
.
join
(
'models'
,
NETS
[
args
.
demo_net
][
0
],
'test.prototxt'
)
caffemodel
=
os
.
path
.
join
(
'data'
,
'fast_rcnn_models'
,
NETS
[
args
.
demo_net
][
1
])
if
not
os
.
path
.
isfile
(
caffemodel
):
raise
IOError
((
'{:s} not found.
\n
Did you run ./data/script/'
'fetch_fast_rcnn_models.sh?'
).
format
(
caffemodel
))
caffe
.
set_mode_gpu
()
caffe
.
set_device
(
gpu_id
)
if
args
.
cpu_mode
:
caffe
.
set_mode_cpu
()
else
:
caffe
.
set_mode_gpu
()
caffe
.
set_device
(
args
.
gpu_id
)
net
=
caffe
.
Net
(
prototxt
,
caffemodel
,
caffe
.
TEST
)
print
'
\n\n
Loaded network {:s}'
.
format
(
caffemodel
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment