Commit c69e33db authored by Ross Girshick's avatar Ross Girshick

cleanup; pylint config

parent a8a0f55d
[TYPECHECK]
ignored-modules = numpy, numpy.random, cv2
......@@ -9,7 +9,7 @@ LOG="experiments/logs/svm_caffenet.txt.`date +'%Y-%m-%d_%H-%M-%S'`"
exec 3>&1 4>&2 &> >(tee -a "$LOG")
echo Logging output to "$LOG"
time ./tools/extra/train_svms.py --gpu $1 \
time ./tools/train_svms.py --gpu $1 \
--def models/CaffeNet/test.prototxt \
--net output/default/voc_2007_trainval/caffenet_fast_rcnn_iter_40000.caffemodel \
--imdb voc_2007_trainval \
......
......@@ -9,7 +9,7 @@ LOG="experiments/logs/svm_vgg_cnn_m_1024.txt.`date +'%Y-%m-%d_%H-%M-%S'`"
exec 3>&1 4>&2 &> >(tee -a "$LOG")
echo Logging output to "$LOG"
time ./tools/extra/train_svms.py --gpu $1 \
time ./tools/train_svms.py --gpu $1 \
--def models/VGG_CNN_M_1024/test.prototxt \
--net output/default/voc_2007_trainval/vgg_cnn_m_1024_fast_rcnn_iter_40000.caffemodel \
--imdb voc_2007_trainval \
......
......@@ -6,10 +6,11 @@
# --------------------------------------------------------
import numpy as np
import numpy.random as npr
import cv2
import matplotlib.pyplot as plt
from fast_rcnn.config import cfg
import utils.blob
from utils.blob import prep_im_for_blob, im_list_to_blob
def get_minibatch(roidb):
"""
......@@ -19,8 +20,8 @@ def get_minibatch(roidb):
# Infer number of classes from the number of columns in gt_overlaps
num_classes = roidb[0]['gt_overlaps'].shape[1]
# Sample random scales to use for each image in this batch
random_scale_inds = \
np.random.randint(0, high=len(cfg.TRAIN.SCALES), size=num_images)
random_scale_inds = npr.randint(0, high=len(cfg.TRAIN.SCALES),
size=num_images)
assert(cfg.TRAIN.BATCH_SIZE % num_images == 0), \
'num_images ({}) must divide BATCH_SIZE ({})'. \
format(num_images, cfg.TRAIN.BATCH_SIZE)
......@@ -40,7 +41,7 @@ def get_minibatch(roidb):
labels, overlaps, im_rois, bbox_targets, bbox_loss \
= _sample_rois(roidb[im_i], fg_rois_per_image, rois_per_image)
# Add to ROIs blob
# Add to RoIs blob
rois = _scale_im_rois(im_rois, im_scales[im_i])
batch_ind = im_i * np.ones((rois.shape[0], 1))
rois_blob_this_image = np.hstack((batch_ind, rois))
......@@ -67,42 +68,42 @@ def get_minibatch(roidb):
def _sample_rois(roidb, fg_rois_per_image, rois_per_image):
"""
Generate a random sample of ROIs comprising foreground and background
Generate a random sample of RoIs comprising foreground and background
examples.
"""
# label = class ROI has max overlap with
# label = class RoI has max overlap with
labels = roidb['max_classes']
overlaps = roidb['max_overlaps']
rois = roidb['boxes']
# Select foreground ROIs as those with >= FG_THRESH overlap
# Select foreground RoIs as those with >= FG_THRESH overlap
fg_inds = np.where(overlaps >= cfg.TRAIN.FG_THRESH)[0]
# Guard against the case when an image has fewer than fg_rois_per_image
# foreground ROIs
# foreground RoIs
fg_rois_per_this_image = np.minimum(fg_rois_per_image, fg_inds.size)
# Sample foreground regions without replacement
if fg_inds.size > 0:
fg_inds = np.random.choice(fg_inds, size=fg_rois_per_this_image,
replace=False)
fg_inds = npr.choice(fg_inds, size=fg_rois_per_this_image,
replace=False)
# Select background ROIs as those within [BG_THRESH_LO, BG_THRESH_HI)
# Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI)
bg_inds = np.where((overlaps < cfg.TRAIN.BG_THRESH_HI) &
(overlaps >= cfg.TRAIN.BG_THRESH_LO))[0]
# Compute number of background ROIs to take from this image (guarding
# Compute number of background RoIs to take from this image (guarding
# against there being fewer than desired)
bg_rois_per_this_image = rois_per_image - fg_rois_per_this_image
bg_rois_per_this_image = np.minimum(bg_rois_per_this_image,
bg_inds.size)
# Sample foreground regions without replacement
if bg_inds.size > 0:
bg_inds = np.random.choice(bg_inds, size=bg_rois_per_this_image,
replace=False)
bg_inds = npr.choice(bg_inds, size=bg_rois_per_this_image,
replace=False)
# The indices that we're selecting (both fg and bg)
keep_inds = np.append(fg_inds, bg_inds)
# Select sampled values from various arrays:
labels = labels[keep_inds]
# Clamp labels for the background ROIs to 0
# Clamp labels for the background RoIs to 0
labels[fg_rois_per_this_image:] = 0
overlaps = overlaps[keep_inds]
rois = rois[keep_inds]
......@@ -117,8 +118,7 @@ def _sample_rois(roidb, fg_rois_per_image, rois_per_image):
def _get_image_blob(roidb, scale_inds):
"""
Build an input blob from the images in the roidb at the specified
scales.
Builds an input blob from the images in the roidb at the specified scales.
"""
num_images = len(roidb)
processed_ims = []
......@@ -128,14 +128,13 @@ def _get_image_blob(roidb, scale_inds):
if roidb[i]['flipped']:
im = im[:, ::-1, :]
target_size = cfg.TRAIN.SCALES[scale_inds[i]]
im, im_scale = \
utils.blob.prep_im_for_blob(im, cfg.PIXEL_MEANS,
target_size, cfg.TRAIN.MAX_SIZE)
im, im_scale = prep_im_for_blob(im, cfg.PIXEL_MEANS, target_size,
cfg.TRAIN.MAX_SIZE)
im_scales.append(im_scale)
processed_ims.append(im)
# Create a blob to hold the input images
blob = utils.blob.im_list_to_blob(processed_ims)
blob = im_list_to_blob(processed_ims)
return blob, im_scales
......@@ -147,11 +146,13 @@ def _get_bbox_regression_labels(bbox_target_data, num_classes):
"""
Bounding-box regression targets are stored in a compact form in the roidb.
This function expands those targets into the 4-of-4*K representation used
by the network (i.e. only one class has non-zero targets).
The loss weights are similarly expanded.
by the network (i.e. only one class has non-zero targets). The loss weights
are similarly expanded.
Returns:
(N, K * 4) blob of regression targets
(N, K * 4) blob of loss weights
"""
# Return (N, K * 4, 1, 1) blob of regression targets
# Return (N, K * 4, 1, 1) blob of loss weights
clss = bbox_target_data[:, 0]
bbox_targets = np.zeros((clss.size, 4 * num_classes), dtype=np.float32)
bbox_loss_weights = np.zeros(bbox_targets.shape, dtype=np.float32)
......@@ -165,21 +166,21 @@ def _get_bbox_regression_labels(bbox_target_data, num_classes):
return bbox_targets, bbox_loss_weights
def _vis_minibatch(im_blob, rois_blob, labels_blob, overlaps):
num_images = im_blob.shape[0]
"""Visualize a mini-batch for debugging."""
for i in xrange(rois_blob.shape[0]):
rois = rois_blob[i, :]
im_ind = rois[0]
roi = rois[1:]
im = im_blob[im_ind, :, :, :].transpose((1, 2, 0)).copy()
im += cfg.PIXEL_MEANS
im = im[:, :, (2, 1, 0)]
im = im.astype(np.uint8)
cls = labels_blob[i]
plt.imshow(im)
print 'class: ', cls, ' overlap: ', overlaps[i]
plt.gca().add_patch(
plt.Rectangle((roi[0], roi[1]), roi[2] - roi[0],
roi[3] - roi[1], fill=False,
edgecolor='r', linewidth=3)
)
plt.show()
rois = rois_blob[i, :]
im_ind = rois[0]
roi = rois[1:]
im = im_blob[im_ind, :, :, :].transpose((1, 2, 0)).copy()
im += cfg.PIXEL_MEANS
im = im[:, :, (2, 1, 0)]
im = im.astype(np.uint8)
cls = labels_blob[i]
plt.imshow(im)
print 'class: ', cls, ' overlap: ', overlaps[i]
plt.gca().add_patch(
plt.Rectangle((roi[0], roi[1]), roi[2] - roi[0],
roi[3] - roi[1], fill=False,
edgecolor='r', linewidth=3)
)
plt.show()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment