From 60ac3647968e36be809cdaede086cf8cb8cd8fb5 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 19 May 2020 09:33:05 -0700 Subject: [PATCH] Add a Compare() builder that is compatible with omitting broadcast_dimensions with the same ordering as the other binary ops. This helps reduce boilerplate of generated code that seeks to treat all binary ops generically. PiperOrigin-RevId: 312295743 Change-Id: I7d12b26579ef5375394e5980fec3c11c128318f7 --- tensorflow/compiler/xla/client/xla_builder.cc | 4 ++++ tensorflow/compiler/xla/client/xla_builder.h | 5 ++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index a4e5b936153..58365c0f498 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -3188,6 +3188,10 @@ XlaOp Compare(const XlaOp lhs, const XlaOp rhs, broadcast_dimensions, direction); } +XlaOp Compare(const XlaOp lhs, const XlaOp rhs, ComparisonDirection direction) { + return Compare(lhs, rhs, {}, direction); +} + XlaOp Dot(const XlaOp lhs, const XlaOp rhs, const PrecisionConfig* precision_config) { return lhs.builder()->Dot(lhs, rhs, precision_config); diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index b631514248c..426b6d83207 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -889,6 +889,7 @@ class XlaBuilder { friend XlaOp Compare(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions, ComparisonDirection direction); + friend XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction); friend XlaOp Dot(XlaOp lhs, XlaOp rhs, const PrecisionConfig* precision_config); friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs, @@ -1498,10 +1499,12 @@ XlaOp Lt(XlaOp lhs, XlaOp rhs, XlaOp Le(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); -// Enqueues a comparison instruction onto the computation. +// Enqueues a comparison instruction onto the computation (optionally without +// broadcast_dimensions for consistency with others). XlaOp Compare(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions, ComparisonDirection direction); +XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction); // Enqueues a dot instruction onto the computation. XlaOp Dot(XlaOp lhs, XlaOp rhs,