diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index ee6f7d5956e..07bf937547f 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -57,6 +57,24 @@ xla_proto_library( ], ) +cc_library( + name = "comparison_util", + srcs = [ + "comparison_util.cc", + ], + hdrs = [ + "comparison_util.h", + ], + visibility = [":friends"], + deps = [ + ":statusor", + ":types", + ":util", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + ], +) + cc_library( name = "execution_options_util", srcs = [ diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index f5d56e8a9e1..ae1a459d96d 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -212,6 +212,7 @@ cc_library( ":padding", ":sharding_builder", ":xla_computation", + "//tensorflow/compiler/xla:comparison_util", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 6c6d1a9bd3a..2f574366694 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -480,7 +480,8 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) { } XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { + absl::Span broadcast_dimensions, + absl::optional direction) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -489,6 +490,17 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, ShapeInference::InferBinaryOpShape( binop, lhs_shape, rhs_shape, broadcast_dimensions)); *instr.mutable_shape() = shape.ToProto(); + if (binop == HloOpcode::kCompare) { + if (!direction.has_value()) { + return InvalidArgument( + "kCompare expects a ComparisonDirection, but none provided."); + } + instr.set_comparison_direction(ComparisonDirectionToString(*direction)); + } else if (direction.has_value()) { + return InvalidArgument( + "A comparison direction is provided for a non-compare opcode: %s.", + HloOpcodeString(binop)); + } const int64 lhs_rank = lhs_shape.rank(); const int64 rhs_rank = rhs_shape.rank(); @@ -2908,38 +2920,39 @@ XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index) { XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->BinaryOp(HloOpcode::kEq, lhs, rhs, - broadcast_dimensions); + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq); } XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->BinaryOp(HloOpcode::kNe, lhs, rhs, - broadcast_dimensions); + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe); } XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->BinaryOp(HloOpcode::kGe, lhs, rhs, - broadcast_dimensions); + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe); } XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->BinaryOp(HloOpcode::kGt, lhs, rhs, - broadcast_dimensions); + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt); } XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->BinaryOp(HloOpcode::kLe, lhs, rhs, - broadcast_dimensions); + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe); } XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->BinaryOp(HloOpcode::kLt, lhs, rhs, - broadcast_dimensions); + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt); +} + +XlaOp Compare(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions, + ComparisonDirection direction) { + return lhs.builder()->BinaryOp(HloOpcode::kCompare, lhs, rhs, + broadcast_dimensions, direction); } XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 56e85e394c5..80f93a8b6de 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h" @@ -596,9 +597,11 @@ class XlaBuilder { // Internal helper method that does the building for an arbitrary binary op. // broadcast_dimensions specifies which dimensions to use for broadcasting - // when the operation is between tensors of different ranks. + // when the operation is between tensors of different ranks. The direction is + // only used if opcode is kCompare. XlaOp BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions); + absl::Span broadcast_dimensions, + absl::optional direction = absl::nullopt); // Internal helper method that does the building for an arbitrary ternary op. XlaOp TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, @@ -767,6 +770,9 @@ class XlaBuilder { absl::Span broadcast_dimensions); friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions); + friend XlaOp Compare(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions, + ComparisonDirection direction); friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, const PrecisionConfig* precision_config); friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, @@ -1279,6 +1285,11 @@ XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); +// Enqueues a comparison instruction onto the computation. +XlaOp Compare(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions, + ComparisonDirection direction); + // Enqueues a dot instruction onto the computation. XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, const PrecisionConfig* precision_config = nullptr); diff --git a/tensorflow/compiler/xla/comparison_util.cc b/tensorflow/compiler/xla/comparison_util.cc new file mode 100644 index 00000000000..de34ad678e7 --- /dev/null +++ b/tensorflow/compiler/xla/comparison_util.cc @@ -0,0 +1,57 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/comparison_util.h" +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { + +std::string ComparisonDirectionToString(ComparisonDirection direction) { + switch (direction) { + case ComparisonDirection::kEq: + return "EQ"; + case ComparisonDirection::kNe: + return "NE"; + case ComparisonDirection::kGe: + return "GE"; + case ComparisonDirection::kGt: + return "GT"; + case ComparisonDirection::kLe: + return "LE"; + case ComparisonDirection::kLt: + return "LT"; + } +} + +StatusOr StringToComparisonDirection( + absl::string_view direction_name) { + static auto* direction_map = + new absl::flat_hash_map({ + {"EQ", ComparisonDirection::kEq}, + {"NE", ComparisonDirection::kNe}, + {"GE", ComparisonDirection::kGe}, + {"GT", ComparisonDirection::kGt}, + {"LE", ComparisonDirection::kLe}, + {"LT", ComparisonDirection::kLt}, + }); + auto it = direction_map->find(direction_name); + if (it == direction_map->end()) { + return InvalidArgument("Unknown comparison direction: %s", direction_name); + } + return it->second; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/comparison_util.h b/tensorflow/compiler/xla/comparison_util.h new file mode 100644 index 00000000000..8b150c3cfad --- /dev/null +++ b/tensorflow/compiler/xla/comparison_util.h @@ -0,0 +1,42 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_ + +#include "absl/base/macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +// Represents different comparison operations. +enum class ComparisonDirection : uint8 { + kEq, + kNe, + kGe, + kGt, + kLe, + kLt, +}; + +string ComparisonDirectionToString(ComparisonDirection direction); + +StatusOr StringToComparisonDirection( + absl::string_view direction_name); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index ad65cb291f6..22826918da8 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -334,6 +334,7 @@ cc_library( ":hlo_proto", ":name_uniquer", "//tensorflow/compiler/xla:array", + "//tensorflow/compiler/xla:comparison_util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:protobuf_util", @@ -3301,6 +3302,7 @@ tf_cc_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 069320a24cd..b223fc8b1b5 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -873,9 +873,9 @@ std::unique_ptr TryDivideToShift(HloInstruction* divide, computation, a->shape().element_type(), a->shape().dimensions()); auto* dividend_is_negative = - computation->AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::ChangeElementType(a->shape(), PRED), HloOpcode::kLt, a, - zero_like_a)); + computation->AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::ChangeElementType(a->shape(), PRED), a, zero_like_a, + ComparisonDirection::kLt)); auto* negated_dividend = computation->AddInstruction( HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a)); @@ -2475,9 +2475,9 @@ std::unique_ptr TryRemainderToAnd(HloInstruction* remainder, computation, a->shape().element_type(), a->shape().dimensions()); auto* dividend_is_negative = - computation->AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::ChangeElementType(a->shape(), PRED), HloOpcode::kLt, a, - zero_like_a)); + computation->AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::ChangeElementType(a->shape(), PRED), a, zero_like_a, + ComparisonDirection::kLt)); auto* negated_dividend = computation->AddInstruction( HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a)); diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc index 9c9db74fd2f..b972b1289b9 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc @@ -221,7 +221,7 @@ HloModule foobar %x = (f32[2,2], f32[2,2]) parameter(0) %constant.0 = s32[] constant(0) %constant.1 = s32[] constant(1) - ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0) + ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT } %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { @@ -258,7 +258,7 @@ HloModule foobar %x = (f32[2,2], f32[2,2]) parameter(0) %constant.0 = s32[] constant(0) %constant.1 = s32[] constant(1) - ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0) + ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT } %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { @@ -296,7 +296,7 @@ HloModule foobar %x = (f32[2,2], f32[2,2]) parameter(0) %constant.0 = s32[] constant(0) %constant.1 = s32[] constant(1) - ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0) + ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT } %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index a9b5d9916e4..357d38a5548 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -109,8 +109,8 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) { HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); HloInstruction* add1 = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, b)); - HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {2, 4}), HloOpcode::kEq, a, b)); + HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {2, 4}), a, b, ComparisonDirection::kEq)); HloInstruction* sel = builder.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kSelect, pred, c, add1)); HloInstruction* xpose = @@ -574,8 +574,8 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { HloInstruction::CreateParameter(0, shape, "cond_param")); auto cond_dot = builder_cond.AddInstruction(CreateDot(shape, cond_param, cond_param)); - auto cond_root = builder_cond.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + auto cond_root = builder_cond.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), builder_cond.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {}), builder_cond.AddInstruction( @@ -583,9 +583,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { cond_dot, {0, 0}, {1, 1}, {1, 1})))), builder_cond.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {}), - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2}, - {1, 1})))))); + builder_cond.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond_dot, {1, 1}, {2, 2}, {1, 1})))), + ComparisonDirection::kGt)); auto cond = module->AddEmbeddedComputation(builder_cond.Build()); auto builder_body = HloComputation::Builder("body"); @@ -631,8 +632,8 @@ TEST_F(BFloat16PropagationTest, auto builder_cond = HloComputation::Builder("cond"); auto cond_param = builder_cond.AddInstruction( HloInstruction::CreateParameter(0, shape, "cond_param")); - builder_cond.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + builder_cond.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), builder_cond.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {}), builder_cond.AddInstruction(HloInstruction::CreateSlice( @@ -642,7 +643,8 @@ TEST_F(BFloat16PropagationTest, ShapeUtil::MakeShape(F32, {}), builder_cond.AddInstruction(HloInstruction::CreateSlice( ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {1, 1}, {2, 2}, - {1, 1})))))); + {1, 1})))), + ComparisonDirection::kGt)); auto cond = module->AddEmbeddedComputation(builder_cond.Build()); auto builder_body = HloComputation::Builder("body"); @@ -705,8 +707,8 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { HloInstruction::CreateBinary(shape, HloOpcode::kAdd, cond_rhs, cond_rhs)); auto cond_dot = builder_cond.AddInstruction(CreateDot(shape, cond_lhs, cond_add_rhs)); - builder_cond.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + builder_cond.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), builder_cond.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {}), builder_cond.AddInstruction( @@ -714,9 +716,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { cond_dot, {0, 0}, {1, 1}, {1, 1})))), builder_cond.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {}), - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2}, - {1, 1})))))); + builder_cond.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond_dot, {1, 1}, {2, 2}, {1, 1})))), + ComparisonDirection::kGt)); auto cond = module->AddEmbeddedComputation(builder_cond.Build()); auto builder_body = HloComputation::Builder("body"); @@ -800,8 +803,8 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { shape, HloOpcode::kAdd, cond0_rhs, cond0_rhs)); auto cond0_dot = builder_cond0.AddInstruction(CreateDot(shape, cond0_lhs, cond0_add_rhs)); - builder_cond0.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + builder_cond0.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), builder_cond0.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {}), builder_cond0.AddInstruction( @@ -809,9 +812,10 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { cond0_dot, {0, 0}, {1, 1}, {1, 1})))), builder_cond0.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {}), - builder_cond0.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {1, 1}), cond0_dot, {1, 1}, {2, 2}, - {1, 1})))))); + builder_cond0.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond0_dot, {1, 1}, {2, 2}, {1, 1})))), + ComparisonDirection::kGt)); auto cond0 = module->AddEmbeddedComputation(builder_cond0.Build()); // Condition computation for the second while. @@ -828,8 +832,8 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { shape, HloOpcode::kAdd, cond1_lhs, cond1_lhs)); auto cond1_dot = builder_cond1.AddInstruction(CreateDot(shape, cond1_add_lhs, cond1_rhs)); - builder_cond1.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + builder_cond1.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), builder_cond1.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {}), builder_cond1.AddInstruction( @@ -837,9 +841,10 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { cond1_dot, {0, 0}, {1, 1}, {1, 1})))), builder_cond1.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {}), - builder_cond1.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {1, 1}), cond1_dot, {1, 1}, {2, 2}, - {1, 1})))))); + builder_cond1.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond1_dot, {1, 1}, {2, 2}, {1, 1})))), + ComparisonDirection::kGt)); auto cond1 = module->AddEmbeddedComputation(builder_cond1.Build()); // Body computation shared by both whiles. diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 580bc2f4338..704585033f0 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -190,8 +190,9 @@ class BufferAssignmentTest : public HloTestBase { HloInstruction::CreateParameter(0, t_s32_f32v4_, "x")); auto index = builder.AddInstruction( HloInstruction::CreateGetTupleElement(const4->shape(), param, 0)); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, index, const4)); + builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), index, + const4, ComparisonDirection::kLt)); return builder.Build(); } @@ -1863,8 +1864,8 @@ class WhileBufferAssignmentTest : public HloTestBase { HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); auto ten = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, zero, ten)); + builder.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), zero, ten, ComparisonDirection::kLt)); return builder.Build(); } @@ -2135,8 +2136,9 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(4))); auto param = builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x")); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param, const4)); + builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param, + const4, ComparisonDirection::kLt)); return builder.Build(); }; @@ -2530,7 +2532,7 @@ while_condition { state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0) get-tuple-element = s32[] get-tuple-element(state), index=0 get-tuple-element.1 = s32[] constant(3) - ROOT less-than.339.338 = pred[] less-than(get-tuple-element, get-tuple-element.1) + ROOT less-than.339.338 = pred[] compare(get-tuple-element, get-tuple-element.1), direction=LT } ENTRY entry_computation { diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index 5de724f8924..458aef14999 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -83,8 +83,9 @@ class CallGraphTest : public HloTestBase { HloInstruction::CreateParameter(0, kScalarShape, "param0")); HloInstruction* zero = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero)); + builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0, + zero, ComparisonDirection::kGt)); return builder.Build(); } diff --git a/tensorflow/compiler/xla/service/convolution_group_converter.cc b/tensorflow/compiler/xla/service/convolution_group_converter.cc index f11f9e5fc29..434bbe9ffd5 100644 --- a/tensorflow/compiler/xla/service/convolution_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_group_converter.cc @@ -191,8 +191,9 @@ HloInstruction* GetExpandedFilterMask( // linspace to create a diagonal predicate. Shape predicate_shape = ShapeUtil::MakeShape( PRED, AsInt64Slice(expanded_filter_shape.dimensions())); - return add_instruction(HloInstruction::CreateBinary( - predicate_shape, HloOpcode::kEq, broadcasted_mask1, broadcasted_mask2)); + return add_instruction(HloInstruction::CreateCompare( + predicate_shape, broadcasted_mask1, broadcasted_mask2, + ComparisonDirection::kEq)); } // This function handles batch_group_counts which are relevant only for diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 4391bdcba53..6fa3161e578 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -420,9 +420,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto induction_variable = builder.AddInstruction(HloInstruction::CreateGetTupleElement( limit_const->shape(), loop_state, 0)); - builder.AddInstruction( - HloInstruction::CreateBinary(condition_result_shape_, HloOpcode::kLt, - induction_variable, limit_const)); + builder.AddInstruction(HloInstruction::CreateCompare( + condition_result_shape_, induction_variable, limit_const, + ComparisonDirection::kLt)); return builder.Build(); } @@ -1842,7 +1842,7 @@ HloModule TokensShouldNotBeCopied %param = (s32[], token[]) parameter(0) %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 %constant = s32[] constant(42) - ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) + ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT } ENTRY %TokensShouldNotBeCopied () -> s32[] { @@ -2060,7 +2060,7 @@ if-condition.v4 { p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0 constant.4 = s32[] constant(0) - ROOT equal-to = pred[] equal-to(get-tuple-element.67, constant.4) + ROOT equal-to = pred[] compare(get-tuple-element.67, constant.4), direction=EQ } _functionalize_body_1__.v28 { @@ -2070,7 +2070,7 @@ _functionalize_body_1__.v28 { add.4 = s32[] add(get-tuple-element.68, constant.7) get-tuple-element.69 = s32[] get-tuple-element(arg_tuple.4), index=1 get-tuple-element.70 = s32[] get-tuple-element(arg_tuple.4), index=2 - less-than-or-equal-to = pred[] less-than-or-equal-to(get-tuple-element.69, get-tuple-element.70) + less-than-or-equal-to = pred[] compare(get-tuple-element.69, get-tuple-element.70), direction=LE constant.8 = s32[] constant(0) select = s32[] select(less-than-or-equal-to, constant.8, constant.7) get-tuple-element.71 = s32[] get-tuple-element(arg_tuple.4), index=3 @@ -2087,7 +2087,7 @@ cond_wrapper.v3.1 { inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0) get-tuple-element.75 = s32[] get-tuple-element(inputs.1), index=0 constant.11 = s32[] constant(7) - ROOT less-than.2 = pred[] less-than(get-tuple-element.75, constant.11) + ROOT less-than.2 = pred[] compare(get-tuple-element.75, constant.11), direction=LT } _functionalize_body_2__.v25 { @@ -2110,7 +2110,7 @@ cond_wrapper.v3.2 { inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) get-tuple-element.83 = s32[] get-tuple-element(inputs.2), index=1 constant.13 = s32[] constant(5) - ROOT less-than.3 = pred[] less-than(get-tuple-element.83, constant.13) + ROOT less-than.3 = pred[] compare(get-tuple-element.83, constant.13), direction=LT } ENTRY TestComputation { @@ -2142,7 +2142,7 @@ if-condition.v4 { p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0 constant.4 = s32[] constant(0) - ROOT equal-to = pred[] equal-to(get-tuple-element.67, constant.4) + ROOT equal-to = pred[] compare(get-tuple-element.67, constant.4), direction=EQ } if-body.v5.1 { @@ -2159,7 +2159,7 @@ if-condition.v4.1 { p.4 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) get-tuple-element.71 = s32[] get-tuple-element(p.4), index=0 constant.6 = s32[] constant(1) - ROOT equal-to.1 = pred[] equal-to(get-tuple-element.71, constant.6) + ROOT equal-to.1 = pred[] compare(get-tuple-element.71, constant.6), direction=EQ } _functionalize_body_1__.v28 { @@ -2169,7 +2169,7 @@ _functionalize_body_1__.v28 { add.4 = s32[] add(get-tuple-element.72, constant.7) get-tuple-element.73 = s32[] get-tuple-element(arg_tuple.4), index=1 get-tuple-element.74 = s32[] get-tuple-element(arg_tuple.4), index=2 - less-than-or-equal-to = pred[] less-than-or-equal-to(get-tuple-element.73, get-tuple-element.74) + less-than-or-equal-to = pred[] compare(get-tuple-element.73, get-tuple-element.74), direction=LE constant.8 = s32[] constant(0) select = s32[] select(less-than-or-equal-to, constant.8, constant.7) get-tuple-element.75 = s32[] get-tuple-element(arg_tuple.4), index=3 @@ -2187,7 +2187,7 @@ cond_wrapper.v3.1 { inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0) get-tuple-element.78 = s32[] get-tuple-element(inputs.1), index=0 constant.11 = s32[] constant(7) - ROOT less-than.2 = pred[] less-than(get-tuple-element.78, constant.11) + ROOT less-than.2 = pred[] compare(get-tuple-element.78, constant.11), direction=LT } _functionalize_body_2__.v25 { @@ -2210,7 +2210,7 @@ cond_wrapper.v3.2 { inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) get-tuple-element.86 = s32[] get-tuple-element(inputs.2), index=1 constant.13 = s32[] constant(5) - ROOT less-than.3 = pred[] less-than(get-tuple-element.86, constant.13) + ROOT less-than.3 = pred[] compare(get-tuple-element.86, constant.13), direction=LT } ENTRY TestComputation { diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc index 762ee67db9a..e07ac9edc89 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc @@ -29,7 +29,7 @@ HloModule KeyValueSort compare { p.0.lhs = f32[] parameter(0) p.0.rhs = f32[] parameter(1) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY main { diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc index 70173d43d79..bd638917ccf 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc @@ -75,7 +75,7 @@ ENTRY TestComputation { broadcast = f32[42] broadcast(add), dimensions={} slice = f32[1] slice(broadcast), slice={[1:2]} copy = f32[] copy(arg) - eq = pred[] equal-to(arg, gte) + eq = pred[] compare(arg, gte), direction=EQ neg = f32[] negate(arg) ROOT convert = f64[] convert(f32[] arg) })"; diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc index b3d93737427..2cd2ef27279 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc @@ -68,8 +68,8 @@ class DynamicDimensionInferenceTest : public HloTestBase { 0, ShapeUtil::MakeShape(F32, {}), "lhs")); auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( 1, ShapeUtil::MakeShape(F32, {}), "rhs")); - embedded_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGe, lhs, rhs)); + embedded_builder.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), lhs, rhs, ComparisonDirection::kGe)); return module_->AddEmbeddedComputation(embedded_builder.Build()); } diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc index a982dad95c7..9e1efa44299 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder.cc @@ -160,9 +160,10 @@ StatusOr DynamicPadder::Run(HloModule* module) { HloInstruction* broadcasted_effective_size = computation->AddInstruction(HloInstruction::CreateBroadcast( mask_shape, dynamic_size, {})); - HloInstruction* pred = computation->AddInstruction( - HloInstruction::CreateBinary(pred_shape, HloOpcode::kLt, iota, - broadcasted_effective_size)); + HloInstruction* pred = + computation->AddInstruction(HloInstruction::CreateCompare( + pred_shape, iota, broadcasted_effective_size, + ComparisonDirection::kLt)); HloInstruction* broadcasted_identity_value = computation->AddInstruction(HloInstruction::CreateBroadcast( diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index df556a9d6c7..2e54a278fa8 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -719,25 +719,28 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( // We use ordered comparisons for everything except kNe, where we use an // unordered comparison. This makes x != y equivalent to !(x == y), and // matches C++'s semantics. - case HloOpcode::kEq: - return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value, - rhs_value, b_); - case HloOpcode::kNe: - return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value, - rhs_value, b_); - case HloOpcode::kLt: - return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value, - rhs_value, b_); - case HloOpcode::kGt: - return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value, - rhs_value, b_); - case HloOpcode::kLe: - return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value, - rhs_value, b_); - case HloOpcode::kGe: - return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value, - rhs_value, b_); - + case HloOpcode::kCompare: { + switch (op->comparison_direction()) { + case ComparisonDirection::kEq: + return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value, + rhs_value, b_); + case ComparisonDirection::kNe: + return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value, + rhs_value, b_); + case ComparisonDirection::kLt: + return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value, + rhs_value, b_); + case ComparisonDirection::kGt: + return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value, + rhs_value, b_); + case ComparisonDirection::kLe: + return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value, + rhs_value, b_); + case ComparisonDirection::kGe: + return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value, + rhs_value, b_); + } + } case HloOpcode::kMaximum: return EmitFloatMax(lhs_value, rhs_value); case HloOpcode::kMinimum: @@ -839,21 +842,28 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( // We use ordered comparisons for everything except kNe, where we use an // unordered comparison. This makes x != y equivalent to !(x == y), and // matches C++'s semantics. - case HloOpcode::kEq: - return And(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, - EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value), b_), - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, - EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value), b_)); - case HloOpcode::kNe: - return Or(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, - EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value), b_), - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, - EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value), b_)); - + case HloOpcode::kCompare: { + switch (op->comparison_direction()) { + case ComparisonDirection::kEq: + return And(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, + EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value), b_), + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, + EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value), b_)); + case ComparisonDirection::kNe: + return Or(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, + EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value), b_), + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, + EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value), b_)); + default: + return Unimplemented( + "complex comparison '%s'", + ComparisonDirectionToString(op->comparison_direction())); + } + } case HloOpcode::kPower: { auto a = EmitExtractReal(lhs_value); auto b = EmitExtractImag(lhs_value); @@ -1278,28 +1288,32 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( return EmitIntegerDivide(lhs_value, rhs_value, is_signed); case HloOpcode::kRemainder: return EmitIntegerRemainder(lhs_value, rhs_value, is_signed); - case HloOpcode::kEq: - return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value, - rhs_value, b_); - case HloOpcode::kNe: - return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value, - rhs_value, b_); - case HloOpcode::kLt: - return llvm_ir::EmitComparison( - is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT, - lhs_value, rhs_value, b_); - case HloOpcode::kGt: - return llvm_ir::EmitComparison( - is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT, - lhs_value, rhs_value, b_); - case HloOpcode::kLe: - return llvm_ir::EmitComparison( - is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE, - lhs_value, rhs_value, b_); - case HloOpcode::kGe: - return llvm_ir::EmitComparison( - is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE, - lhs_value, rhs_value, b_); + case HloOpcode::kCompare: { + switch (op->comparison_direction()) { + case ComparisonDirection::kEq: + return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value, + rhs_value, b_); + case ComparisonDirection::kNe: + return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value, + rhs_value, b_); + case ComparisonDirection::kLt: + return llvm_ir::EmitComparison( + is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT, + lhs_value, rhs_value, b_); + case ComparisonDirection::kGt: + return llvm_ir::EmitComparison( + is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT, + lhs_value, rhs_value, b_); + case ComparisonDirection::kLe: + return llvm_ir::EmitComparison( + is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE, + lhs_value, rhs_value, b_); + case ComparisonDirection::kGe: + return llvm_ir::EmitComparison( + is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE, + lhs_value, rhs_value, b_); + } + } case HloOpcode::kMinimum: return EmitIntegralMin(lhs_value, rhs_value, is_signed); case HloOpcode::kMaximum: @@ -2197,17 +2211,12 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kAdd: case HloOpcode::kAnd: case HloOpcode::kAtan2: + case HloOpcode::kCompare: case HloOpcode::kComplex: case HloOpcode::kDivide: - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kMultiply: - case HloOpcode::kNe: case HloOpcode::kOr: case HloOpcode::kXor: case HloOpcode::kPower: diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc index 8eeb930b481..ef35311b08b 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -81,8 +81,9 @@ class FlattenCallGraphTest : public HloTestBase { HloInstruction::CreateParameter(0, kScalarShape, "param0")); HloInstruction* zero = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero)); + builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0, + zero, ComparisonDirection::kGt)); return builder.Build(); } @@ -158,9 +159,9 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { 0, ShapeUtil::MakeShape(PRED, {}), "param0")); HloInstruction* false_constant = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); - builder.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), - HloOpcode::kEq, param0, false_constant)); + builder.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), param0, false_constant, + ComparisonDirection::kEq)); cond_computation = module->AddEmbeddedComputation(builder.Build()); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc index 15d4ee206ce..ee64b3a7596 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc @@ -62,7 +62,7 @@ TEST_F(GpuFusibleTest, copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1) c0 = f16[] constant(0) broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={} - greater-than = pred[128,1024,32,32]{1,3,2,0} greater-than(copy, broadcast) + greater-than = pred[128,1024,32,32]{1,3,2,0} compare(copy, broadcast), direction=GT ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast) } fused_reduce { @@ -122,7 +122,7 @@ TEST_F(GpuFusibleTest, p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1) c0 = f16[] constant(0) broadcast = f16[128,1024,32,32]{3,2,1,0} broadcast(c0), dimensions={} - greater-than = pred[128,1024,32,32]{3,2,1,0} greater-than(p1.1, broadcast) + greater-than = pred[128,1024,32,32]{3,2,1,0} compare(p1.1, broadcast), direction=GT select = f16[128,1024,32,32]{3,2,1,0} select(greater-than, p0.1, broadcast) ROOT root = f16[128,1024,32,32]{1,3,2,0} copy(select) } @@ -507,7 +507,7 @@ TEST_F(GpuFusibleTest, p1.1 = f32[2,2,2]{2,1,0} parameter(1) c0 = f32[] constant(0) broadcast = f32[2,2,2]{2,1,0} broadcast(f32[] c0), dimensions={} - greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast) + greater-than = pred[2,2,2]{2,1,0} compare(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast), direction=GT p0.1 = f32[2,2,2]{2,1,0} parameter(0) ROOT select = f32[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f32[2,2,2]{2,1,0} p0.1, f32[2,2,2]{2,1,0} broadcast) } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 391029e5746..3630c3e38c5 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -374,7 +374,7 @@ TEST_F(LayoutAssignmentTest, SortLayout) { p.0.rhs = f32[] parameter(1) p.1.lhs = f32[] parameter(2) p.1.rhs = f32[] parameter(3) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY sort { diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index 40b87b16a19..4b78d48210a 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -437,7 +437,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) { p1.1 = f32[2,2,2]{2,1,0} parameter(1) c0 = f32[] constant(0) broadcast = f32[2,2,2]{2,1,0} broadcast(f32[] c0), dimensions={} - greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast) + greater-than = pred[2,2,2]{2,1,0} compare(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast), direction=GT p0.1 = f32[2,2,2]{2,1,0} parameter(0) ROOT select = f32[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f32[2,2,2]{2,1,0} p0.1, f32[2,2,2]{2,1,0} broadcast) } @@ -505,7 +505,7 @@ TEST_F(MultiOutputFusionTest, p1.1 = f16[2,2,2]{2,1,0} parameter(1) c0 = f16[] constant(0) broadcast = f16[2,2,2]{2,1,0} broadcast(f16[] c0), dimensions={} - greater-than = pred[2,2,2]{2,1,0} greater-than(f16[2,2,2]{2,1,0} p1.1, f16[2,2,2]{2,1,0} broadcast) + greater-than = pred[2,2,2]{2,1,0} compare(f16[2,2,2]{2,1,0} p1.1, f16[2,2,2]{2,1,0} broadcast), direction=GT p0.1 = f16[2,2,2]{2,1,0} parameter(0) ROOT select = f16[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f16[2,2,2]{2,1,0} p0.1, f16[2,2,2]{2,1,0} broadcast) } @@ -548,7 +548,7 @@ TEST_F(MultiOutputFusionTest, copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1) c0 = f16[] constant(0) broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={} - greater-than = pred[128,1024,32,32]{1,3,2,0} greater-than(copy, broadcast) + greater-than = pred[128,1024,32,32]{1,3,2,0} compare(copy, broadcast), direction=GT ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast) } fused_reduce { diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc index 6814be779e0..963716e7050 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc @@ -48,8 +48,9 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndex) { HloInstruction::CreateParameter(0, param_shape, "x")); HloInstruction* param_y = builder.AddInstruction( HloInstruction::CreateParameter(1, param_shape, "y")); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {5, 7, 2}), HloOpcode::kGe, param_x, param_y)); + builder.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {5, 7, 2}), param_x, param_y, + ComparisonDirection::kGe)); auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(builder.Build()); @@ -73,7 +74,7 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithReshape) { x = f32[5,7,2]{2,1,0} parameter(0) y = f32[5,14]{1,0} parameter(1) reshape = f32[5,7,2]{2,1,0} reshape(y) - ROOT gte = pred[5,7,2]{2,1,0} greater-than-or-equal-to(x, reshape) + ROOT gte = pred[5,7,2]{2,1,0} compare(x, reshape), direction=GE })", config) .ValueOrDie(); @@ -98,7 +99,7 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithReshapeAndBroadcast) { y = f32[14]{0} parameter(1) reshape = f32[7,2]{1,0} reshape(y) broadcast = f32[5,7,2]{2,1,0} broadcast(reshape), dimensions={1,2} - ROOT gte = pred[5,7,2]{2,1,0} greater-than-or-equal-to(x, broadcast) + ROOT gte = pred[5,7,2]{2,1,0} compare(x, broadcast), direction=GE })", config) .ValueOrDie(); diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index 77e49f0e46b..64a5fe5fdd2 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -44,9 +44,9 @@ class WhileTransformerTest : public HloTestBase { auto induction_variable = builder.AddInstruction(HloInstruction::CreateGetTupleElement( limit_const->shape(), loop_state, tuple_index)); - builder.AddInstruction( - HloInstruction::CreateBinary(condition_result_shape_, HloOpcode::kLt, - induction_variable, limit_const)); + builder.AddInstruction(HloInstruction::CreateCompare( + condition_result_shape_, induction_variable, limit_const, + ComparisonDirection::kLt)); return builder.Build(); } diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index dc40b9446ad..2f162803820 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -54,8 +54,8 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte) HloInstruction* cond_lt = cond_builder.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), - HloOpcode::kLt, cond_iter, cond_data)); + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter, + cond_data, ComparisonDirection::kLt)); HloComputation* cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); @@ -113,7 +113,8 @@ TEST_F(MinimumMemoryForSequenceTest, SubcomputationAccounting) { // %slice = f32[1]{0} slice(f32[4]{0} %cond_param), slice={[0:1]} // %reshape = f32[] reshape(f32[1]{0} %slice) // %constant = f32[] constant(0) - // ROOT %not-equal-to = pred[] not-equal-to(f32[] %reshape, f32[] %constant) + // ROOT %not-equal-to = pred[] compare(f32[] %reshape, f32[] %constant), + // direction=NE // } // ENTRY %SubcomputationAccounting () -> f32[2,4] { @@ -143,9 +144,9 @@ TEST_F(MinimumMemoryForSequenceTest, SubcomputationAccounting) { cond_builder.AddInstruction(HloInstruction::CreateReshape(r0f32, slice)); HloInstruction* zero = cond_builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); - HloInstruction* cond_comparison = - cond_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, reshape, zero)); + HloInstruction* cond_comparison = cond_builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), reshape, + zero, ComparisonDirection::kNe)); auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); // param - 1 @@ -703,8 +704,8 @@ TEST_F(HeapSimulatorTest, WholeModule) { HloInstruction* cond_data = cond_builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); HloInstruction* cond_lt = cond_builder.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), - HloOpcode::kLt, cond_iter, cond_data)); + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter, + cond_data, ComparisonDirection::kLt)); HloComputation* cond_computation = tracker.module()->AddEmbeddedComputation(cond_builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 1413ce3062d..54ee92943cc 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 63 +// Next ID: 64 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -146,6 +146,9 @@ message HloInstructionProto { // FFT length. repeated int64 fft_length = 32; + // Comparison direction only used for kCompare. + string comparison_direction = 63; + // Gather dimension numbers. xla.GatherDimensionNumbers gather_dimension_numbers = 33; repeated int64 gather_slice_sizes = 34; diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index fe37ca6b396..3fa6f80b1b9 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -509,8 +509,9 @@ TEST_F(HloComputationTest, CloneWithReplacements) { HloInstruction::CreateParameter(1, r0f32_, "p.0.rhs")); auto param2 = builder.AddInstruction(HloInstruction::CreateParameter(2, r0s64, "p.1")); - auto lt = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param0, param1)); + auto lt = builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0, + param1, ComparisonDirection::kLt)); auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build(/*root_instruction=*/lt)); diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index b5d9e8e7f1a..d9c5f7c66de 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -42,6 +42,18 @@ StatusOr MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs, HloInstruction::CreateBinary(binary_op_shape, opcode, lhs, rhs)); } +StatusOr MakeCompareHlo(ComparisonDirection direction, + HloInstruction* lhs, + HloInstruction* rhs) { + HloComputation* computation = lhs->parent(); + CHECK_EQ(computation, rhs->parent()); + TF_ASSIGN_OR_RETURN( + Shape binary_op_shape, + ShapeInference::InferBinaryOpShape(HloOpcode::kCompare, lhs, rhs)); + return computation->AddInstruction( + HloInstruction::CreateCompare(binary_op_shape, lhs, rhs, direction)); +} + StatusOr MakePadHlo(HloInstruction* operand, HloInstruction* padding_value, const PaddingConfig& padding_config) { diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 17b7a2da6a9..f163112f7ff 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -32,6 +32,12 @@ namespace xla { StatusOr MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs, HloInstruction* rhs); +// Creates a compare HLO instruction and adds it to the computation containing +// `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation). +StatusOr MakeCompareHlo(ComparisonDirection direction, + HloInstruction* lhs, + HloInstruction* rhs); + // Creates a pad HLO instruction and adds it to the computation containing // `operand` and `padding_value` (`operand` and `padding_value` must be in the // same computation). diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 13027fd5463..12fbcdb4a0e 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -2239,8 +2239,8 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { HloInstruction::CreateParameter(0, in_shape, "param0")); auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, in_shape, "param1")); - auto result = builder.AddInstruction( - HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1)); + auto result = builder.AddInstruction(HloInstruction::CreateCompare( + out_shape, param0, param1, ComparisonDirection::kEq)); BuildModuleAndRunAnalysis(builder.Build()); @@ -2563,8 +2563,8 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { auto builder = HloComputation::Builder(TestName() + ".Cond"); auto data = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "data")); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data)); + builder.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), data, data, ComparisonDirection::kEq)); return builder.Build(); }; diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index b5d72b386f8..d0073237ac2 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -223,8 +223,9 @@ TEST_F(HloDceTest, CalledComputationWithSideEffect) { HloInstruction::CreateParameter(0, shape, "cond_param")); auto constant = cond_builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); - cond_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param, constant)); + cond_builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param, + constant, ComparisonDirection::kLt)); } auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 8e323068d18..d920484973b 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -56,43 +56,40 @@ namespace xla { namespace { template -StatusOr Compare(const Shape& shape, HloOpcode opcode, +StatusOr Compare(const Shape& shape, ComparisonDirection direction, LiteralSlice lhs_literal, LiteralSlice rhs_literal) { std::function compare_op; - switch (opcode) { - case HloOpcode::kEq: + switch (direction) { + case ComparisonDirection::kEq: compare_op = [](OperandT lhs_el, OperandT rhs_el) { return lhs_el == rhs_el; }; break; - case HloOpcode::kNe: + case ComparisonDirection::kNe: compare_op = [](OperandT lhs_el, OperandT rhs_el) { return lhs_el != rhs_el; }; break; - case HloOpcode::kGe: + case ComparisonDirection::kGe: compare_op = [](OperandT lhs_el, OperandT rhs_el) { return lhs_el >= rhs_el; }; break; - case HloOpcode::kGt: + case ComparisonDirection::kGt: compare_op = [](OperandT lhs_el, OperandT rhs_el) { return lhs_el > rhs_el; }; break; - case HloOpcode::kLe: + case ComparisonDirection::kLe: compare_op = [](OperandT lhs_el, OperandT rhs_el) { return lhs_el <= rhs_el; }; break; - case HloOpcode::kLt: + case ComparisonDirection::kLt: compare_op = [](OperandT lhs_el, OperandT rhs_el) { return lhs_el < rhs_el; }; break; - default: - LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: " - << HloOpcodeString(opcode); } Literal result(shape); @@ -106,24 +103,25 @@ StatusOr Compare(const Shape& shape, HloOpcode opcode, } template <> -StatusOr Compare(const Shape& shape, HloOpcode opcode, +StatusOr Compare(const Shape& shape, + ComparisonDirection direction, LiteralSlice lhs_literal, LiteralSlice rhs_literal) { std::function compare_op; - switch (opcode) { - case HloOpcode::kEq: + switch (direction) { + case ComparisonDirection::kEq: compare_op = [](complex64 lhs_el, complex64 rhs_el) { return lhs_el == rhs_el; }; break; - case HloOpcode::kNe: + case ComparisonDirection::kNe: compare_op = [](complex64 lhs_el, complex64 rhs_el) { return lhs_el != rhs_el; }; break; default: - LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: " - << HloOpcodeString(opcode); + LOG(FATAL) << "unhandled direction for conversion to Comparison: " + << ComparisonDirectionToString(direction); } Literal result(shape); @@ -137,24 +135,25 @@ StatusOr Compare(const Shape& shape, HloOpcode opcode, } template <> -StatusOr Compare(const Shape& shape, HloOpcode opcode, +StatusOr Compare(const Shape& shape, + ComparisonDirection direction, LiteralSlice lhs_literal, LiteralSlice rhs_literal) { std::function compare_op; - switch (opcode) { - case HloOpcode::kEq: + switch (direction) { + case ComparisonDirection::kEq: compare_op = [](complex128 lhs_el, complex128 rhs_el) { return lhs_el == rhs_el; }; break; - case HloOpcode::kNe: + case ComparisonDirection::kNe: compare_op = [](complex128 lhs_el, complex128 rhs_el) { return lhs_el != rhs_el; }; break; default: - LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: " - << HloOpcodeString(opcode); + LOG(FATAL) << "unhandled direction for conversion to Comparison: " + << ComparisonDirectionToString(direction); } Literal result(shape); @@ -671,7 +670,7 @@ Status HloEvaluator::HandleComplex(HloInstruction* complex) { } Status HloEvaluator::HandleCompare(HloInstruction* compare) { - HloOpcode opcode = compare->opcode(); + ComparisonDirection direction = compare->comparison_direction(); auto lhs = compare->operand(0); auto rhs = compare->operand(1); DCHECK(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) && @@ -687,76 +686,76 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) { case PRED: { TF_ASSIGN_OR_RETURN( evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + Compare(compare->shape(), direction, lhs_literal, rhs_literal)); } break; case U8: { - TF_ASSIGN_OR_RETURN( - evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), direction, + lhs_literal, rhs_literal)); } break; case U16: { - TF_ASSIGN_OR_RETURN( - evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), direction, + lhs_literal, rhs_literal)); } break; case U32: { - TF_ASSIGN_OR_RETURN( - evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), direction, + lhs_literal, rhs_literal)); } break; case U64: { - TF_ASSIGN_OR_RETURN( - evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), direction, + lhs_literal, rhs_literal)); } break; case S8: { TF_ASSIGN_OR_RETURN( evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + Compare(compare->shape(), direction, lhs_literal, rhs_literal)); } break; case S16: { - TF_ASSIGN_OR_RETURN( - evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), direction, + lhs_literal, rhs_literal)); } break; case S32: { - TF_ASSIGN_OR_RETURN( - evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), direction, + lhs_literal, rhs_literal)); } break; case S64: { - TF_ASSIGN_OR_RETURN( - evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), direction, + lhs_literal, rhs_literal)); } break; case F16: { TF_ASSIGN_OR_RETURN( evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + Compare(compare->shape(), direction, lhs_literal, rhs_literal)); } break; case BF16: { TF_ASSIGN_OR_RETURN(evaluated_[compare], - Compare(compare->shape(), opcode, + Compare(compare->shape(), direction, lhs_literal, rhs_literal)); } break; case F32: { - TF_ASSIGN_OR_RETURN( - evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), direction, + lhs_literal, rhs_literal)); } break; case F64: { - TF_ASSIGN_OR_RETURN( - evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), direction, + lhs_literal, rhs_literal)); } break; case C64: { TF_ASSIGN_OR_RETURN(evaluated_[compare], - Compare(compare->shape(), opcode, + Compare(compare->shape(), direction, lhs_literal, rhs_literal)); } break; case C128: { TF_ASSIGN_OR_RETURN(evaluated_[compare], - Compare(compare->shape(), opcode, + Compare(compare->shape(), direction, lhs_literal, rhs_literal)); } break; default: diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 383921fde22..79cb6055fd5 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -2848,8 +2848,15 @@ TEST_F(HloEvaluatorTest, DoesCompareBF16) { {bfloat16(0.25), bfloat16(-0.375), bfloat16(-0.127)}}); auto expected = LiteralUtil::CreateR2({{false, true, true}, {false, true, true}}); - TestBinaryOp(HloOpcode::kGe, std::move(expected), std::move(lhs), - std::move(rhs)); + + HloComputation::Builder b(TestName()); + auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs))); + auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs))); + b.AddInstruction(HloInstruction::CreateCompare(expected.shape(), c1, c2, + ComparisonDirection::kGe)); + m_->AddEntryComputation(b.Build()); + + EXPECT_TRUE(LiteralTestUtil::Equal(expected, Evaluate())); } TEST_P(HloEvaluatorBf16Test, Bf16Reduction) { diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 9623edcf5eb..89e8fe3960e 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -258,14 +258,16 @@ optional MatchTrivialComputation(const HloComputation* computation) { // param0), check that the operation being performed is commutative. if (root->operand(0) == param1) { CHECK_EQ(root->operand(1), param0); - switch (root->opcode()) { - case HloOpcode::kLe: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLt: - return nullopt; - default: - break; + if (root->opcode() == HloOpcode()) { + switch (root->comparison_direction()) { + case ComparisonDirection::kLe: + case ComparisonDirection::kGe: + case ComparisonDirection::kGt: + case ComparisonDirection::kLt: + return nullopt; + default: + break; + } } } @@ -279,18 +281,22 @@ optional MatchTrivialComputation(const HloComputation* computation) { return "min"; case HloOpcode::kMaximum: return "max"; - case HloOpcode::kLe: - return "less-or-equal"; - case HloOpcode::kGe: - return "greater-or-equal"; - case HloOpcode::kGt: - return "greater-than"; - case HloOpcode::kLt: - return "less-than"; - case HloOpcode::kEq: - return "equal-to"; - case HloOpcode::kNe: - return "not-equal-to"; + case HloOpcode::kCompare: { + switch (root->comparison_direction()) { + case ComparisonDirection::kLe: + return "less-or-equal"; + case ComparisonDirection::kGe: + return "greater-or-equal"; + case ComparisonDirection::kGt: + return "greater-than"; + case ComparisonDirection::kLt: + return "less-than"; + case ComparisonDirection::kEq: + return "equal-to"; + case ComparisonDirection::kNe: + return "not-equal-to"; + } + } default: return nullopt; } @@ -922,27 +928,22 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kCeil: case HloOpcode::kClamp: case HloOpcode::kClz: + case HloOpcode::kCompare: case HloOpcode::kComplex: case HloOpcode::kConvert: case HloOpcode::kCos: case HloOpcode::kDivide: - case HloOpcode::kEq: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: - case HloOpcode::kGe: - case HloOpcode::kGt: case HloOpcode::kImag: case HloOpcode::kIota: case HloOpcode::kIsFinite: - case HloOpcode::kLe: case HloOpcode::kLog: case HloOpcode::kLog1p: - case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kMultiply: - case HloOpcode::kNe: case HloOpcode::kNegate: case HloOpcode::kNot: case HloOpcode::kOr: diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 064c53252c0..f92759c56f6 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla.pb.h" @@ -31,6 +32,8 @@ namespace { using absl::StrCat; using ::testing::HasSubstr; +using HloGraphDumperTest = HloTestBase; + string TestName() { return ::testing::UnitTest::GetInstance()->current_test_info()->name(); } @@ -48,7 +51,7 @@ class DotRenderer : public hlo_graph_dumper::GraphRendererInterface { XLA_REGISTER_GRAPH_RENDERER(DotRenderer); -TEST(HloGraphDumperTest, NestedFusion) { +TEST_F(HloGraphDumperTest, NestedFusion) { HloComputation::Builder b("b"); // Build param0 + param1 + param2 + param3 + param4. @@ -118,7 +121,7 @@ TEST(HloGraphDumperTest, NestedFusion) { HasSubstr(inner_sum->name())); } -TEST(HloGraphDumperTest, Constant) { +TEST_F(HloGraphDumperTest, Constant) { HloComputation::Builder b("b"); auto instruction = b.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(-42))); @@ -132,7 +135,7 @@ TEST(HloGraphDumperTest, Constant) { EXPECT_THAT(graph, Not(HasSubstr("i_am_a_constant_root_instruction"))); } -TEST(HloGraphDumperTest, TupleConstant) { +TEST_F(HloGraphDumperTest, TupleConstant) { Shape tuple_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(S32, {4, 5})}); HloComputation::Builder b("b"); @@ -150,5 +153,21 @@ TEST(HloGraphDumperTest, TupleConstant) { EXPECT_THAT(graph, HasSubstr("constant (f32[3,2], s32[4,5])")); } +TEST_F(HloGraphDumperTest, Compare) { + const char* hlo_string = R"( + HloModule comp + + ENTRY comp { + param.0 = f32[10] parameter(0) + param.1 = f32[10] parameter(1) + ROOT lt = pred[10] compare(param.0, param.1), direction=LT + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + string graph = hlo_graph_dumper::DumpGraph(*module->entry_computation(), + /*label=*/"comp", DebugOptions()); + EXPECT_THAT(graph, HasSubstr("direction=LT")); +} + } // anonymous namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 55108b3d772..fe8a178f80f 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -64,7 +64,35 @@ StatusOr> HloInstruction::CreateFromProto( const absl::flat_hash_map& instruction_map, const absl::flat_hash_map& computation_map) { TF_RET_CHECK(!proto.opcode().empty()); - TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode())); + HloOpcode opcode; + auto opcode_or = StringToHloOpcode(proto.opcode()); + absl::optional comparison_direction; + if (opcode_or.ok()) { + opcode = opcode_or.ConsumeValueOrDie(); + } else { + // Unknown opcode. Try auto-upgrading deprecated "less-than", + // "greater-than", etc opcodes, which are now rolled into the kCompare + // opcode. + if (proto.opcode() == "equal-to") { + comparison_direction = ComparisonDirection::kEq; + } else if (proto.opcode() == "not-equal-to") { + comparison_direction = ComparisonDirection::kNe; + } else if (proto.opcode() == "greater-than-or-equal-to") { + comparison_direction = ComparisonDirection::kGe; + } else if (proto.opcode() == "greater-than") { + comparison_direction = ComparisonDirection::kGt; + } else if (proto.opcode() == "less-than-or-equal-to") { + comparison_direction = ComparisonDirection::kLe; + } else if (proto.opcode() == "less-than") { + comparison_direction = ComparisonDirection::kLt; + } + if (comparison_direction) { + opcode = HloOpcode::kCompare; + } else { + return InvalidArgument("Unknown opcode: %s", proto.opcode()); + } + } + TF_RET_CHECK(proto.has_shape()); std::unique_ptr instruction; @@ -136,6 +164,17 @@ StatusOr> HloInstruction::CreateFromProto( absl::Span(fft_length)); break; } + case HloOpcode::kCompare: { + // Auto-upgraded from deprecated opcode skips the following. + if (!comparison_direction) { + TF_ASSIGN_OR_RETURN( + comparison_direction, + StringToComparisonDirection(proto.comparison_direction())); + } + instruction = + CreateCompare(shape, operands(0), operands(1), *comparison_direction); + break; + } case HloOpcode::kTriangularSolve: { instruction = CreateTriangularSolve(shape, operands(0), operands(1), proto.triangular_solve_options()); @@ -688,15 +727,9 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kAtan2: case HloOpcode::kDivide: case HloOpcode::kComplex: - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kMultiply: - case HloOpcode::kNe: case HloOpcode::kPower: case HloOpcode::kRemainder: case HloOpcode::kSubtract: @@ -761,6 +794,12 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, fft_length); } +/* static */ std::unique_ptr HloInstruction::CreateCompare( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + ComparisonDirection direction) { + return absl::make_unique(shape, lhs, rhs, direction); +} + /* static */ std::unique_ptr HloInstruction::CreateTriangularSolve(const Shape& shape, HloInstruction* a, HloInstruction* b, @@ -1311,6 +1350,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormGrad: case HloOpcode::kFft: + case HloOpcode::kCompare: case HloOpcode::kSend: case HloOpcode::kSendDone: case HloOpcode::kRecv: @@ -1384,12 +1424,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kDivide: case HloOpcode::kMultiply: case HloOpcode::kSubtract: - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: - case HloOpcode::kNe: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kPower: @@ -1705,26 +1739,20 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kCos: case HloOpcode::kDivide: case HloOpcode::kDynamicUpdateSlice: - case HloOpcode::kEq: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: - case HloOpcode::kGe: - case HloOpcode::kGt: case HloOpcode::kImag: case HloOpcode::kIsFinite: - case HloOpcode::kLe: case HloOpcode::kLog: case HloOpcode::kLog1p: case HloOpcode::kAnd: case HloOpcode::kNot: case HloOpcode::kOr: case HloOpcode::kXor: - case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kMultiply: - case HloOpcode::kNe: case HloOpcode::kNegate: case HloOpcode::kPower: case HloOpcode::kReal: @@ -1772,6 +1800,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormGrad: case HloOpcode::kFft: + case HloOpcode::kCompare: case HloOpcode::kSend: case HloOpcode::kSendDone: case HloOpcode::kRecv: @@ -2119,17 +2148,12 @@ bool HloInstruction::IsElementwiseImpl( // Binary elementwise operations, the same as in IsElementwiseBinary(). case HloOpcode::kAdd: case HloOpcode::kAtan2: + case HloOpcode::kCompare: case HloOpcode::kComplex: case HloOpcode::kDivide: - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kMultiply: - case HloOpcode::kNe: case HloOpcode::kPower: case HloOpcode::kRemainder: case HloOpcode::kSubtract: @@ -2472,12 +2496,7 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleGetTupleElement(this); case HloOpcode::kParameter: return visitor->HandleParameter(this); - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: - case HloOpcode::kNe: + case HloOpcode::kCompare: return visitor->HandleCompare(this); case HloOpcode::kComplex: return visitor->HandleComplex(this); @@ -3519,6 +3538,10 @@ const DomainMetadata& HloInstruction::user_side_metadata() const { return Cast(this)->user_side_metadata(); } +ComparisonDirection HloInstruction::comparison_direction() const { + return Cast(this)->direction(); +} + const TriangularSolveOptions& HloInstruction::triangular_solve_options() const { return Cast(this)->triangular_solve_options(); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index b28521bf8d4..6f6a1b8505e 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -37,6 +37,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" @@ -444,6 +445,11 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, FftType fft_type, absl::Span fft_length); + // Creates a compare op, performing the comparison specified in direction. + static std::unique_ptr CreateCompare( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + ComparisonDirection direction); + static std::unique_ptr CreateTriangularSolve( const Shape& shape, HloInstruction* a, HloInstruction* b, const TriangularSolveOptions& options); @@ -1600,6 +1606,9 @@ class HloInstruction { // Delegates to HloDomainInstruction::user_side_metadata(). const DomainMetadata& user_side_metadata() const; + // Delegates to HloCompareInstruction::direction(). + ComparisonDirection comparison_direction() const; + // Delegates to HloTriangularSolveInstruction::triangular_solve_options(). const TriangularSolveOptions& triangular_solve_options() const; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 35f031f29a7..85f2ddba8d3 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1655,7 +1655,7 @@ body (bparam: s32[]) -> s32[] { condition (cparam: s32[]) -> pred[] { xconstant = s32[] constant(5) cparam = s32[] parameter(0) - ROOT greater-than = pred[] greater-than(xconstant, cparam) + ROOT greater-than = pred[] compare(xconstant, cparam), direction=GT } ENTRY entry (param: s32[]) -> s32[] { diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 7d18b35c2bb..41b4ba21380 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -202,6 +202,42 @@ std::unique_ptr HloFftInstruction::CloneWithNewOperandsImpl( fft_length_); } +HloCompareInstruction::HloCompareInstruction(const Shape& shape, + HloInstruction* lhs, + HloInstruction* rhs, + ComparisonDirection direction) + : HloInstruction(HloOpcode::kCompare, shape), direction_(direction) { + AppendOperand(lhs); + AppendOperand(rhs); +} + +HloInstructionProto HloCompareInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_comparison_direction(ComparisonDirectionToString(direction_)); + return proto; +} + +std::vector HloCompareInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("direction=", ComparisonDirectionToString(direction()))}; +} + +bool HloCompareInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return direction() == casted_other.direction(); +} + +std::unique_ptr HloCompareInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return absl::make_unique(shape, new_operands[0], + new_operands[1], direction()); +} + namespace { // Converts a protocol buffer message (e.g., TriangularSolveOptions) to a vector diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 43aa12c10f2..0bc0db41c0a 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -131,6 +131,28 @@ class HloFftInstruction : public HloInstruction { std::vector fft_length_; }; +class HloCompareInstruction : public HloInstruction { + public: + explicit HloCompareInstruction(const Shape& shape, HloInstruction* lhs, + HloInstruction* rhs, + ComparisonDirection direction); + ComparisonDirection direction() const { return direction_; } + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + + ComparisonDirection direction_; +}; + class HloTriangularSolveInstruction : public HloInstruction { public: explicit HloTriangularSolveInstruction(const Shape& shape, HloInstruction* a, diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc index 436cccb1fb9..45d3e9c460e 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc @@ -255,7 +255,7 @@ TEST_F(HloLivenessAnalysisTest, WhileWithDeadTupleElement) { loop_var.2 = (s32[], s32[3]{0}) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 constant.2 = s32[] constant(5) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(0) @@ -308,7 +308,7 @@ TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) { get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=1 add.1 = s32[] add(get-tuple-element.3, get-tuple-element.4) constant.2 = s32[] constant(5) - ROOT less-than = pred[] less-than(add.1, constant.2) + ROOT less-than = pred[] compare(add.1, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(0) @@ -360,7 +360,7 @@ TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) { loop_var.2 = (s32[], s32[], s32[]) parameter(0) get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=0 constant.1 = s32[] constant(5) - ROOT less-than = pred[] less-than(get-tuple-element.4, constant.1) + ROOT less-than = pred[] compare(get-tuple-element.4, constant.1), direction=LT } ENTRY SimpleLoop { constant.2 = s32[] constant(0) @@ -415,7 +415,7 @@ TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) { cond_param = (s32[]) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 constant.2 = s32[] constant(10) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(0) @@ -448,13 +448,13 @@ TEST_F(HloLivenessAnalysisTest, NestedWhileWithOutfeed) { cond_param = (s32[]) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 constant.2 = s32[] constant(10) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } OuterWhileCondition { cond_param.2 = (s32[]) parameter(0) get-tuple-element.5 = s32[] get-tuple-element(cond_param.2), index=0 constant.5 = s32[] constant(5) - ROOT less-than.2 = pred[] less-than(get-tuple-element.5, constant.5) + ROOT less-than.2 = pred[] compare(get-tuple-element.5, constant.5), direction=LT } OuterWhileBody { body_param.2 = (s32[]) parameter(0) diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc index d28e79d41ad..47ed85be196 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers.cc @@ -89,6 +89,22 @@ bool HloParameterMatcher::MatchAndExplain( return true; } +bool HloComparisonMatcher::MatchAndExplain( + const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const { + if (!HloMatcher::MatchAndExplain(instruction, listener)) { + return false; + } + if (instruction->comparison_direction() != direction_) { + *listener << "has wrong comparison direction (got " + << ComparisonDirectionToString( + instruction->comparison_direction()) + << ", want " << ComparisonDirectionToString(direction_) << ")"; + return false; + } + return true; +} + bool HloGetTupleElementMatcher::MatchAndExplain( const HloInstruction* instruction, ::testing::MatchResultListener* listener) const { diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 67488a6a9a0..756f4d2c6bc 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -54,6 +54,21 @@ class HloParameterMatcher : public HloMatcher { int64 parameter_number_; }; +// Custom matcher for comparisons, which accepts a comparison direction. +class HloComparisonMatcher : public HloMatcher { + public: + explicit HloComparisonMatcher( + ComparisonDirection direction, + std::vector<::testing::Matcher> operands) + : HloMatcher(HloOpcode::kCompare, operands), direction_(direction) {} + + bool MatchAndExplain(const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const override; + + private: + ComparisonDirection direction_; +}; + // Custom matcher for get-tuple-element instructions, which accepts a tuple // index to match. class HloGetTupleElementMatcher : public HloMatcher { @@ -172,6 +187,7 @@ HLO_MATCHER(BatchNormGrad); HLO_MATCHER(Call); HLO_MATCHER(Ceil); HLO_MATCHER(Clamp); +HLO_MATCHER(Compare); HLO_MATCHER(Concatenate); HLO_MATCHER(Conditional); HLO_MATCHER(Constant); @@ -184,28 +200,22 @@ HLO_MATCHER(Divide); HLO_MATCHER(Domain); HLO_MATCHER(DynamicSlice); HLO_MATCHER(DynamicUpdateSlice); -HLO_MATCHER(Eq); HLO_MATCHER(Exp); HLO_MATCHER(Floor); HLO_MATCHER(Fusion); -HLO_MATCHER(Ge); HLO_MATCHER(AfterAll); -HLO_MATCHER(Gt); HLO_MATCHER(Iota); HLO_MATCHER(Infeed); HLO_MATCHER(IsFinite); -HLO_MATCHER(Le); HLO_MATCHER(Log); HLO_MATCHER(And); HLO_MATCHER(Not); HLO_MATCHER(Or); HLO_MATCHER(Xor); -HLO_MATCHER(Lt); HLO_MATCHER(Map); HLO_MATCHER(Maximum); HLO_MATCHER(Minimum); HLO_MATCHER(Multiply); -HLO_MATCHER(Ne); HLO_MATCHER(Negate); HLO_MATCHER(Outfeed); HLO_MATCHER(Pad); @@ -256,6 +266,38 @@ inline ::testing::Matcher Parameter() { new ::xla::testing::HloMatcher(HloOpcode::kParameter, {})); } +// Comparison matchers below do not require any additional arguments. +template +inline ::testing::Matcher Eq(M... operands) { + return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher( + ComparisonDirection::kEq, {operands...})); +} +template +inline ::testing::Matcher Ne(M... operands) { + return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher( + ComparisonDirection::kNe, {operands...})); +} +template +inline ::testing::Matcher Ge(M... operands) { + return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher( + ComparisonDirection::kGe, {operands...})); +} +template +inline ::testing::Matcher Gt(M... operands) { + return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher( + ComparisonDirection::kGt, {operands...})); +} +template +inline ::testing::Matcher Le(M... operands) { + return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher( + ComparisonDirection::kLe, {operands...})); +} +template +inline ::testing::Matcher Lt(M... operands) { + return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher( + ComparisonDirection::kLt, {operands...})); +} + // GetTupleElement(operand, N) matches a GTE instruction which gets the N'th // tuple element of operand, while GetTupleElement(operand) matches any GTE // operation on operand, and GetTupleElement() matches any GTE operation at all. diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc index 7961aece541..549fc603c70 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -220,5 +220,33 @@ ENTRY DotOperationFusion_TransposeFusion { "rhs_contracting_dimensions (got {0} want {1})"); } +TEST(HloMatchersTest, ComparisonMatcher) { + auto shape = ShapeUtil::MakeShape(F32, {1}); + auto p0 = HloInstruction::CreateParameter(0, shape, "param.0"); + auto p1 = HloInstruction::CreateParameter(1, shape, "param.1"); + auto eq = HloInstruction::CreateCompare(shape, p0.get(), p1.get(), + ComparisonDirection::kEq); + auto ne = HloInstruction::CreateCompare(shape, p0.get(), p1.get(), + ComparisonDirection::kNe); + auto add = + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0.get(), p1.get()); + auto le = HloInstruction::CreateCompare(shape, p0.get(), add.get(), + ComparisonDirection::kLe); + + EXPECT_THAT(eq.get(), op::Compare()); + EXPECT_THAT(eq.get(), op::Eq()); + EXPECT_THAT(ne.get(), op::Compare()); + EXPECT_THAT(ne.get(), op::Ne()); + EXPECT_THAT(le.get(), + op::Compare(op::Parameter(0), + op::Add(op::Parameter(0), op::Parameter(1)))); + EXPECT_THAT(le.get(), op::Le(op::Parameter(0), + op::Add(op::Parameter(0), op::Parameter(1)))); + + EXPECT_THAT(Explain(eq.get(), op::Add()), Eq("")); + EXPECT_THAT(Explain(eq.get(), op::Ne()), + Eq("has wrong comparison direction (got EQ, want NE)")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc index bc0d7e2bc00..200d08c562e 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc @@ -254,8 +254,9 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { HloInstruction* zero_vector = cond_builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({0, 0, 0, 0}))); - cond_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); + cond_builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_param, + zero_vector, ComparisonDirection::kNe)); auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); // param - 1 diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc index f6e28662049..84988a9ecb3 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc @@ -86,7 +86,7 @@ TEST_F(HloModuleDceTest, WhileWithLiveOutputs) { loop_var.2 = (s32[], s32[3]{0}) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 constant.2 = s32[] constant(5) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(0) @@ -125,7 +125,7 @@ TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) { loop_var.2 = (s32[], f32[]) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 constant.3 = s32[] constant(5) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.3) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.3), direction=LT } ENTRY SimpleLoop { constant.4 = s32[] constant(0) @@ -163,7 +163,7 @@ TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) { loop_var.2 = (s32[], s32[3]{0}) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 constant.2 = s32[] constant(5) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(0) @@ -206,7 +206,7 @@ TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) { loop_var.2 = (s32[], s32[]) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1 constant.2 = s32[] constant(5) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(0) @@ -248,7 +248,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) { loop_var.2 = (s32[], s32[3]{0}) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 constant.2 = s32[] constant(5) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } SimpleLoop.body1 { loop_var.3 = (s32[], s32[3]{0}) parameter(0) @@ -263,7 +263,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) { loop_var.4 = (s32[], s32[3]{0}) parameter(0) get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0 constant.4 = s32[] constant(5) - ROOT less-than.1 = pred[] less-than(get-tuple-element.6, constant.4) + ROOT less-than.1 = pred[] compare(get-tuple-element.6, constant.4), direction=LT } ENTRY SimpleLoop { constant.5 = s32[] constant(0) @@ -316,7 +316,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) { loop_var.2 = (s32[3]{0}, s32[]) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1 constant.2 = s32[] constant(5) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } SimpleLoop.body1 { loop_var.3 = (s32[], s32[3]{0}) parameter(0) @@ -331,7 +331,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) { loop_var.4 = (s32[], s32[3]{0}) parameter(0) get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0 constant.4 = s32[] constant(5) - ROOT less-than.1 = pred[] less-than(get-tuple-element.6, constant.4) + ROOT less-than.1 = pred[] compare(get-tuple-element.6, constant.4), direction=LT } ENTRY SimpleLoop { constant.5 = s32[] constant(0) @@ -383,7 +383,7 @@ TEST_F(HloModuleDceTest, WhileWithOutfeed) { cond_param = (s32[]) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 constant.2 = s32[] constant(10) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(0) @@ -418,7 +418,7 @@ TEST_F(HloModuleDceTest, WhileWithOnlyLoopVariableBumping) { cond_param = (s32[], s32[]) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 constant.2 = s32[] constant(10) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } ENTRY SimpleLoop { p0 = (s32[]) parameter(0) diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index 548fbb873aa..8f459107b32 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -44,21 +44,8 @@ StatusOr StringToHloOpcode(const string& opcode_name) { return it->second; } -#define CHECK_DEFAULT(property_name, opcode_name) false -#define CHECK_PROPERTY(property_name, opcode_name, value) \ - (value & property_name) -#define RESOLVE(_1, _2, target, ...) target -#define HAS_PROPERTY(property, ...) \ - RESOLVE(__VA_ARGS__, CHECK_PROPERTY, CHECK_DEFAULT)(property, __VA_ARGS__) - bool HloOpcodeIsComparison(HloOpcode opcode) { - switch (opcode) { -#define CASE_IS_COMPARISON(enum_name, opcode_name, ...) \ - case HloOpcode::enum_name: \ - return HAS_PROPERTY(kHloOpcodeIsComparison, __VA_ARGS__); - HLO_OPCODE_LIST(CASE_IS_COMPARISON) -#undef CASE_IS_COMPARISON - } + return opcode == HloOpcode::kCompare; } bool HloOpcodeIsVariadic(HloOpcode opcode) { @@ -82,9 +69,4 @@ absl::optional HloOpcodeArity(HloOpcode opcode) { } } -#undef HAS_PROPERTY -#undef RESOLVE -#undef CHECK_DEFAULT -#undef CHECK_PROPERTY - } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 3e144c4472f..c5ccd49552a 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -19,8 +19,10 @@ limitations under the License. #include #include #include "absl/types/optional.h" +#include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -65,6 +67,7 @@ namespace xla { V(kClamp, "clamp", 3) \ V(kCollectivePermute, "collective-permute", 1) \ V(kClz, "count-leading-zeros", 1) \ + V(kCompare, "compare", 2) \ V(kComplex, "complex", 2) \ V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \ V(kConditional, "conditional", kHloOpcodeIsVariadic) \ @@ -79,34 +82,28 @@ namespace xla { V(kDot, "dot", 2) \ V(kDynamicSlice, "dynamic-slice", kHloOpcodeIsVariadic) \ V(kDynamicUpdateSlice, "dynamic-update-slice", kHloOpcodeIsVariadic) \ - V(kEq, "equal-to", 2, kHloOpcodeIsComparison) \ V(kExp, "exponential", 1) \ V(kExpm1, "exponential-minus-one", 1) \ V(kFft, "fft", 1) \ V(kFloor, "floor", 1) \ V(kFusion, "fusion", kHloOpcodeIsVariadic) \ V(kGather, "gather", 2) \ - V(kGe, "greater-than-or-equal-to", 2, kHloOpcodeIsComparison) \ V(kGetDimensionSize, "get-dimension-size", 1) \ V(kGetTupleElement, "get-tuple-element", 1) \ - V(kGt, "greater-than", 2, kHloOpcodeIsComparison) \ V(kImag, "imag", 1) \ V(kInfeed, "infeed", 1) \ V(kIota, "iota", 0) \ V(kIsFinite, "is-finite", 1) \ - V(kLe, "less-than-or-equal-to", 2, kHloOpcodeIsComparison) \ V(kLog, "log", 1) \ V(kLog1p, "log-plus-one", 1) \ V(kAnd, "and", 2) \ V(kNot, "not", 1) \ V(kOr, "or", 2) \ V(kXor, "xor", 2) \ - V(kLt, "less-than", 2, kHloOpcodeIsComparison) \ V(kMap, "map", kHloOpcodeIsVariadic) \ V(kMaximum, "maximum", 2) \ V(kMinimum, "minimum", 2) \ V(kMultiply, "multiply", 2) \ - V(kNe, "not-equal-to", 2, kHloOpcodeIsComparison) \ V(kNegate, "negate", 1) \ V(kOutfeed, "outfeed", 2) \ V(kPad, "pad", 2) \ diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc index 910cc25a591..136e6702b21 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode_test.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc @@ -42,12 +42,7 @@ TEST(HloOpcodeTest, OpcodeProperties) { // Test some properties. switch (opcode) { - case HloOpcode::kEq: - case HloOpcode::kNe: - case HloOpcode::kGt: - case HloOpcode::kLt: - case HloOpcode::kGe: - case HloOpcode::kLe: + case HloOpcode::kCompare: EXPECT_TRUE(HloOpcodeIsComparison(opcode)); break; default: diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 3ca77e60cd5..8e8b9d663ea 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -306,7 +306,7 @@ condition.v4 { constant.2 = s32[] constant(2) prev.2 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) parameter(0) get-tuple-element.8 = s32[] get-tuple-element(prev.2), index=0 - ROOT greater-than = pred[] greater-than(constant.2, get-tuple-element.8) + ROOT greater-than = pred[] compare(constant.2, get-tuple-element.8), direction=GT } fused_computation { diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index fd55d92c04a..8e76a1f262e 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -183,6 +183,7 @@ class HloParser { kHloComputation, kBracedHloComputationList, kFftType, + kComparisonDirection, kWindow, kConvolutionDimensionNumbers, kSharding, @@ -300,6 +301,7 @@ class HloParser { bool ParseTiles(std::vector* tiles); bool ParseOpcode(HloOpcode* result); bool ParseFftType(FftType* result); + bool ParseComparisonDirection(ComparisonDirection* result); bool ParseFusionKind(HloInstruction::FusionKind* result); bool ParseRandomDistribution(RandomDistribution* result); bool ParsePrecision(PrecisionConfig::Precision* result); @@ -763,12 +765,6 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, case HloOpcode::kSubtract: case HloOpcode::kAtan2: case HloOpcode::kComplex: - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: - case HloOpcode::kNe: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kPower: @@ -1133,6 +1129,18 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, shape, operands[0], operands[1], options)); break; } + case HloOpcode::kCompare: { + optional direction; + attrs["direction"] = {/*required=*/true, AttrTy::kComparisonDirection, + &direction}; + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateCompare( + shape, operands[0], operands[1], *direction)); + break; + } case HloOpcode::kCholesky: { CholeskyOptions options; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -2728,6 +2736,15 @@ bool HloParser::ParseAttributeHelper( static_cast*>(attr_out_ptr)->emplace(result); return true; } + case AttrTy::kComparisonDirection: { + ComparisonDirection result; + if (!ParseComparisonDirection(&result)) { + return false; + } + static_cast*>(attr_out_ptr) + ->emplace(result); + return true; + } case AttrTy::kWindow: { Window result; if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) { @@ -3756,6 +3773,22 @@ bool HloParser::ParseFftType(FftType* result) { return true; } +bool HloParser::ParseComparisonDirection(ComparisonDirection* result) { + VLOG(1) << "ParseComparisonDirection"; + if (lexer_.GetKind() != TokKind::kIdent) { + return TokenError("expects comparison direction"); + } + string val = lexer_.GetStrVal(); + auto status_or_result = StringToComparisonDirection(val); + if (!status_or_result.ok()) { + return TokenError( + StrFormat("expects comparison direction but sees: %s", val)); + } + *result = status_or_result.ValueOrDie(); + lexer_.Lex(); + return true; +} + bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) { VLOG(1) << "ParseFusionKind"; if (lexer_.GetKind() != TokKind::kIdent) { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 1ba2d718ecc..6f4171bca82 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -222,7 +222,7 @@ R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] { %v1 = f32[4]{0} parameter(0), sharding={maximal device=1} %v2 = f32[4]{0} parameter(1), sharding={maximal device=1} - %greater-than = pred[4]{0} greater-than(f32[4]{0} %v1, f32[4]{0} %v2), sharding={replicated} + %greater-than = pred[4]{0} compare(f32[4]{0} %v1, f32[4]{0} %v2), direction=GT, sharding={replicated} ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2), sharding={} } @@ -292,7 +292,7 @@ R"(HloModule WhileWithScalarS32Result_module %condition.v3 (prev.2: s32[]) -> pred[] { %constant.1 = s32[] constant(5) %prev.2 = s32[] parameter(0) - ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %prev.2) + ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %prev.2), direction=GT } ENTRY %WhileWithScalarS32Result.v2 () -> s32[] { @@ -474,7 +474,7 @@ R"(HloModule R4F32OverlapSmall_module %ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] { %lhs = f32[] parameter(0) %rhs = f32[] parameter(1) - ROOT %greater-than-or-equal-to = pred[] greater-than-or-equal-to(f32[] %lhs, f32[] %rhs) + ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE } %add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] { @@ -500,7 +500,7 @@ R"(HloModule select_and_scatter_scalar %ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] { %lhs = f32[] parameter(0) %rhs = f32[] parameter(1) - ROOT %greater-than-or-equal-to = pred[] greater-than-or-equal-to(f32[] %lhs, f32[] %rhs) + ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE } %add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] { @@ -1037,7 +1037,7 @@ R"(HloModule TupleReduce max_argmax { value = f32[] parameter(2) prev_max = f32[] parameter(0) - is_next_larger = pred[] greater-than-or-equal-to(value, prev_max) + is_next_larger = pred[] compare(value, prev_max), direction=GE max = f32[] select(is_next_larger, value, prev_max) index = s32[] parameter(3) prev_argmax = s32[] parameter(1) @@ -1106,7 +1106,7 @@ R"(HloModule sort compare { p.0.lhs = f32[] parameter(0) p.0.rhs = f32[] parameter(1) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY Sort { @@ -1126,7 +1126,7 @@ compare { p.1.rhs = s32[] parameter(3) p.0.lhs = f32[] parameter(0) p.0.rhs = f32[] parameter(1) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY Sort { @@ -1145,7 +1145,7 @@ R"(HloModule sort compare { p.0.lhs = f32[] parameter(0) p.0.rhs = f32[] parameter(1) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY Sort { @@ -1165,7 +1165,7 @@ compare { p.1.rhs = s32[] parameter(3) p.0.lhs = f32[] parameter(0) p.0.rhs = f32[] parameter(1) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY Sort { @@ -1190,7 +1190,7 @@ compare { p.3.rhs = f32[] parameter(7) p.0.lhs = f32[] parameter(0) p.0.rhs = f32[] parameter(1) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY Sort { @@ -1211,7 +1211,7 @@ R"(HloModule sort compare { p.0.lhs = f32[] parameter(0) p.0.rhs = f32[] parameter(1) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY Sort { @@ -1469,7 +1469,7 @@ compare { p.1.rhs = s32[] parameter(3) p.0.lhs = f32[] parameter(0) p.0.rhs = f32[] parameter(1) - ROOT lhs = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lhs = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY Sort { @@ -1656,7 +1656,7 @@ TEST_F(HloParserTest, WrongOperandsSize) { ENTRY %blabla (x: f32[]) -> pred[] { %x = f32[]{} parameter(0) - %eq = pred[]{} equal-to(f32[]{} %x) + %eq = pred[]{} compare(f32[]{} %x), direction=EQ } )"; @@ -1668,7 +1668,7 @@ TEST_F(HloParserTest, OperandNotFound) { const string original = R"(HloModule operand_not_found: ENTRY %blabla (x: f32[]) -> pred[] { %x = f32[]{} parameter(0) - %eq = pred[]{} equal-to(f32[]{} %x, f32[]{} %y) + %eq = pred[]{} compare(f32[]{} %x, f32[]{} %y), direction=EQ } )"; auto result = ParseHloString(original); diff --git a/tensorflow/compiler/xla/service/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/hlo_schedule_test.cc index 0e56e6f760e..ecc8dbe6560 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/hlo_schedule_test.cc @@ -228,7 +228,7 @@ HloModule UpdateScheduleWithMultipleComputations %param = (s32[], token[]) parameter(0) %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 %constant = s32[] constant(42) - ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) + ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT } ENTRY %WhileLoop () -> s32[] { @@ -297,7 +297,7 @@ HloModule UpdateScheduleWithMultipleComputations %param = (s32[], token[]) parameter(0) %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 %constant = s32[] constant(42) - ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) + ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT } ENTRY %WhileLoop () -> s32[] { diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index e4a78af7c72..4868cf961aa 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -65,6 +65,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kCeil: case HloOpcode::kClamp: case HloOpcode::kClz: + case HloOpcode::kCompare: case HloOpcode::kComplex: case HloOpcode::kConcatenate: case HloOpcode::kConstant: @@ -72,21 +73,15 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kCopy: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: - case HloOpcode::kEq: case HloOpcode::kFloor: - case HloOpcode::kGe: case HloOpcode::kGetTupleElement: - case HloOpcode::kGt: case HloOpcode::kImag: case HloOpcode::kInfeed: case HloOpcode::kIota: case HloOpcode::kIsFinite: - case HloOpcode::kLe: - case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kMultiply: - case HloOpcode::kNe: case HloOpcode::kNegate: case HloOpcode::kNot: case HloOpcode::kOr: diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 29b2503c05b..bff7c961d07 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -2001,6 +2001,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kCeil: case HloOpcode::kClamp: case HloOpcode::kClz: + case HloOpcode::kCompare: case HloOpcode::kComplex: case HloOpcode::kConcatenate: case HloOpcode::kConditional: @@ -2012,24 +2013,18 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kDivide: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: - case HloOpcode::kEq: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFft: case HloOpcode::kFloor: - case HloOpcode::kGe: - case HloOpcode::kGt: case HloOpcode::kImag: case HloOpcode::kIsFinite: - case HloOpcode::kLe: case HloOpcode::kLog: case HloOpcode::kLog1p: - case HloOpcode::kLt: case HloOpcode::kMap: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kMultiply: - case HloOpcode::kNe: case HloOpcode::kNegate: case HloOpcode::kNot: case HloOpcode::kOr: diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index c8cf3c47d38..efca6be331e 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -1084,7 +1084,7 @@ TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) { tup.1 = (s32[], token[], f32[512,1024]{0,1}) parameter(0) counter.1 = s32[] get-tuple-element(tup.1), index=0 five = s32[] constant(5) - ROOT lt = pred[] less-than(counter.1, five) + ROOT lt = pred[] compare(counter.1, five), direction=LT } body.2 (tup: (s32[], token[], f32[512,1024]{0,1})) -> (s32[], token[], f32[512,1024]{0,1}) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc index b6ae4932f57..db900856993 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc @@ -46,7 +46,7 @@ condition { condition.state = f32[] parameter(0) addend = f32[] custom-call(condition.state), custom_call_target="FakeCustomCallTarget" add = f32[] add(addend, condition.state) - ROOT greater-than = pred[] greater-than(const.100, add) + ROOT greater-than = pred[] compare(const.100, add), direction=GT } ENTRY while3 { diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index 7164bfc4cd4..ae1df60d350 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -67,6 +67,7 @@ namespace xla { // - WithOneUse: Instruction is used as an operand exactly once. // - WithOneUser: Instruction is used by exactly one other instruction, but // is possibly used more than once as an operand (e.g. multiply(x,x)). +// - WithComparisonDirection: instr has the given direction // // Shape(): // - EqualTo @@ -1671,6 +1672,40 @@ class HloInstructionPatternOneUserImpl } }; +class HloInstructionPatternComparisonDirectionImpl { + public: + explicit constexpr HloInstructionPatternComparisonDirectionImpl( + ComparisonDirection direction) + : direction_(direction) {} + + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return MatchImpl(inst, option); + } + + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + return MatchImpl(inst, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "which has comparison direction " + << ComparisonDirectionToString(direction_); + } + + private: + template + bool MatchImpl(HloInstructionType* inst, MatchOption option) const { + if (inst->opcode() != HloOpcode::kCompare || + inst->comparison_direction() != direction_) { + EXPLAIN << "HloInstruction is not comparison " + << ComparisonDirectionToString(direction_); + return false; + } + return true; + } + + ComparisonDirection direction_; +}; + // Matches a constant scalar or effective scalar, optionally with a given value. template class HloConstantScalarImpl { @@ -1956,6 +1991,14 @@ class HloInstructionPattern { return AppendImpl(HloInstructionPatternOneUserImpl()); } + // Modifies the pattern to match only if the instruction has the given + // comparison direction. + auto WithComparisonDirection(ComparisonDirection direction) const + -> decltype(this->AppendImpl( + HloInstructionPatternComparisonDirectionImpl(direction))) { + return AppendImpl(HloInstructionPatternComparisonDirectionImpl(direction)); + } + void DescribeTo(std::ostream* os, int64 indent = 0) const { impl_.DescribeTo(os, indent); } @@ -2118,18 +2161,13 @@ XLA_COMMUTATIVE_BINOP_PATTERN(Add) XLA_BINOP_PATTERN(Atan2) XLA_BINOP_PATTERN(Divide) XLA_BINOP_PATTERN(Complex) +XLA_BINOP_PATTERN(Compare) XLA_BINOP_PATTERN(Convolution) XLA_BINOP_PATTERN(Dot) -XLA_COMMUTATIVE_BINOP_PATTERN(Eq) XLA_BINOP_PATTERN(Gather) -XLA_BINOP_PATTERN(Ge) -XLA_BINOP_PATTERN(Gt) -XLA_BINOP_PATTERN(Le) -XLA_BINOP_PATTERN(Lt) XLA_COMMUTATIVE_BINOP_PATTERN(Maximum) XLA_COMMUTATIVE_BINOP_PATTERN(Minimum) XLA_COMMUTATIVE_BINOP_PATTERN(Multiply) -XLA_COMMUTATIVE_BINOP_PATTERN(Ne) XLA_BINOP_PATTERN(Outfeed) XLA_BINOP_PATTERN(Pad) XLA_BINOP_PATTERN(Power) @@ -2242,6 +2280,73 @@ XLA_VARIADIC_OP_PATTERN(Reduce); XLA_VARIADIC_OP_PATTERN(Sort); XLA_VARIADIC_OP_PATTERN(Tuple); +// Helpers for comparison instructions. +#define XLA_COMPARE_PATTERN(NAME) \ + inline auto NAME()->decltype( \ + Op().WithOpcode(HloOpcode::kCompare) \ + .WithComparisonDirection(ComparisonDirection::k##NAME)) { \ + return Op() \ + .WithOpcode(HloOpcode::kCompare) \ + .WithComparisonDirection(ComparisonDirection::k##NAME); \ + } \ + \ + template \ + inline auto NAME(Lhs&& lhs, Rhs&& rhs) \ + ->decltype(Op().WithOpcode(HloOpcode::kCompare) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs)) \ + .WithComparisonDirection(ComparisonDirection::k##NAME)) { \ + return Op() \ + .WithOpcode(HloOpcode::kCompare) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs)) \ + .WithComparisonDirection(ComparisonDirection::k##NAME); \ + } \ + \ + template \ + inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) \ + ->decltype(Op(matched_inst) \ + .WithOpcode(HloOpcode::kCompare) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs)) \ + .WithComparisonDirection(ComparisonDirection::k##NAME)) { \ + return Op(matched_inst) \ + .WithOpcode(HloOpcode::kCompare) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs)) \ + .WithComparisonDirection(ComparisonDirection::k##NAME); \ + } + +#define XLA_COMMUTATIVE_COMPARE_PATTERN(NAME) \ + XLA_COMPARE_PATTERN(NAME) \ + \ + template \ + inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \ + Rhs&& rhs) \ + ->decltype(Op(matched_inst) \ + .WithOpcode(HloOpcode::kCompare) \ + .WithBinaryOperandsAnyOrder(std::forward(lhs), \ + std::forward(rhs))) { \ + return Op(matched_inst) \ + .WithOpcode(HloOpcode::kCompare) \ + .WithBinaryOperandsAnyOrder(std::forward(lhs), \ + std::forward(rhs)); \ + } \ + template \ + inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \ + ->decltype(NAME##AnyOrder( \ + nullptr, std::forward(lhs), std::forward(rhs))) { \ + return NAME##AnyOrder( \ + nullptr, std::forward(lhs), std::forward(rhs)); \ + } + +XLA_COMMUTATIVE_COMPARE_PATTERN(Eq); +XLA_COMMUTATIVE_COMPARE_PATTERN(Ne); +XLA_COMPARE_PATTERN(Ge); +XLA_COMPARE_PATTERN(Gt); +XLA_COMPARE_PATTERN(Le); +XLA_COMPARE_PATTERN(Lt); + // Helpers for matching non-constant instructions. inline auto NonConstant() -> decltype(Op().IsNonConstant()) { return Op().IsNonConstant(); diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index 5c3c009a68b..cbe8c4a2410 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -931,5 +931,48 @@ TEST(PatternMatcherTest, OneUseAndOneUser) { "in p0 = f32[] parameter(0)"); } +TEST(HloMatchersTest, Comparison) { + auto shape = ShapeUtil::MakeShape(F32, {1}); + auto p0 = HloInstruction::CreateParameter(0, shape, "param.0"); + auto p1 = HloInstruction::CreateParameter(1, shape, "param.1"); + auto eq = HloInstruction::CreateCompare(shape, p0.get(), p1.get(), + ComparisonDirection::kEq); + auto ne = HloInstruction::CreateCompare(shape, p0.get(), p1.get(), + ComparisonDirection::kNe); + auto add = + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0.get(), p1.get()); + auto le = HloInstruction::CreateCompare(shape, p0.get(), add.get(), + ComparisonDirection::kLe); + + EXPECT_TRUE(Match(eq.get(), m::Compare())); + EXPECT_TRUE(Match(eq.get(), m::Eq())); + EXPECT_TRUE(Match(eq.get(), m::Eq(m::Parameter(0), m::Parameter(1)))); + EXPECT_TRUE(Match(eq.get(), m::EqAnyOrder(m::Parameter(1), m::Parameter(0)))); + EXPECT_TRUE(Match(ne.get(), m::Compare())); + EXPECT_TRUE(Match(ne.get(), m::Ne())); + EXPECT_TRUE(Match( + le.get(), + m::Compare(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1))))); + EXPECT_TRUE(Match(le.get(), m::Le(m::Parameter(0), + m::Add(m::Parameter(0), m::Parameter(1))))); + + EXPECT_FALSE(Match(eq.get(), m::Add())); + EXPECT_FALSE(Match(eq.get(), m::Ne())); + EXPECT_FALSE( + Match(le.get(), + m::Eq(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1))))); + EXPECT_FALSE(Match(eq.get(), m::Eq(m::Parameter(1), m::Parameter(0)))); + EXPECT_DESC_AND_EXPLANATION( + eq, m::Ne().WithOneUser(), + "an HloInstruction:\n" + " * with opcode compare AND\n" + " * which has comparison direction NE AND\n" + " * which has exactly one user (but possibly is used " + "multiple times by that instruction)", + "HloInstruction is not comparison NE\n" + "in compare = f32[1]{0} compare(f32[1]{0} param.0, f32[1]{0} param.1), " + "direction=EQ"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc index fdb5cd91fd0..e3a3feb8640 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.cc +++ b/tensorflow/compiler/xla/service/scatter_expander.cc @@ -181,8 +181,9 @@ static StatusOr CheckIndexValidity( HloInstruction* zero_index = BroadcastZeros(computation, index->shape().element_type(), AsInt64Slice(index->shape().dimensions())); - TF_ASSIGN_OR_RETURN(HloInstruction * negative_index_check, - MakeBinaryHlo(HloOpcode::kLe, zero_index, index)); + TF_ASSIGN_OR_RETURN( + HloInstruction * negative_index_check, + MakeCompareHlo(ComparisonDirection::kLe, zero_index, index)); // Check if the index is OOB w.r.t. the operand dimensions and window sizes. std::vector max_valid_index(operand_dims.size()); @@ -193,9 +194,9 @@ static StatusOr CheckIndexValidity( HloInstruction * max_valid_index_constant, MakeR1ConstantHlo(computation, index->shape().element_type(), max_valid_index)); - TF_ASSIGN_OR_RETURN( - HloInstruction * oob_index_check, - MakeBinaryHlo(HloOpcode::kGe, max_valid_index_constant, index)); + TF_ASSIGN_OR_RETURN(HloInstruction * oob_index_check, + MakeCompareHlo(ComparisonDirection::kGe, + max_valid_index_constant, index)); // Combine the results of the two checks above. TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 53b5d18a065..b1e7d2bc727 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -988,12 +988,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } return InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions); - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: - case HloOpcode::kNe: { + case HloOpcode::kCompare: { TF_ASSIGN_OR_RETURN(const Shape& shape, InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions)); diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 6f8cc6136bb..a9cab3f3e69 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -918,55 +918,10 @@ TEST_F(ShapeInferenceTest, InferPowShape) { ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie())); } -TEST_F(ShapeInferenceTest, InferCompareShapeEq) { +TEST_F(ShapeInferenceTest, InferCompareShape) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); - auto inferred_status = - ShapeInference::InferBinaryOpShape(HloOpcode::kEq, ten_floats, f32_, {}); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), - inferred_status.ValueOrDie())); -} - -TEST_F(ShapeInferenceTest, InferCompareShapeGe) { - auto ten_floats = ShapeUtil::MakeShape(F32, {10}); - auto inferred_status = - ShapeInference::InferBinaryOpShape(HloOpcode::kGe, ten_floats, f32_, {}); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), - inferred_status.ValueOrDie())); -} - -TEST_F(ShapeInferenceTest, InferCompareShapeGt) { - auto ten_floats = ShapeUtil::MakeShape(F32, {10}); - auto inferred_status = - ShapeInference::InferBinaryOpShape(HloOpcode::kGt, ten_floats, f32_, {}); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), - inferred_status.ValueOrDie())); -} - -TEST_F(ShapeInferenceTest, InferCompareShapeLe) { - auto ten_floats = ShapeUtil::MakeShape(F32, {10}); - auto inferred_status = - ShapeInference::InferBinaryOpShape(HloOpcode::kLe, ten_floats, f32_, {}); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), - inferred_status.ValueOrDie())); -} - -TEST_F(ShapeInferenceTest, InferCompareShapeLt) { - auto ten_floats = ShapeUtil::MakeShape(F32, {10}); - auto inferred_status = - ShapeInference::InferBinaryOpShape(HloOpcode::kLt, ten_floats, f32_, {}); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), - inferred_status.ValueOrDie())); -} - -TEST_F(ShapeInferenceTest, InferCompareShapeNe) { - auto ten_floats = ShapeUtil::MakeShape(F32, {10}); - auto inferred_status = - ShapeInference::InferBinaryOpShape(HloOpcode::kNe, ten_floats, f32_, {}); + auto inferred_status = ShapeInference::InferBinaryOpShape( + HloOpcode::kCompare, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); diff --git a/tensorflow/compiler/xla/service/sort_simplifier_test.cc b/tensorflow/compiler/xla/service/sort_simplifier_test.cc index 696ac1b4658..284d5095277 100644 --- a/tensorflow/compiler/xla/service/sort_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/sort_simplifier_test.cc @@ -39,7 +39,7 @@ TEST_F(SortSimplifierTest, RemoveUnusedSortOperandArrayResult) { p.0.rhs = f32[] parameter(1) p.1.lhs = s32[] parameter(2) p.1.rhs = s32[] parameter(3) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY sort_computation { @@ -73,7 +73,7 @@ TEST_F(SortSimplifierTest, RemoveUnusedSortOperandTuple) { p.1.rhs = s32[] parameter(3) p.2.lhs = u32[] parameter(4) p.2.rhs = u32[] parameter(5) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY sort_computation { @@ -109,7 +109,7 @@ TEST_F(SortSimplifierTest, DontRemoveUnusedSortKey) { p.0.rhs = f32[] parameter(1) p.1.lhs = s32[] parameter(2) p.1.rhs = s32[] parameter(3) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY sort_computation { @@ -134,7 +134,7 @@ TEST_F(SortSimplifierTest, RemoveUnusedFirstOperand) { p.0.rhs = f32[] parameter(1) p.1.lhs = s32[] parameter(2) p.1.rhs = s32[] parameter(3) - ROOT lt = pred[] less-than(p.1.lhs, p.1.rhs) + ROOT lt = pred[] compare(p.1.lhs, p.1.rhs), direction=LT } ENTRY sort_computation { diff --git a/tensorflow/compiler/xla/service/stable_sort_expander.cc b/tensorflow/compiler/xla/service/stable_sort_expander.cc index 1aa7e5fe7c0..ae4ce32569a 100644 --- a/tensorflow/compiler/xla/service/stable_sort_expander.cc +++ b/tensorflow/compiler/xla/service/stable_sort_expander.cc @@ -180,13 +180,13 @@ StatusOr StableSortExpander::ExpandInstruction( CHECK_NE(cloned_root, nullptr); Shape scalar_pred = ShapeUtil::MakeShape(PRED, {}); HloInstruction* same = - comparator->AddInstruction(HloInstruction::CreateBinary( - scalar_pred, HloOpcode::kEq, old_root, cloned_root)); + comparator->AddInstruction(HloInstruction::CreateCompare( + scalar_pred, old_root, cloned_root, ComparisonDirection::kEq)); HloInstruction* tie_breaker = - comparator->AddInstruction(HloInstruction::CreateBinary( - scalar_pred, HloOpcode::kLt, - comparator->parameter_instruction(2 * iota_index), - comparator->parameter_instruction(2 * iota_index + 1))); + comparator->AddInstruction(HloInstruction::CreateCompare( + scalar_pred, comparator->parameter_instruction(2 * iota_index), + comparator->parameter_instruction(2 * iota_index + 1), + ComparisonDirection::kLt)); HloInstruction* new_root = comparator->AddInstruction(HloInstruction::CreateTernary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kSelect, same, tie_breaker, diff --git a/tensorflow/compiler/xla/service/stable_sort_expander_test.cc b/tensorflow/compiler/xla/service/stable_sort_expander_test.cc index a62d953e6e8..61fb4392a32 100644 --- a/tensorflow/compiler/xla/service/stable_sort_expander_test.cc +++ b/tensorflow/compiler/xla/service/stable_sort_expander_test.cc @@ -65,7 +65,8 @@ void CheckComputationHasTieBreaker(const HloInstruction* root, // the copied comparison function where the parameters are reversed. Lt() is // the tie breaker comparison using the Iota operand. ASSERT_EQ(root->opcode(), HloOpcode::kSelect); - ASSERT_EQ(root->operand(0)->opcode(), HloOpcode::kEq); + ASSERT_EQ(root->operand(0)->opcode(), HloOpcode::kCompare); + ASSERT_EQ(root->operand(0)->comparison_direction(), ComparisonDirection::kEq); // Check that the tie breaker instruction is correct. EXPECT_THAT(root->operand(1), @@ -88,7 +89,7 @@ TEST_F(StableSortExpanderTest, StabilizeSortReuseIotaOperand) { p.0.rhs = f32[] parameter(1) p.1.lhs = s32[] parameter(2) p.1.rhs = s32[] parameter(3) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY sort_computation { @@ -126,15 +127,15 @@ TEST_F(StableSortExpanderTest, lhs.unsigned = u32[] bitcast-convert(p.0.lhs) lhs.flipped = u32[] subtract(max, lhs.unsigned) lhs.flipped.signed = s32[] bitcast-convert(lhs.flipped) - lhs.is_negative = pred[] less-than(lhs.flipped.signed, zero) + lhs.is_negative = pred[] compare(lhs.flipped.signed, zero), direction=LT lhs.converted = s32[] select(lhs.is_negative, lhs.flipped.signed, lhs.signed) rhs.signed = s32[] bitcast-convert(p.0.rhs) rhs.unsigned = u32[] bitcast-convert(p.0.rhs) rhs.flipped = u32[] subtract(max, rhs.unsigned) rhs.flipped.signed = s32[] bitcast-convert(rhs.flipped) - rhs.is_negative = pred[] less-than(rhs.flipped.signed, zero) + rhs.is_negative = pred[] compare(rhs.flipped.signed, zero), direction=LT rhs.converted = s32[] select(rhs.is_negative, rhs.flipped.signed, rhs.signed) - ROOT lt = pred[] less-than(lhs.converted, rhs.converted) + ROOT lt = pred[] compare(lhs.converted, rhs.converted), direction=LT } ENTRY sort_computation { @@ -165,7 +166,7 @@ TEST_F(StableSortExpanderTest, StabilizeSortAddIotaOperandAndChangeRoot) { p.0.rhs = f32[] parameter(1) p.1.lhs = s32[] parameter(2) p.1.rhs = s32[] parameter(3) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY sort_computation { @@ -200,7 +201,7 @@ TEST_F(StableSortExpanderTest, HonorIsStableFlag) { p.0.rhs = f32[] parameter(1) p.1.lhs = s32[] parameter(2) p.1.rhs = s32[] parameter(3) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY sort_computation { @@ -227,7 +228,7 @@ TEST_F(StableSortExpanderTest, p.0.rhs = f32[] parameter(1) p.1.lhs = s32[] parameter(2) p.1.rhs = s32[] parameter(3) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY sort_computation { @@ -264,7 +265,7 @@ TEST_F(StableSortExpanderTest, StabilizeSortDontReuseIotaOperandWrongType) { p.0.rhs = f32[] parameter(1) p.1.lhs = f32[] parameter(2) p.1.rhs = f32[] parameter(3) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY sort_computation { @@ -302,7 +303,7 @@ TEST_F(StableSortExpanderTest, StabilizeSortR1) { mask = s32[] constant(65535) lhs = s32[] and(p.0.lhs, mask) rhs = s32[] and(p.0.rhs, mask) - ROOT lt = pred[] less-than(lhs, rhs) + ROOT lt = pred[] compare(lhs, rhs), direction=LT } ENTRY sort_computation { @@ -332,7 +333,7 @@ TEST_F(StableSortExpanderTest, StabilizeSortR1NoRoot) { mask = s32[] constant(65535) lhs = s32[] and(p.0.lhs, mask) rhs = s32[] and(p.0.rhs, mask) - ROOT lt = pred[] less-than(lhs, rhs) + ROOT lt = pred[] compare(lhs, rhs), direction=LT } ENTRY sort_computation { diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 6f61fc44166..61b98673cbe 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -934,8 +934,8 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { HloInstruction::CreateParameter(0, in_shape, "param0")); auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, in_shape, "param1")); - auto result = builder.AddInstruction( - HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1)); + auto result = builder.AddInstruction(HloInstruction::CreateCompare( + out_shape, param0, param1, ComparisonDirection::kEq)); BuildModuleAndRunAnalysis(builder.Build()); @@ -1185,8 +1185,8 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { auto builder = HloComputation::Builder(TestName() + ".Cond"); auto data = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "data")); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data)); + builder.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), data, data, ComparisonDirection::kEq)); return builder.Build(); }; diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc index ac52a5fc2a4..ffa89b6a797 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis.cc +++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc @@ -286,7 +286,7 @@ static optional PatternMatchLoopTripCount(HloInstruction* while_op, // Handle `i = K; i < N; ++i`. if (Match(while_cond_root, m::Op() - .WithOpcode(HloOpcode::kLt) + .WithComparisonDirection(ComparisonDirection::kLt) .WithOperand(0, m::Op().Is(while_cond_indvar)))) { VLOG(2) << "Pattern-match succeeded: loop condition is i < N: " << while_cond_root->ToString(); @@ -303,7 +303,7 @@ static optional PatternMatchLoopTripCount(HloInstruction* while_op, // Handle `i = K; i <= N; ++i`. if (Match(while_cond_root, m::Op() - .WithOpcode(HloOpcode::kLe) + .WithComparisonDirection(ComparisonDirection::kLe) .WithOperand(0, m::Op().Is(while_cond_indvar)))) { VLOG(2) << "Pattern-match succeeded: loop condition is i <= N: " << while_cond_root->ToString(); diff --git a/tensorflow/compiler/xla/service/while_loop_analysis_test.cc b/tensorflow/compiler/xla/service/while_loop_analysis_test.cc index 1da0fbeac89..5a5dc742c03 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_analysis_test.cc @@ -40,7 +40,7 @@ TEST_F(WhileLoopAnalysisTest, SingleIterationUpperBound) { p_cond = (f32[2], s32[]) parameter(0) gte = s32[] get-tuple-element(p_cond), index=1 const = s32[] constant(42) - ROOT result = pred[] equal-to(gte, const) + ROOT result = pred[] compare(gte, const), direction=EQ } ENTRY entry { @@ -71,7 +71,7 @@ TEST_F(WhileLoopAnalysisTest, NoUpperBound) { p_cond = (f32[2], s32[]) parameter(0) gte = s32[] get-tuple-element(p_cond), index=1 const = s32[] constant(42) - ROOT result = pred[] equal-to(gte, const) + ROOT result = pred[] compare(gte, const), direction=EQ } ENTRY entry { @@ -104,7 +104,7 @@ TEST_F(WhileLoopAnalysisTest, ExactBound) { p_cond = (f32[2], s32[]) parameter(0) gte = s32[] get-tuple-element(p_cond), index=1 const = s32[] constant(42) - ROOT result = pred[] less-than(gte, const) + ROOT result = pred[] compare(gte, const), direction=LT } ENTRY entry { diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc index 3bcf5c38309..8ab5e433e0f 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc @@ -260,7 +260,7 @@ condition { p_cond = (f32[],f32[]) parameter(0) p_cond.0 = f32[] get-tuple-element((f32[],f32[]) p_cond), index=0 p_cond.1 = f32[] get-tuple-element((f32[],f32[]) p_cond), index=1 - ROOT result = pred[] less-than(p_cond.0, p_cond.1) + ROOT result = pred[] compare(p_cond.0, p_cond.1), direction=LT } ENTRY entry { @@ -300,7 +300,7 @@ condition { p_c.0 = f32[] get-tuple-element((f32[],(f32[],f32[])) p_c), index=0 p_c.1 = (f32[],f32[]) get-tuple-element((f32[],(f32[],f32[])) p_c), index=1 p_c.1.1 = f32[] get-tuple-element((f32[],f32[]) p_c.1), index=1 - ROOT result = pred[] less-than(p_c.0, p_c.1.1) + ROOT result = pred[] compare(p_c.0, p_c.1.1), direction=LT } ENTRY entry { @@ -342,7 +342,7 @@ condition { p_cond.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=0 p_cond.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=1 p_cond.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2 - ROOT result = pred[] less-than(p_cond.0, p_cond.1) + ROOT result = pred[] compare(p_cond.0, p_cond.1), direction=LT } ENTRY entry { @@ -389,10 +389,10 @@ condition { p_cond = (f32[],f32[],f32[]) parameter(0) p_cond.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=0 p_cond.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2 - lt.0 = pred[] less-than(p_cond.0, p_cond.2) + lt.0 = pred[] compare(p_cond.0, p_cond.2), direction=LT p_cond.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=1 p_cond.2.c = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2 - lt.1 = pred[] less-than(p_cond.1, p_cond.2.c) + lt.1 = pred[] compare(p_cond.1, p_cond.2.c), direction=LT ROOT result = pred[] and(lt.0, lt.1) } diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index 3587c016b44..f0bb646d9c0 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -556,7 +556,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DoNotHoistOutOfSingleIteration) { p_cond = (f32[2], f32[2], f32[2], s32[]) parameter(0) gte = s32[] get-tuple-element(p_cond), index=3 const = s32[] constant(42) - ROOT result = pred[] equal-to(gte, const) + ROOT result = pred[] compare(gte, const), direction=EQ } ENTRY entry { diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index ecca76b1e86..65175fb6ab3 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -72,7 +72,7 @@ WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) { loop_var.2 = (s32[], s32[3]{0}) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 constant.2 = s32[] constant({{LOOP_BOUND}}) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(42) @@ -107,7 +107,7 @@ WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound( loop_var.2 = (s32[], s32[3]{0}, s32[]) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=2 - ROOT less-than = pred[] less-than(get-tuple-element.3, get-tuple-element.4) + ROOT less-than = pred[] compare(get-tuple-element.3, get-tuple-element.4), direction=LT } ENTRY SimpleLoopWithIndirectLoopBound { constant.3 = s32[] constant(42) @@ -237,7 +237,7 @@ TEST_F(WhileLoopSimplifierTest, NonTupleShapedLoopNotSimplified) { NonTupleShapedLoop.condition { loop_var = s32[] parameter(0) constant = s32[] constant(100) - ROOT less-than = pred[] less-than(s32[] loop_var, s32[] constant) + ROOT less-than = pred[] compare(s32[] loop_var, s32[] constant), direction=LT } ENTRY INonTupleShapedLoop { constant.2 = s32[] constant(42) @@ -387,7 +387,7 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) { param0 = (s32[], s32[], s32[]) parameter(0) get-tuple-element = s32[] get-tuple-element((s32[], s32[], s32[]) param0), index=2 - ROOT equal-to = pred[] equal-to(s32[] constant.2, s32[] get-tuple-element) + ROOT equal-to = pred[] compare(s32[] constant.2, s32[] get-tuple-element), direction=EQ } ENTRY RemoveUnusedOperands { x = s32[] parameter(0) @@ -471,7 +471,7 @@ TEST_F(WhileLoopSimplifierTest, loop_var.2 = (s32[], s32[3]{0}) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 constant.2 = s32[] constant(44) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(42) @@ -503,7 +503,7 @@ TEST_F(WhileLoopSimplifierTest, LoopWithArrayConstantNotSimplified) { loop_var.2 = (s32[], s32[3]{0}, s32[3]{0}) parameter(0) get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=0 constant.2 = s32[] constant(47) - ROOT less-than = pred[] less-than(get-tuple-element.4, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.4, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(42) @@ -679,7 +679,7 @@ const char* const kSimpleMergeInductionVariablesModule = R"( b = TYPE[] get-tuple-element(param), index=1 sum = TYPE[] power(a, b) ten = TYPE[] constant(10) - ROOT cond = pred[] less-than(sum, ten) + ROOT cond = pred[] compare(sum, ten), direction=LT } ENTRY Loop { a = TYPE[] constant(10) diff --git a/tensorflow/compiler/xla/service/while_loop_trip_count_annotator_test.cc b/tensorflow/compiler/xla/service/while_loop_trip_count_annotator_test.cc index 5c19cbc015d..a1e18bbdef6 100644 --- a/tensorflow/compiler/xla/service/while_loop_trip_count_annotator_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_trip_count_annotator_test.cc @@ -41,7 +41,7 @@ TEST_F(TripCountAnnotatorTest, KnownSmallTripCount) { param = (s32[]) parameter(0) i = s32[] get-tuple-element(param), index=0 trip_count = s32[] constant(10) - ROOT done = pred[] less-than(i, trip_count) + ROOT done = pred[] compare(i, trip_count), direction=LT } ENTRY test { @@ -77,7 +77,7 @@ TEST_F(TripCountAnnotatorTest, KnownLargeTripCount) { param = (s32[]) parameter(0) i = s32[] get-tuple-element(param), index=0 trip_count = s32[] constant(1000000) - ROOT done = pred[] less-than(i, trip_count) + ROOT done = pred[] compare(i, trip_count), direction=LT } ENTRY test { @@ -113,7 +113,7 @@ TEST_F(TripCountAnnotatorTest, NonzeroStart) { param = (s32[]) parameter(0) i = s32[] get-tuple-element(param), index=0 trip_count = s32[] constant(1000000) - ROOT done = pred[] less-than(i, trip_count) + ROOT done = pred[] compare(i, trip_count), direction=LT } ENTRY test { @@ -149,7 +149,7 @@ TEST_F(TripCountAnnotatorTest, LessThanOrEqualTo) { param = (s32[]) parameter(0) i = s32[] get-tuple-element(param), index=0 trip_count = s32[] constant(1000000) - ROOT done = pred[] less-than-or-equal-to(i, trip_count) + ROOT done = pred[] compare(i, trip_count), direction=LE } ENTRY test { @@ -188,7 +188,7 @@ TEST_F(TripCountAnnotatorTest, Int64Overflow) { param = (s64[]) parameter(0) i = s64[] get-tuple-element(param), index=0 trip_count = s64[] constant(9223372036854775807) // 2^63-1 - ROOT done = pred[] less-than-or-equal-to(i, trip_count) + ROOT done = pred[] compare(i, trip_count), direction=LE } ENTRY test { diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index d77386497a1..b6f65c763ea 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -166,7 +166,7 @@ MakeCountedLoopConditionComputation(const Shape& loop_state_shape, TF_ASSIGN_OR_RETURN( HloInstruction * compare, - MakeBinaryHlo(HloOpcode::kLt, indvar, trip_count_constant)); + MakeCompareHlo(ComparisonDirection::kLt, indvar, trip_count_constant)); cond_computation->set_root_instruction(compare); return std::move(cond_computation); } diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index d98a644fd4d..189736effb1 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -63,7 +63,11 @@ const float test_float_vals[3][test_width][test_height] = { class FusionTest : public HloTestBase { protected: template - void TestElementwise2D(HloOpcode opcode) { + void TestElementwise2D( + HloOpcode opcode, + absl::optional direction = absl::nullopt) { + // Create a variable for comparisons since they require the direction. + bool is_compare = std::is_same::value; Array2D operand_data[Arity]; for (int i = 0; i < Arity; ++i) { new (&operand_data[i]) Array2D(test_width, test_height); @@ -76,7 +80,11 @@ class FusionTest : public HloTestBase { xs[k] = test_float_vals[k][i][j]; operand_data[k](i, j) = xs[k]; } - answer_data(i, j) = ComputeElementwiseAnswer(opcode, xs); + if (is_compare) { + answer_data(i, j) = ComputeElementwiseAnswerCompare(*direction, xs); + } else { + answer_data(i, j) = ComputeElementwiseAnswerFloat(opcode, xs); + } } } @@ -98,8 +106,13 @@ class FusionTest : public HloTestBase { root_hlo = HloInstruction::CreateUnary(answer_shape, opcode, hlos[1]); break; case 2: - root_hlo = HloInstruction::CreateBinary(answer_shape, opcode, hlos[1], - hlos[2]); + if (is_compare) { + root_hlo = HloInstruction::CreateCompare(answer_shape, hlos[1], + hlos[2], *direction); + } else { + root_hlo = HloInstruction::CreateBinary(answer_shape, opcode, hlos[1], + hlos[2]); + } break; case 3: root_hlo = HloInstruction::CreateTernary(answer_shape, opcode, hlos[1], @@ -124,13 +137,14 @@ class FusionTest : public HloTestBase { } private: - template - T ComputeElementwiseAnswer(HloOpcode opcode, absl::Span xs); + float ComputeElementwiseAnswerFloat(HloOpcode opcode, + absl::Span xs); + bool ComputeElementwiseAnswerCompare(ComparisonDirection direction, + absl::Span xs); }; -template <> -float FusionTest::ComputeElementwiseAnswer(HloOpcode opcode, - absl::Span xs) { +float FusionTest::ComputeElementwiseAnswerFloat(HloOpcode opcode, + absl::Span xs) { switch (opcode) { case HloOpcode::kAdd: return xs[0] + xs[1]; @@ -153,24 +167,21 @@ float FusionTest::ComputeElementwiseAnswer(HloOpcode opcode, } } -template <> -bool FusionTest::ComputeElementwiseAnswer(HloOpcode opcode, - absl::Span xs) { - switch (opcode) { - case HloOpcode::kEq: +bool FusionTest::ComputeElementwiseAnswerCompare(ComparisonDirection direction, + absl::Span xs) { + switch (direction) { + case ComparisonDirection::kEq: return xs[0] == xs[1]; - case HloOpcode::kNe: + case ComparisonDirection::kNe: return xs[0] != xs[1]; - case HloOpcode::kGt: + case ComparisonDirection::kGt: return xs[0] > xs[1]; - case HloOpcode::kLt: + case ComparisonDirection::kLt: return xs[0] < xs[1]; - case HloOpcode::kGe: + case ComparisonDirection::kGe: return xs[0] >= xs[1]; - case HloOpcode::kLe: + case ComparisonDirection::kLe: return xs[0] <= xs[1]; - default: - LOG(FATAL) << "No comparatory opcode: " << opcode; } } @@ -740,24 +751,28 @@ XLA_TEST_F(FusionTest, Maximum2D) { TestElementwise2D(HloOpcode::kMaximum); } -XLA_TEST_F(FusionTest, Equal2D) { TestElementwise2D(HloOpcode::kEq); } +XLA_TEST_F(FusionTest, Equal2D) { + TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kEq); +} XLA_TEST_F(FusionTest, Inequal2D) { - TestElementwise2D(HloOpcode::kNe); + TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kNe); } XLA_TEST_F(FusionTest, Greater2D) { - TestElementwise2D(HloOpcode::kGt); + TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kGt); } -XLA_TEST_F(FusionTest, Lesser2D) { TestElementwise2D(HloOpcode::kLt); } +XLA_TEST_F(FusionTest, Lesser2D) { + TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kLt); +} XLA_TEST_F(FusionTest, GreaterOrEqual2D) { - TestElementwise2D(HloOpcode::kGe); + TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kGe); } XLA_TEST_F(FusionTest, LesserOrEqual2D) { - TestElementwise2D(HloOpcode::kLe); + TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kLe); } XLA_TEST_F(FusionTest, Clamp2D) { diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 1fd9cb055c0..73c9d7ed4b0 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -227,7 +227,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { fused_computation { p = f32[4] parameter(0) multiply = f32[4] multiply(p, p) - less-than = pred[4] less-than(p, multiply) + less-than = pred[4] compare(p, multiply), direction=LT ROOT tuple = (pred[4], f32[4]) tuple(less-than, multiply) } @@ -252,7 +252,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { fused_computation { p = f32[] parameter(0) multiply = f32[] multiply(p, p) - less-than = pred[] less-than(p, multiply) + less-than = pred[] compare(p, multiply), direction=LT ROOT tuple = (pred[], f32[]) tuple(less-than, multiply) } diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index f68ee04565f..4337aa4bf9a 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -143,7 +143,7 @@ compare { p.0.rhs = f32[] parameter(1) p.1.lhs = s32[] parameter(2) p.1.rhs = s32[] parameter(3) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> (f32[1048576], s32[1048576]) { @@ -174,7 +174,7 @@ compare { p.0.rhs = s32[] parameter(1) p.1.lhs = s32[] parameter(2) p.1.rhs = s32[] parameter(3) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> (s32[1048576], s32[1048576]) { @@ -205,7 +205,7 @@ compare { p.0.rhs = bf16[] parameter(1) p.1.lhs = s32[] parameter(2) p.1.rhs = s32[] parameter(3) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY %sort. (parameter.0: bf16[2,1452], parameter.1: s32[2,1452]) -> (bf16[2,1452], s32[2,1452]) { diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc index b77cf38ed8e..38a2a9b8fba 100644 --- a/tensorflow/compiler/xla/tests/token_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -129,7 +129,7 @@ HloModule TokenInWhileLoop %param = (s32[], token[]) parameter(0) %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 %constant = s32[] constant(42) - ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) + ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT } ENTRY %TokenInWhileLoop () -> s32[] {