Add a Compare() builder that is compatible with omitting broadcast_dimensions with the same ordering as the other binary ops.
This helps reduce boilerplate of generated code that seeks to treat all binary ops generically. PiperOrigin-RevId: 312295743 Change-Id: I7d12b26579ef5375394e5980fec3c11c128318f7
This commit is contained in:
parent
3114f6b980
commit
60ac364796
|
@ -3188,6 +3188,10 @@ XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
|
|||
broadcast_dimensions, direction);
|
||||
}
|
||||
|
||||
XlaOp Compare(const XlaOp lhs, const XlaOp rhs, ComparisonDirection direction) {
|
||||
return Compare(lhs, rhs, {}, direction);
|
||||
}
|
||||
|
||||
XlaOp Dot(const XlaOp lhs, const XlaOp rhs,
|
||||
const PrecisionConfig* precision_config) {
|
||||
return lhs.builder()->Dot(lhs, rhs, precision_config);
|
||||
|
|
|
@ -889,6 +889,7 @@ class XlaBuilder {
|
|||
friend XlaOp Compare(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions,
|
||||
ComparisonDirection direction);
|
||||
friend XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction);
|
||||
friend XlaOp Dot(XlaOp lhs, XlaOp rhs,
|
||||
const PrecisionConfig* precision_config);
|
||||
friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
|
||||
|
@ -1498,10 +1499,12 @@ XlaOp Lt(XlaOp lhs, XlaOp rhs,
|
|||
XlaOp Le(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions = {});
|
||||
|
||||
// Enqueues a comparison instruction onto the computation.
|
||||
// Enqueues a comparison instruction onto the computation (optionally without
|
||||
// broadcast_dimensions for consistency with others).
|
||||
XlaOp Compare(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions,
|
||||
ComparisonDirection direction);
|
||||
XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction);
|
||||
|
||||
// Enqueues a dot instruction onto the computation.
|
||||
XlaOp Dot(XlaOp lhs, XlaOp rhs,
|
||||
|
|
Loading…
Reference in New Issue