train.py 4.34 KB
Newer Older
Ross Girshick's avatar
Ross Girshick committed
1
2
# --------------------------------------------------------
# Fast R-CNN
Ross Girshick's avatar
Ross Girshick committed
3
4
5
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
Ross Girshick's avatar
Ross Girshick committed
6
# --------------------------------------------------------
Ross Girshick's avatar
Ross Girshick committed
7

Ross Girshick's avatar
Ross Girshick committed
8
9
"""Train a Fast R-CNN network."""

Ross Girshick's avatar
Ross Girshick committed
10
import caffe
11
12
from fast_rcnn.config import cfg
import roi_data_layer.roidb as rdl_roidb
13
from utils.timer import Timer
14
import numpy as np
15
import os
Ross Girshick's avatar
Ross Girshick committed
16

Ross Girshick's avatar
Ross Girshick committed
17
18
19
20
from caffe.proto import caffe_pb2
import google.protobuf as pb2

class SolverWrapper(object):
Ross Girshick's avatar
Ross Girshick committed
21
22
23
24
25
    """A simple wrapper around Caffe's solver.
    This wrapper gives us control over he snapshotting process, which we
    use to unnormalize the learned bounding-box regression weights.
    """

26
27
    def __init__(self, solver_prototxt, roidb, output_dir,
                 pretrained_model=None):
Ross Girshick's avatar
Ross Girshick committed
28
        """Initialize the SolverWrapper."""
29
30
31
32
33
34
        self.output_dir = output_dir

        print 'Computing bounding-box regression targets...'
        self.bbox_means, self.bbox_stds = \
                rdl_roidb.add_bbox_regression_targets(roidb)
        print 'done'
Ross Girshick's avatar
Ross Girshick committed
35
36
37

        self.solver = caffe.SGDSolver(solver_prototxt)
        if pretrained_model is not None:
38
39
            print ('Loading pretrained model '
                   'weights from {:s}').format(pretrained_model)
Ross Girshick's avatar
Ross Girshick committed
40
41
42
43
44
45
            self.solver.net.copy_from(pretrained_model)

        self.solver_param = caffe_pb2.SolverParameter()
        with open(solver_prototxt, 'rt') as f:
            pb2.text_format.Merge(f.read(), self.solver_param)

46
47
        self.solver.net.layers[0].set_roidb(roidb)

Ross Girshick's avatar
Ross Girshick committed
48
    def snapshot(self):
Ross Girshick's avatar
Ross Girshick committed
49
50
51
        """Take a snapshot of the network after unnormalizing the learned
        bounding-box regression weights. This enables easy use at test-time.
        """
52
        net = self.solver.net
53

54
        if cfg.TRAIN.BBOX_REG:
55
            # save original values
56
57
            orig_0 = net.params['bbox_pred'][0].data.copy()
            orig_1 = net.params['bbox_pred'][1].data.copy()
58
59

            # scale and shift with bbox reg unnormalization; then save snapshot
60
61
            net.params['bbox_pred'][0].data[...] = \
                    (net.params['bbox_pred'][0].data *
62
                     self.bbox_stds[:, np.newaxis])
63
64
            net.params['bbox_pred'][1].data[...] = \
                    (net.params['bbox_pred'][1].data *
65
                     self.bbox_stds + self.bbox_means)
Ross Girshick's avatar
Ross Girshick committed
66

67
68
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)
69

70
71
        infix = ('_' + cfg.TRAIN.SNAPSHOT_INFIX
                 if cfg.TRAIN.SNAPSHOT_INFIX != '' else '')
72
73
74
        filename = (self.solver_param.snapshot_prefix + infix +
                    '_iter_{:d}'.format(self.solver.iter) + '.caffemodel')
        filename = os.path.join(self.output_dir, filename)
75

76
        net.save(str(filename))
Ross Girshick's avatar
Ross Girshick committed
77
78
        print 'Wrote snapshot to: {:s}'.format(filename)

79
80
        if cfg.TRAIN.BBOX_REG:
            # restore net to original state
81
82
            net.params['bbox_pred'][0].data[...] = orig_0
            net.params['bbox_pred'][1].data[...] = orig_1
Ross Girshick's avatar
Ross Girshick committed
83

84
    def train_model(self, max_iters):
Ross Girshick's avatar
Ross Girshick committed
85
        """Network training loop."""
86
        last_snapshot_iter = -1
87
        timer = Timer()
88
        while self.solver.iter < max_iters:
89
            # Make one SGD update
90
            timer.tic()
91
            self.solver.step(1)
92
93
94
            timer.toc()
            if self.solver.iter % (10 * self.solver_param.display) == 0:
                print 'speed: {:.3f}s / iter'.format(timer.average_time)
95
96
97
98

            if self.solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = self.solver.iter
                self.snapshot()
99
100
101
102

        if last_snapshot_iter != self.solver.iter:
            self.snapshot()

103
def get_training_roidb(imdb):
Ross Girshick's avatar
Ross Girshick committed
104
    """Returns a roidb (Region of Interest database) for use in training."""
105
    if cfg.TRAIN.USE_FLIPPED:
Ross Girshick's avatar
Ross Girshick committed
106
        print 'Appending horizontally-flipped training examples...'
107
        imdb.append_flipped_images()
Ross Girshick's avatar
Ross Girshick committed
108
        print 'done'
Ross Girshick's avatar
Ross Girshick committed
109

Ross Girshick's avatar
Ross Girshick committed
110
    print 'Preparing training data...'
111
    rdl_roidb.prepare_roidb(imdb)
Ross Girshick's avatar
Ross Girshick committed
112
    print 'done'
Ross Girshick's avatar
Ross Girshick committed
113

114
    return imdb.roidb
Ross Girshick's avatar
Ross Girshick committed
115

116
117
def train_net(solver_prototxt, roidb, output_dir,
              pretrained_model=None, max_iters=40000):
Ross Girshick's avatar
Ross Girshick committed
118
    """Train a Fast R-CNN network."""
119
120
    sw = SolverWrapper(solver_prototxt, roidb, output_dir,
                       pretrained_model=pretrained_model)
Ross Girshick's avatar
Ross Girshick committed
121

Ross Girshick's avatar
Ross Girshick committed
122
    print 'Solving...'
123
    sw.train_model(max_iters)
Ross Girshick's avatar
Ross Girshick committed
124
    print 'done solving'