Skip to content

Commit

Permalink
Merge pull request apache#180 from antinucleon/master
Browse files Browse the repository at this point in the history
[example] simple_bind
  • Loading branch information
tqchen committed Sep 29, 2015
2 parents b3c0988 + 10af9b9 commit cdeb822
Show file tree
Hide file tree
Showing 7 changed files with 604 additions and 14 deletions.
2 changes: 1 addition & 1 deletion example/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Notebooks
* [cifar-10 recipe](notebooks/cifar-recipe.ipynb) gives you a step by step demo of how to use MXNet
* [cifar-100](notebooks/cifar-100.ipynb) gives you a demo of how to train a 75.68% accuracy CIFAR-100 model
* [predict with pretained model](notebooks/predict-with-pretrained-model.ipynb) gives you a demo of use a pretrained Inception-BN Network

* [simple bind](notebooks/simple_bind.ipynb) gives you a demo of some details in ```mx.model``` module.

Contents
--------
Expand Down
27 changes: 18 additions & 9 deletions example/imagenet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@

## Prepare Dataset

TODO
We are using RecordIO to pack image together. By packing images into Record IO, we can reach 3000 images/second on a normal HDD disk. This includes cost of crop from (3 x 256 x 256) to (3 x 224 x 224), random flip and other augmentation.

Please read the document of [How to Create Dataset Using RecordIO](https://mxnet.readthedocs.org/en/latest/python/io.html#create-dataset-using-recordio)

Note: A commonly mistake is forgetting shuffle the image list. This will lead fail of training, eg. ```accuracy``` keeps 0.001 for several rounds.

## Neural Networks

- [alexnet.py](alexnet.py) : alexnet with 5 convolution layers followed by 3
fully connnected layers
- [inception.py](inception.py) : inception + batch norm network

## Results

Expand All @@ -16,17 +21,21 @@ Machine: Dual Xeon E5-2680 2.8GHz, GTX 980, Ubuntu 14.0, GCC 4.8, MKL, CUDA

* AlexNet

| | 1 x GTX 980 | 2 x GTX 980 | 4 x GTX 980 |
| ---------------- | ----------- | ------------ | ------------ |
| ```alexnet.py``` | 527 img/sec | 1030 img/sec | 1413 img/sec |
| cxxnet | 256 img/sec | 492 img/sec | 914 img/sec |

For AlexNet, single model + single center test top-5 accuracy will be around 81%.

| | val accuracy | 1 x GTX 980 | 2 x GTX 980 | 4 x GTX 980 |
| --- | ---: | ---: | ---: | ---: | ---: |
| `alexnet.py` | ? | 527 img/sec | 1030 img/sec | 1413 img/sec |
| cxxnet | ?|256 img/sec | 492 img/sec | 914 img/sec |

* Inception-BN

| | val accuracy | 1 x GTX 980 | 2 x GTX 980 | 4 x GTX 980 |
| --- | ---: | ---: | ---: | ---: | ---: |
| `inception.py` | ? | 97 img/sec (batch 32) | 178 img/sec (batch 64) | 357 img/sec (batch 128) |
| cxxnet | ?|57 img/sec (batch 16) | 112 img/sec (batch 32) | 224 img/sec (batch 64) |
| | 1 x GTX 980 | 2 x GTX 980 | 4 x GTX 980 |
| ------------------ | --------------------- | ---------------------- | ----------------------- |
| ```inception.py``` | 97 img/sec (batch 32) | 178 img/sec (batch 64) | 357 img/sec (batch 128) |
| cxxnet | 57 img/sec (batch 16) | 112 img/sec (batch 32) | 224 img/sec (batch 64) |

For Inception-BN network, single model + single center test top-5 accuracy will be round 90%.

Note: MXNet is much more memory efficiency than cxxnet, so we are able to train on larger batch.
98 changes: 98 additions & 0 deletions example/imagenet/inception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# pylint: skip-file
import sys
import mxnet as mx
import logging
from data import ilsvrc12_iterator


logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0), name=None, suffix=''):
conv = mx.symbol.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, name='conv_%s%s' %(name, suffix))
bn = mx.symbol.BatchNorm(data=conv, name='bn_%s%s' %(name, suffix))
act = mx.symbol.Activation(data=bn, act_type='relu', name='relu_%s%s' %(name, suffix))
return act

def InceptionFactoryA(data, num_1x1, num_3x3red, num_3x3, num_d3x3red, num_d3x3, pool, proj, name):
# 1x1
c1x1 = ConvFactory(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_1x1' % name))
# 3x3 reduce + 3x3
c3x3r = ConvFactory(data=data, num_filter=num_3x3red, kernel=(1, 1), name=('%s_3x3' % name), suffix='_reduce')
c3x3 = ConvFactory(data=c3x3r, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), name=('%s_3x3' % name))
# double 3x3 reduce + double 3x3
cd3x3r = ConvFactory(data=data, num_filter=num_d3x3red, kernel=(1, 1), name=('%s_double_3x3' % name), suffix='_reduce')
cd3x3 = ConvFactory(data=cd3x3r, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), name=('%s_double_3x3_0' % name))
cd3x3 = ConvFactory(data=cd3x3, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), name=('%s_double_3x3_1' % name))
# pool + proj
pooling = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
cproj = ConvFactory(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_proj' % name))
# concat
concat = mx.symbol.Concat(*[c1x1, c3x3, cd3x3, cproj], name='ch_concat_%s_chconcat' % name)
return concat

def InceptionFactoryB(data, num_3x3red, num_3x3, num_d3x3red, num_d3x3, name):
# 3x3 reduce + 3x3
c3x3r = ConvFactory(data=data, num_filter=num_3x3red, kernel=(1, 1), name=('%s_3x3' % name), suffix='_reduce')
c3x3 = ConvFactory(data=c3x3r, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name=('%s_3x3' % name))
# double 3x3 reduce + double 3x3
cd3x3r = ConvFactory(data=data, num_filter=num_d3x3red, kernel=(1, 1), name=('%s_double_3x3' % name), suffix='_reduce')
cd3x3 = ConvFactory(data=cd3x3r, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name=('%s_double_3x3_0' % name))
cd3x3 = ConvFactory(data=cd3x3, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name=('%s_double_3x3_1' % name))
# pool + proj
pooling = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type="max", name=('max_pool_%s_pool' % name))
# concat
concat = mx.symbol.Concat(*[c3x3, cd3x3, pooling], name='ch_concat_%s_chconcat' % name)
return concat

