Commit 54cd7fba authored by Ross Girshick's avatar Ross Girshick
Browse files

fix the rng seed for caffe (not just numpy); option to randomize training

parent 4e50be95
...@@ -225,6 +225,9 @@ Test output is written underneath `$FRCN_ROOT/output`. ...@@ -225,6 +225,9 @@ Test output is written underneath `$FRCN_ROOT/output`.
### Experiment scripts ### Experiment scripts
Scripts to reproduce the experiments in the paper (*up to stochastic variation*) are provided in `$FRCN_ROOT/experiments/scripts`. Log files for experiments are located in `experiments/logs`. Scripts to reproduce the experiments in the paper (*up to stochastic variation*) are provided in `$FRCN_ROOT/experiments/scripts`. Log files for experiments are located in `experiments/logs`.
**Note:** Until recently (commit a566e39), the RNG seed for Caffe was not fixed during training. Now it's fixed, unless `train_net.py` is called with the `--rand` flag.
Results generated before this commit will have some stochastic variation.
### Extra downloads ### Extra downloads
- [Experiment logs](http://www.cs.berkeley.edu/~rbg/fast-rcnn-data/fast_rcnn_experiments.tgz) - [Experiment logs](http://www.cs.berkeley.edu/~rbg/fast-rcnn-data/fast_rcnn_experiments.tgz)
......
Subproject commit 832bbcba5261a34c08916199e4ec91377ff7a4fb Subproject commit cf9235a98680ffa3786957ba82bec9460f7db486
...@@ -24,9 +24,11 @@ def parse_args(): ...@@ -24,9 +24,11 @@ def parse_args():
Parse input arguments Parse input arguments
""" """
parser = argparse.ArgumentParser(description='Train a Fast R-CNN network') parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')
parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]', parser.add_argument('--gpu', dest='gpu_id',
help='GPU device id to use [0]',
default=0, type=int) default=0, type=int)
parser.add_argument('--solver', dest='solver', help='solver prototxt', parser.add_argument('--solver', dest='solver',
help='solver prototxt',
default=None, type=str) default=None, type=str)
parser.add_argument('--iters', dest='max_iters', parser.add_argument('--iters', dest='max_iters',
help='number of iterations to train', help='number of iterations to train',
...@@ -35,10 +37,14 @@ def parse_args(): ...@@ -35,10 +37,14 @@ def parse_args():
help='initialize with pretrained model weights', help='initialize with pretrained model weights',
default=None, type=str) default=None, type=str)
parser.add_argument('--cfg', dest='cfg_file', parser.add_argument('--cfg', dest='cfg_file',
help='optional config file', default=None, type=str) help='optional config file',
default=None, type=str)
parser.add_argument('--imdb', dest='imdb_name', parser.add_argument('--imdb', dest='imdb_name',
help='dataset to train on', help='dataset to train on',
default='voc_2007_trainval', type=str) default='voc_2007_trainval', type=str)
parser.add_argument('--rand', dest='randomize',
help='randomize (do not use a fixed seed)',
action='store_true')
if len(sys.argv) == 1: if len(sys.argv) == 1:
parser.print_help() parser.print_help()
...@@ -59,8 +65,10 @@ if __name__ == '__main__': ...@@ -59,8 +65,10 @@ if __name__ == '__main__':
print('Using config:') print('Using config:')
pprint.pprint(cfg) pprint.pprint(cfg)
# fix the random seed for reproducibility if not args.randomize:
np.random.seed(cfg.RNG_SEED) # fix the random seeds (numpy and caffe) for reproducibility
np.random.seed(cfg.RNG_SEED)
caffe.set_random_seed(cfg.RNG_SEED)
# set up caffe # set up caffe
caffe.set_mode_gpu() caffe.set_mode_gpu()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment