Speeds up Softmax by up to 43%, by changing "/ sum" to "* (1/sum)".

Benchmarked using third_party/tensorflow/core/kernels:nn_ops_test.

Wall time improves 10-43%:

Benchmark                          Base (ns)  New (ns) Improvement
------------------------------------------------------------------
BM_ImageNetSoftmaxFwd_32_1008_1       713325    620705    +13.0%
BM_ImageNetSoftmaxFwd_128_1008_1     3097766   2782433    +10.2%
BM_ImageNetSoftmaxFwd_32_1008_4      1254561    703238    +43.9%
BM_ImageNetSoftmaxFwd_128_1008_4     3225011   2543525    +21.1%

CPU time improves 4-17%:

Benchmark                          Base (ns)  New (ns) Improvement
------------------------------------------------------------------
BM_ImageNetSoftmaxFwd_32_1008_1       711375    618729    +13.0%
BM_ImageNetSoftmaxFwd_128_1008_1     3087158   2779777    +10.0%
BM_ImageNetSoftmaxFwd_32_1008_4       959016    795579    +17.0%
BM_ImageNetSoftmaxFwd_128_1008_4     3774543   3591573     +4.8%
Change: 123074430
This commit is contained in:
Zongheng Yang 2016-05-23 21:59:21 -08:00 committed by TensorFlower Gardener
parent 26f54a9fcd
commit 989166223c

View File

@ -63,31 +63,34 @@ struct SoftmaxEigenImpl {
Eigen::IndexList<Eigen::type2index<1>, int> one_by_class;
one_by_class.set(1, num_classes);
#endif
//shifted_logits = logits - max(logits along classes);
auto shifted_logits = (logits - logits.maximum(along_class)
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class));
// shifted_logits = logits - max(logits along classes);
auto shifted_logits = (logits -
logits.maximum(along_class)
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class));
if (log) {
// Calculate the log of the softmax
// softmax = logits - max(logits along classes);
softmax.device(d) = shifted_logits;
// softmax = softmax - log(sum(exp(softmax along classes)));
softmax.device(d) = (softmax -
softmax.exp().sum(along_class)
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class)
.log());
softmax.exp()
.sum(along_class)
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class)
.log());
} else {
// NOTE(touts): If you modify this implementation please run
// the BM_ImageNetSoftmaxFwd benchmark in nn_ops_test.cc.
//
// softmax = exp(logits - max(logits along classes));
softmax.device(d) = shifted_logits.exp();
// softmax = softmax / sum(softmax along classes);
softmax.device(d) = (softmax /
// softmax = softmax * (1 / sum(softmax along classes));
softmax.device(d) = (softmax *
softmax.sum(along_class)
.inverse()
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class));