From a5fe66b1519668505c0daf5f2d93a4d532cedda1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Oct 2017 16:32:24 -0700 Subject: [PATCH] Removed some unnecessary broadcasts in binary ops where only one input needs broadcasting (which is a fairly common case, even in the fallback path). PiperOrigin-RevId: 172950493 --- tensorflow/core/kernels/cwise_ops_common.h | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/kernels/cwise_ops_common.h b/tensorflow/core/kernels/cwise_ops_common.h index 9a05e1500f5..2454620776f 100644 --- a/tensorflow/core/kernels/cwise_ops_common.h +++ b/tensorflow/core/kernels/cwise_ops_common.h @@ -410,10 +410,20 @@ struct BinaryFunctor { } } - // Fallback path. Always work and probably slower. - auto lhs = in0.broadcast(bcast0); - auto rhs = in1.broadcast(bcast1); - Assign(dev, out, lhs.binaryExpr(rhs, func)); + // Fallback path. Always works and probably slower. + if (AllOne(bcast0) && AllOne(bcast1)) { + Assign(dev, out, in0.binaryExpr(in1, func)); + } else if (AllOne(bcast0)) { + auto rhs = in1.broadcast(bcast1); + Assign(dev, out, in0.binaryExpr(rhs, func)); + } else if (AllOne(bcast1)) { + auto lhs = in0.broadcast(bcast0); + Assign(dev, out, lhs.binaryExpr(in1, func)); + } else { + auto lhs = in0.broadcast(bcast0); + auto rhs = in1.broadcast(bcast1); + Assign(dev, out, lhs.binaryExpr(rhs, func)); + } } };