Removed some unnecessary broadcasts in binary ops where only one input needs

broadcasting (which is a fairly common case, even in the fallback path).

PiperOrigin-RevId: 172950493
This commit is contained in:
A. Unique TensorFlower 2017-10-20 16:32:24 -07:00 committed by TensorFlower Gardener
parent 4948379369
commit a5fe66b151

View File

@ -410,10 +410,20 @@ struct BinaryFunctor<CPUDevice, Functor, NDIMS, false> {
}
}
// Fallback path. Always work and probably slower.
auto lhs = in0.broadcast(bcast0);
auto rhs = in1.broadcast(bcast1);
Assign(dev, out, lhs.binaryExpr(rhs, func));
// Fallback path. Always works and probably slower.
if (AllOne<NDIMS>(bcast0) && AllOne<NDIMS>(bcast1)) {
Assign(dev, out, in0.binaryExpr(in1, func));
} else if (AllOne<NDIMS>(bcast0)) {
auto rhs = in1.broadcast(bcast1);
Assign(dev, out, in0.binaryExpr(rhs, func));
} else if (AllOne<NDIMS>(bcast1)) {
auto lhs = in0.broadcast(bcast0);
Assign(dev, out, lhs.binaryExpr(in1, func));
} else {
auto lhs = in0.broadcast(bcast0);
auto rhs = in1.broadcast(bcast1);
Assign(dev, out, lhs.binaryExpr(rhs, func));
}
}
};