Add a Compare() builder that is compatible with omitting broadcast_dimensions with the same ordering as the other binary ops.

This helps reduce boilerplate of generated code that seeks to treat all binary ops generically.

PiperOrigin-RevId: 312295743
Change-Id: I7d12b26579ef5375394e5980fec3c11c128318f7
This commit is contained in:
Stella Laurenzo 2020-05-19 09:33:05 -07:00 committed by TensorFlower Gardener
parent 3114f6b980
commit 60ac364796
2 changed files with 8 additions and 1 deletions

View File

@ -3188,6 +3188,10 @@ XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
broadcast_dimensions, direction);
}
XlaOp Compare(const XlaOp lhs, const XlaOp rhs, ComparisonDirection direction) {
return Compare(lhs, rhs, {}, direction);
}
XlaOp Dot(const XlaOp lhs, const XlaOp rhs,
const PrecisionConfig* precision_config) {
return lhs.builder()->Dot(lhs, rhs, precision_config);

View File

@ -889,6 +889,7 @@ class XlaBuilder {
friend XlaOp Compare(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> broadcast_dimensions,
ComparisonDirection direction);
friend XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction);
friend XlaOp Dot(XlaOp lhs, XlaOp rhs,
const PrecisionConfig* precision_config);
friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
@ -1498,10 +1499,12 @@ XlaOp Lt(XlaOp lhs, XlaOp rhs,
XlaOp Le(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a comparison instruction onto the computation.
// Enqueues a comparison instruction onto the computation (optionally without
// broadcast_dimensions for consistency with others).
XlaOp Compare(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> broadcast_dimensions,
ComparisonDirection direction);
XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction);
// Enqueues a dot instruction onto the computation.
XlaOp Dot(XlaOp lhs, XlaOp rhs,