diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 28f76829f14..171afa42351 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -234,6 +234,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_proto_cc", "//tensorflow/compiler/xla/service:shape_inference", "//tensorflow/core:lib", + "//tensorflow/stream_executor/lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 9d2d190eea9..33506535ddf 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -46,6 +47,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace xla { @@ -3350,6 +3352,11 @@ StatusOr XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) { } if (!need_rewrite) { + if (opcode == HloOpcode::kCompare) { + CHECK(!instr_proto->comparison_type().empty()); + new_instr->set_comparison_type( + ComparisonTypeToString(Comparison::DefaultComparisonType(PRED))); + } *new_instr->mutable_name() = GetFullName(instr_proto->opcode(), kNameSeparator, id); return Status::OK(); @@ -4009,11 +4016,26 @@ XlaOp Eq(const XlaOp lhs, const XlaOp rhs, return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq); } +static XlaOp CompareTotalOrder(const XlaOp lhs, const XlaOp rhs, + absl::Span broadcast_dimensions, + ComparisonDirection comparison_direction) { + auto b = lhs.builder(); + return b->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto operand_shape, b->GetShape(lhs)); + auto operand_element_type = operand_shape.element_type(); + auto compare_type = + primitive_util::IsFloatingPointType(operand_element_type) + ? Comparison::Type::kFloatTotalOrder + : Comparison::DefaultComparisonType(operand_element_type); + return Compare(lhs, rhs, broadcast_dimensions, comparison_direction, + compare_type); + }); +} + XlaOp EqTotalOrder(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { - auto compare_type = Comparison::Type::kFloatTotalOrder; - return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq, - compare_type); + return CompareTotalOrder(lhs, rhs, broadcast_dimensions, + ComparisonDirection::kEq); } XlaOp Ne(const XlaOp lhs, const XlaOp rhs, @@ -4023,9 +4045,8 @@ XlaOp Ne(const XlaOp lhs, const XlaOp rhs, XlaOp NeTotalOrder(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { - auto compare_type = Comparison::Type::kFloatTotalOrder; - return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe, - compare_type); + return CompareTotalOrder(lhs, rhs, broadcast_dimensions, + ComparisonDirection::kNe); } XlaOp Ge(const XlaOp lhs, const XlaOp rhs, @@ -4035,9 +4056,8 @@ XlaOp Ge(const XlaOp lhs, const XlaOp rhs, XlaOp GeTotalOrder(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { - auto compare_type = Comparison::Type::kFloatTotalOrder; - return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe, - compare_type); + return CompareTotalOrder(lhs, rhs, broadcast_dimensions, + ComparisonDirection::kGe); } XlaOp Gt(const XlaOp lhs, const XlaOp rhs, @@ -4047,9 +4067,8 @@ XlaOp Gt(const XlaOp lhs, const XlaOp rhs, XlaOp GtTotalOrder(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { - auto compare_type = Comparison::Type::kFloatTotalOrder; - return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt, - compare_type); + return CompareTotalOrder(lhs, rhs, broadcast_dimensions, + ComparisonDirection::kGt); } XlaOp Le(const XlaOp lhs, const XlaOp rhs, @@ -4059,10 +4078,10 @@ XlaOp Le(const XlaOp lhs, const XlaOp rhs, XlaOp LeTotalOrder(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { - auto compare_type = Comparison::Type::kFloatTotalOrder; - return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe, - compare_type); + return CompareTotalOrder(lhs, rhs, broadcast_dimensions, + ComparisonDirection::kLe); } + XlaOp Lt(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt); @@ -4070,8 +4089,8 @@ XlaOp Lt(const XlaOp lhs, const XlaOp rhs, XlaOp LtTotalOrder(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { - return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt, - Comparison::Type::kFloatTotalOrder); + return CompareTotalOrder(lhs, rhs, broadcast_dimensions, + ComparisonDirection::kLt); } XlaOp Compare(const XlaOp lhs, const XlaOp rhs, diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index b1034bf6ae5..5fe89e84995 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -3721,6 +3721,7 @@ cc_library( ":hlo_casting_utils", ":hlo_pass", ":shape_inference", + "//tensorflow/compiler/xla:comparison_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index f2ea03f063a..e43f68fd257 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 8c66d00cd85..84e4fe6e3fd 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" @@ -1746,6 +1747,31 @@ Status CheckElementwiseInstruction(HloInstruction* instruction) { ShapeUtil::HumanString(operand_shape)); } } + if (auto* comparison = DynCast(instruction)) { + const Shape& operand_shape = comparison->operand(1)->shape(); + PrimitiveType operand_element_type = operand_shape.element_type(); + Comparison::Type default_comparison_type = + Comparison::DefaultComparisonType(operand_element_type); + if (primitive_util::IsFloatingPointType(operand_element_type)) { + if (comparison->type() != Comparison::Type::kFloat && + comparison->type() != Comparison::Type::kFloatTotalOrder) { + return FailedPrecondition( + "Expected comparison type %s or %s.\n" + "actual: %s\noperand: %s\n", + ComparisonTypeToString(Comparison::Type::kFloat), + ComparisonTypeToString(Comparison::Type::kFloatTotalOrder), + ComparisonTypeToString(comparison->type()), + ShapeUtil::HumanString(operand_shape)); + } + } else if (comparison->type() != default_comparison_type) { + return FailedPrecondition( + "Expected comparison type %s.\n" + "actual: %s\noperand: %s\n", + ComparisonTypeToString(default_comparison_type), + ComparisonTypeToString(comparison->type()), + ShapeUtil::HumanString(operand_shape)); + } + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 0df30166a1c..c6c09e3dee1 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -1220,5 +1220,77 @@ TEST_F(HloVerifierTest, CollectivePermuteDoneNoCollectivePermuteStart) { "needs to be collective-permute-start, found tuple")); } +TEST_F(HloVerifierTest, ComparisonTypeFloat) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY RngOperandElementTypesNotMatch { + p0 = f32[] parameter(0) + ROOT cmp = pred[] compare(f32[] p0, f32[] p0), direction=LT, type=UNSIGNED + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Expected comparison type FLOAT or TOTALORDER")); +} + +TEST_F(HloVerifierTest, ComparisonTypeSigned) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY RngOperandElementTypesNotMatch { + p0 = s32[] parameter(0) + ROOT cmp = pred[] compare(s32[] p0, s32[] p0), direction=LT, type=UNSIGNED + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Expected comparison type SIGNED")); +} + +TEST_F(HloVerifierTest, ComparisonTypeUnsigned) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY RngOperandElementTypesNotMatch { + p0 = u32[] parameter(0) + ROOT cmp = pred[] compare(u32[] p0, u32[] p0), direction=LT, type=SIGNED + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Expected comparison type UNSIGNED")); +} + +TEST_F(HloVerifierTest, ComparisonTypePred) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY RngOperandElementTypesNotMatch { + p0 = pred[] parameter(0) + ROOT cmp = pred[] compare(pred[] p0, pred[] p0), direction=LT, type=SIGNED + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Expected comparison type UNSIGNED")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index fe27a8c6963..69916f6abc5 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -554,7 +554,7 @@ ENTRY jit_broken.874 { abs.129 = f32[4]{0} abs(subtract.126) constant.130 = f32[] constant(inf) broadcast.131 = f32[4]{0} broadcast(constant.130), dimensions={} - compare.132 = pred[4]{0} compare(abs.129, broadcast.131), direction=EQ, type=UNSIGNED + compare.132 = pred[4]{0} compare(abs.129, broadcast.131), direction=EQ not.133 = pred[4]{0} not(compare.132) and.134 = pred[4]{0} and(not.128, not.133) add.135 = f32[4]{0} add(add.124, add.89) @@ -577,7 +577,7 @@ ENTRY jit_broken.874 { abs.219 = f32[4]{0} abs(subtract.216) constant.220 = f32[] constant(inf) broadcast.221 = f32[4]{0} broadcast(constant.220), dimensions={} - compare.222 = pred[4]{0} compare(abs.219, broadcast.221), direction=EQ, type=UNSIGNED + compare.222 = pred[4]{0} compare(abs.219, broadcast.221), direction=EQ not.223 = pred[4]{0} not(compare.222) and.224 = pred[4]{0} and(not.218, not.223) add.225 = f32[4]{0} add(add.214, add.179) @@ -600,7 +600,7 @@ ENTRY jit_broken.874 { abs.309 = f32[4]{0} abs(subtract.306) constant.310 = f32[] constant(inf) broadcast.311 = f32[4]{0} broadcast(constant.310), dimensions={} - compare.312 = pred[4]{0} compare(abs.309, broadcast.311), direction=EQ, type=UNSIGNED + compare.312 = pred[4]{0} compare(abs.309, broadcast.311), direction=EQ not.313 = pred[4]{0} not(compare.312) and.314 = pred[4]{0} and(not.308, not.313) add.315 = f32[4]{0} add(add.304, add.269) @@ -623,7 +623,7 @@ ENTRY jit_broken.874 { abs.399 = f32[4]{0} abs(subtract.396) constant.400 = f32[] constant(inf) broadcast.401 = f32[4]{0} broadcast(constant.400), dimensions={} - compare.402 = pred[4]{0} compare(abs.399, broadcast.401), direction=EQ, type=UNSIGNED + compare.402 = pred[4]{0} compare(abs.399, broadcast.401), direction=EQ not.403 = pred[4]{0} not(compare.402) and.404 = pred[4]{0} and(not.398, not.403) add.405 = f32[4]{0} add(add.394, add.359) @@ -646,7 +646,7 @@ ENTRY jit_broken.874 { abs.489 = f32[4]{0} abs(subtract.486) constant.490 = f32[] constant(inf) broadcast.491 = f32[4]{0} broadcast(constant.490), dimensions={} - compare.492 = pred[4]{0} compare(abs.489, broadcast.491), direction=EQ, type=UNSIGNED + compare.492 = pred[4]{0} compare(abs.489, broadcast.491), direction=EQ not.493 = pred[4]{0} not(compare.492) and.494 = pred[4]{0} and(not.488, not.493) add.495 = f32[4]{0} add(add.484, add.449) @@ -669,7 +669,7 @@ ENTRY jit_broken.874 { abs.579 = f32[4]{0} abs(subtract.576) constant.580 = f32[] constant(inf) broadcast.581 = f32[4]{0} broadcast(constant.580), dimensions={} - compare.582 = pred[4]{0} compare(abs.579, broadcast.581), direction=EQ, type=UNSIGNED + compare.582 = pred[4]{0} compare(abs.579, broadcast.581), direction=EQ not.583 = pred[4]{0} not(compare.582) and.584 = pred[4]{0} and(not.578, not.583) add.585 = f32[4]{0} add(add.574, add.539) @@ -692,7 +692,7 @@ ENTRY jit_broken.874 { abs.669 = f32[4]{0} abs(subtract.666) constant.670 = f32[] constant(inf) broadcast.671 = f32[4]{0} broadcast(constant.670), dimensions={} - compare.672 = pred[4]{0} compare(abs.669, broadcast.671), direction=EQ, type=UNSIGNED + compare.672 = pred[4]{0} compare(abs.669, broadcast.671), direction=EQ not.673 = pred[4]{0} not(compare.672) and.674 = pred[4]{0} and(not.668, not.673) add.675 = f32[4]{0} add(add.664, add.629) @@ -715,7 +715,7 @@ ENTRY jit_broken.874 { abs.759 = f32[4]{0} abs(subtract.756) constant.760 = f32[] constant(inf) broadcast.761 = f32[4]{0} broadcast(constant.760), dimensions={} - compare.762 = pred[4]{0} compare(abs.759, broadcast.761), direction=EQ, type=UNSIGNED + compare.762 = pred[4]{0} compare(abs.759, broadcast.761), direction=EQ not.763 = pred[4]{0} not(compare.762) and.764 = pred[4]{0} and(not.758, not.763) add.765 = f32[4]{0} add(add.754, add.719) @@ -738,7 +738,7 @@ ENTRY jit_broken.874 { abs.849 = f32[4]{0} abs(subtract.846) constant.850 = f32[] constant(inf) broadcast.851 = f32[4]{0} broadcast(constant.850), dimensions={} - compare.852 = pred[4]{0} compare(abs.849, broadcast.851), direction=EQ, type=UNSIGNED + compare.852 = pred[4]{0} compare(abs.849, broadcast.851), direction=EQ not.853 = pred[4]{0} not(compare.852) and.854 = pred[4]{0} and(not.848, not.853) add.855 = f32[4]{0} add(add.844, add.809)