Removed some unnecessary broadcasts in binary ops where only one input needs
broadcasting (which is a fairly common case, even in the fallback path). PiperOrigin-RevId: 172950493
This commit is contained in:
parent
4948379369
commit
a5fe66b151
@ -410,11 +410,21 @@ struct BinaryFunctor<CPUDevice, Functor, NDIMS, false> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback path. Always work and probably slower.
|
// Fallback path. Always works and probably slower.
|
||||||
|
if (AllOne<NDIMS>(bcast0) && AllOne<NDIMS>(bcast1)) {
|
||||||
|
Assign(dev, out, in0.binaryExpr(in1, func));
|
||||||
|
} else if (AllOne<NDIMS>(bcast0)) {
|
||||||
|
auto rhs = in1.broadcast(bcast1);
|
||||||
|
Assign(dev, out, in0.binaryExpr(rhs, func));
|
||||||
|
} else if (AllOne<NDIMS>(bcast1)) {
|
||||||
|
auto lhs = in0.broadcast(bcast0);
|
||||||
|
Assign(dev, out, lhs.binaryExpr(in1, func));
|
||||||
|
} else {
|
||||||
auto lhs = in0.broadcast(bcast0);
|
auto lhs = in0.broadcast(bcast0);
|
||||||
auto rhs = in1.broadcast(bcast1);
|
auto rhs = in1.broadcast(bcast1);
|
||||||
Assign(dev, out, lhs.binaryExpr(rhs, func));
|
Assign(dev, out, lhs.binaryExpr(rhs, func));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Version of BinaryFunctor with error handling.
|
// Version of BinaryFunctor with error handling.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user