Speeds up Softmax by up to 43%, by changing "/ sum" to "* (1/sum)".
Benchmarked using third_party/tensorflow/core/kernels:nn_ops_test. Wall time improves 10-43%: Benchmark Base (ns) New (ns) Improvement ------------------------------------------------------------------ BM_ImageNetSoftmaxFwd_32_1008_1 713325 620705 +13.0% BM_ImageNetSoftmaxFwd_128_1008_1 3097766 2782433 +10.2% BM_ImageNetSoftmaxFwd_32_1008_4 1254561 703238 +43.9% BM_ImageNetSoftmaxFwd_128_1008_4 3225011 2543525 +21.1% CPU time improves 4-17%: Benchmark Base (ns) New (ns) Improvement ------------------------------------------------------------------ BM_ImageNetSoftmaxFwd_32_1008_1 711375 618729 +13.0% BM_ImageNetSoftmaxFwd_128_1008_1 3087158 2779777 +10.0% BM_ImageNetSoftmaxFwd_32_1008_4 959016 795579 +17.0% BM_ImageNetSoftmaxFwd_128_1008_4 3774543 3591573 +4.8% Change: 123074430
This commit is contained in:
parent
26f54a9fcd
commit
989166223c
@ -63,31 +63,34 @@ struct SoftmaxEigenImpl {
|
||||
Eigen::IndexList<Eigen::type2index<1>, int> one_by_class;
|
||||
one_by_class.set(1, num_classes);
|
||||
#endif
|
||||
//shifted_logits = logits - max(logits along classes);
|
||||
auto shifted_logits = (logits - logits.maximum(along_class)
|
||||
.eval()
|
||||
.reshape(batch_by_one)
|
||||
.broadcast(one_by_class));
|
||||
// shifted_logits = logits - max(logits along classes);
|
||||
auto shifted_logits = (logits -
|
||||
logits.maximum(along_class)
|
||||
.eval()
|
||||
.reshape(batch_by_one)
|
||||
.broadcast(one_by_class));
|
||||
if (log) {
|
||||
// Calculate the log of the softmax
|
||||
// softmax = logits - max(logits along classes);
|
||||
softmax.device(d) = shifted_logits;
|
||||
// softmax = softmax - log(sum(exp(softmax along classes)));
|
||||
softmax.device(d) = (softmax -
|
||||
softmax.exp().sum(along_class)
|
||||
.eval()
|
||||
.reshape(batch_by_one)
|
||||
.broadcast(one_by_class)
|
||||
.log());
|
||||
softmax.exp()
|
||||
.sum(along_class)
|
||||
.eval()
|
||||
.reshape(batch_by_one)
|
||||
.broadcast(one_by_class)
|
||||
.log());
|
||||
} else {
|
||||
// NOTE(touts): If you modify this implementation please run
|
||||
// the BM_ImageNetSoftmaxFwd benchmark in nn_ops_test.cc.
|
||||
//
|
||||
// softmax = exp(logits - max(logits along classes));
|
||||
softmax.device(d) = shifted_logits.exp();
|
||||
// softmax = softmax / sum(softmax along classes);
|
||||
softmax.device(d) = (softmax /
|
||||
// softmax = softmax * (1 / sum(softmax along classes));
|
||||
softmax.device(d) = (softmax *
|
||||
softmax.sum(along_class)
|
||||
.inverse()
|
||||
.eval()
|
||||
.reshape(batch_by_one)
|
||||
.broadcast(one_by_class));
|
||||
|
Loading…
Reference in New Issue
Block a user