train_net.py 2.82 KB
Newer Older
1
#!/usr/bin/env python
Ross Girshick's avatar
Ross Girshick committed
2

Ross Girshick's avatar
Ross Girshick committed
3
4
# --------------------------------------------------------
# Fast R-CNN
Ross Girshick's avatar
Ross Girshick committed
5
6
7
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
Ross Girshick's avatar
Ross Girshick committed
8
# --------------------------------------------------------
Ross Girshick's avatar
Ross Girshick committed
9

Ross Girshick's avatar
Ross Girshick committed
10
11
"""Train a Fast R-CNN network on a region of interest database."""

Ross Girshick's avatar
Ross Girshick committed
12
import _init_paths
13
14
from fast_rcnn.train import get_training_roidb, train_net
from fast_rcnn.config import cfg, cfg_from_file, get_output_dir
15
from datasets.factory import get_imdb
16
17
import caffe
import argparse
18
import pprint
19
import numpy as np
Ross Girshick's avatar
Ross Girshick committed
20
import sys
21
22
23
24
25

def parse_args():
    """
    Parse input arguments
    """
Ross Girshick's avatar
Ross Girshick committed
26
    parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')
27
28
    parser.add_argument('--gpu', dest='gpu_id',
                        help='GPU device id to use [0]',
29
                        default=0, type=int)
30
31
    parser.add_argument('--solver', dest='solver',
                        help='solver prototxt',
32
                        default=None, type=str)
33
34
35
    parser.add_argument('--iters', dest='max_iters',
                        help='number of iterations to train',
                        default=40000, type=int)
Ross Girshick's avatar
Ross Girshick committed
36
37
38
    parser.add_argument('--weights', dest='pretrained_model',
                        help='initialize with pretrained model weights',
                        default=None, type=str)
39
    parser.add_argument('--cfg', dest='cfg_file',
40
41
                        help='optional config file',
                        default=None, type=str)
42
43
44
    parser.add_argument('--imdb', dest='imdb_name',
                        help='dataset to train on',
                        default='voc_2007_trainval', type=str)
45
46
47
    parser.add_argument('--rand', dest='randomize',
                        help='randomize (do not use a fixed seed)',
                        action='store_true')
Ross Girshick's avatar
Ross Girshick committed
48
49
50
51

    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(1)
52
53
54
55
56

    args = parser.parse_args()
    return args

if __name__ == '__main__':
Ross Girshick's avatar
Ross Girshick committed
57
58
    args = parse_args()

59
60
    print('Called with args:')
    print(args)
Ross Girshick's avatar
Ross Girshick committed
61

62
63
    if args.cfg_file is not None:
        cfg_from_file(args.cfg_file)
Ross Girshick's avatar
Ross Girshick committed
64

Ross Girshick's avatar
Ross Girshick committed
65
    print('Using config:')
66
    pprint.pprint(cfg)
67

68
69
70
71
    if not args.randomize:
        # fix the random seeds (numpy and caffe) for reproducibility
        np.random.seed(cfg.RNG_SEED)
        caffe.set_random_seed(cfg.RNG_SEED)
72
73
74
75
76
77

    # set up caffe
    caffe.set_mode_gpu()
    if args.gpu_id is not None:
        caffe.set_device(args.gpu_id)

Ross Girshick's avatar
Ross Girshick committed
78
79
80
    imdb = get_imdb(args.imdb_name)
    print 'Loaded dataset `{:s}` for training'.format(imdb.name)
    roidb = get_training_roidb(imdb)
81

Ross Girshick's avatar
Ross Girshick committed
82
    output_dir = get_output_dir(imdb, None)
83
84
85
86
87
    print 'Output will be saved to `{:s}`'.format(output_dir)

    train_net(args.solver, roidb, output_dir,
              pretrained_model=args.pretrained_model,
              max_iters=args.max_iters)