Removed some unnecessary broadcasts in binary ops where only one input needs

broadcasting (which is a fairly common case, even in the fallback path).

PiperOrigin-RevId: 172950493
This commit is contained in:
A. Unique TensorFlower 2017-10-20 16:32:24 -07:00 committed by TensorFlower Gardener
parent 4948379369
commit a5fe66b151

View File

@ -410,10 +410,20 @@ struct BinaryFunctor<CPUDevice, Functor, NDIMS, false> {
} }
} }
// Fallback path. Always work and probably slower. // Fallback path. Always works and probably slower.
auto lhs = in0.broadcast(bcast0); if (AllOne<NDIMS>(bcast0) && AllOne<NDIMS>(bcast1)) {
auto rhs = in1.broadcast(bcast1); Assign(dev, out, in0.binaryExpr(in1, func));
Assign(dev, out, lhs.binaryExpr(rhs, func)); } else if (AllOne<NDIMS>(bcast0)) {
auto rhs = in1.broadcast(bcast1);
Assign(dev, out, in0.binaryExpr(rhs, func));
} else if (AllOne<NDIMS>(bcast1)) {
auto lhs = in0.broadcast(bcast0);
Assign(dev, out, lhs.binaryExpr(in1, func));
} else {
auto lhs = in0.broadcast(bcast0);
auto rhs = in1.broadcast(bcast1);
Assign(dev, out, lhs.binaryExpr(rhs, func));
}
} }
}; };