Removed some unnecessary broadcasts in binary ops where only one input needs
broadcasting (which is a fairly common case, even in the fallback path). PiperOrigin-RevId: 172950493
This commit is contained in:
parent
4948379369
commit
a5fe66b151
@ -410,10 +410,20 @@ struct BinaryFunctor<CPUDevice, Functor, NDIMS, false> {
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback path. Always work and probably slower.
|
||||
auto lhs = in0.broadcast(bcast0);
|
||||
auto rhs = in1.broadcast(bcast1);
|
||||
Assign(dev, out, lhs.binaryExpr(rhs, func));
|
||||
// Fallback path. Always works and probably slower.
|
||||
if (AllOne<NDIMS>(bcast0) && AllOne<NDIMS>(bcast1)) {
|
||||
Assign(dev, out, in0.binaryExpr(in1, func));
|
||||
} else if (AllOne<NDIMS>(bcast0)) {
|
||||
auto rhs = in1.broadcast(bcast1);
|
||||
Assign(dev, out, in0.binaryExpr(rhs, func));
|
||||
} else if (AllOne<NDIMS>(bcast1)) {
|
||||
auto lhs = in0.broadcast(bcast0);
|
||||
Assign(dev, out, lhs.binaryExpr(in1, func));
|
||||
} else {
|
||||
auto lhs = in0.broadcast(bcast0);
|
||||
auto rhs = in1.broadcast(bcast1);
|
||||
Assign(dev, out, lhs.binaryExpr(rhs, func));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user