def inception(nhidden, grad_scale):
# data
data = mx.symbol.Variable(name="data")
# stage 1
conv1 = ConvFactory(data=data, num_filter=64, kernel=(7, 7), stride=(2, 2), pad=(3, 3), name='conv1')
pool1 = mx.symbol.Pooling(data=conv1, kernel=(3, 3), stride=(2, 2), name='pool1', pool_type='max')
# stage 2
conv2red = ConvFactory(data=pool1, num_filter=64, kernel=(1, 1), stride=(1, 1), name='conv2red')
conv2 = ConvFactory(data=conv2red, num_filter=192, kernel=(3, 3), stride=(1, 1), pad=(1, 1), name='conv2')
pool2 = mx.symbol.Pooling(data=conv2, kernel=(3, 3), stride=(2, 2), name='pool2', pool_type='max')
# stage 2
in3a = InceptionFactoryA(pool2, 64, 64, 64, 64, 96, "avg", 32, '3a')
in3b = InceptionFactoryA(in3a, 64, 64, 96, 64, 96, "avg", 64, '3b')
in3c = InceptionFactoryB(in3b, 128, 160, 64, 96, '3c')
# stage 3
in4a = InceptionFactoryA(in3c, 224, 64, 96, 96, 128, "avg", 128, '4a')
in4b = InceptionFactoryA(in4a, 192, 96, 128, 96, 128, "avg", 128, '4b')
in4c = InceptionFactoryA(in4b, 160, 128, 160, 128, 160, "avg", 128, '4c')
in4d = InceptionFactoryA(in4c, 96, 128, 192, 160, 192, "avg", 128, '4d')
in4e = InceptionFactoryB(in4d, 128, 192, 192, 256, '4e')
# stage 4
in5a = InceptionFactoryA(in4e, 352, 192, 320, 160, 224, "avg", 128, '5a')
in5b = InceptionFactoryA(in5a, 352, 192, 320, 192, 224, "max", 128, '5b')
# global avg pooling
avg = mx.symbol.Pooling(data=in5b, kernel=(7, 7), stride=(1, 1), name="global_pool", pool_type='avg')
# linear classifier
flatten = mx.symbol.Flatten(data=avg, name='flatten')
fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=nhidden, name='fc1')
softmax = mx.symbol.Softmax(data=fc1, name='softmax')
return softmax

softmax = inception(1000, 1.0)

batch_size = 128
num_gpu = 4
gpus = [mx.gpu(i) for i in range(num_gpu)]
input_shape = (3, 224, 224)
softmax = inception(1000, 1.0)

train, val = ilsvrc12_iterator(batch_size=batch_size, input_shape=(3,224,224))

model_prefix = "model/Inception"
num_round = 40


model = mx.model.FeedForward(ctx=gpus, symbol=softmax, num_round=num_round,
learning_rate=0.05, momentum=0.9, wd=0.00001)

model.fit(X=train, eval_data=val,
eval_metric="acc",
epoch_end_callback=mx.callback.Speedometer(batch_size))
11 changes: 8 additions & 3 deletions example/notebooks/cifar-recipe.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"First, let's make some helper function to let us build a simplified Inception Network. More details about how to composite symbol into component can be found at [component demo](composite_symbol.ipynb)"
"First, let's make some helper function to let us build a simplified Inception Network. More details about how to composite symbol into component can be found at [composite_symbol](composite_symbol.ipynb)"
]
},
{
Expand All @@ -48,7 +48,12 @@
"source": [
"# Basic Conv + BN + ReLU factory\n",
"def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0), act_type=\"relu\"):\n",
" conv = mx.symbol.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad)\n",
" # there is an optional parameter ```wrokshpace``` may influece convolution performance\n",
" # default, the workspace is set to 256(MB)\n",
" # you may set larger value, but convolution layer only requires its needed but not exactly\n",
" # MXNet will handle reuse of workspace without parallelism conflict\n",
" conv = mx.symbol.Convolution(data=data, workspace=256,\n",
" num_filter=num_filter, kernel=kernel, stride=stride, pad=pad)\n",
" bn = mx.symbol.BatchNorm(data=conv)\n",
" act = mx.symbol.Activation(data = bn, act_type=act_type)\n",
" return act"
Expand Down Expand Up @@ -448,7 +453,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.4.2"
"version": "3.4.0"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion example/notebooks/predict-with-pretrained-model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.4.2"
"version": "3.4.0"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit cdeb822

Please sign in to comment.