[XLA] Verify the comparison type for comparisons

PiperOrigin-RevId: 343577367
Change-Id: Ic39711697188f4927d0b23eb7f4b6cda6a14630a
This commit is contained in:
David Majnemer 2020-11-20 15:34:50 -08:00 committed by TensorFlower Gardener
parent 1ef1b443fa
commit 91c1504965
7 changed files with 146 additions and 26 deletions

View File

@ -234,6 +234,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_proto_cc",
"//tensorflow/compiler/xla/service:shape_inference",
"//tensorflow/core:lib",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@ -46,6 +47,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace xla {
@ -3350,6 +3352,11 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
}
if (!need_rewrite) {
if (opcode == HloOpcode::kCompare) {
CHECK(!instr_proto->comparison_type().empty());
new_instr->set_comparison_type(
ComparisonTypeToString(Comparison::DefaultComparisonType(PRED)));
}
*new_instr->mutable_name() =
GetFullName(instr_proto->opcode(), kNameSeparator, id);
return Status::OK();
@ -4009,11 +4016,26 @@ XlaOp Eq(const XlaOp lhs, const XlaOp rhs,
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq);
}
static XlaOp CompareTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions,
ComparisonDirection comparison_direction) {
auto b = lhs.builder();
return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto operand_shape, b->GetShape(lhs));
auto operand_element_type = operand_shape.element_type();
auto compare_type =
primitive_util::IsFloatingPointType(operand_element_type)
? Comparison::Type::kFloatTotalOrder
: Comparison::DefaultComparisonType(operand_element_type);
return Compare(lhs, rhs, broadcast_dimensions, comparison_direction,
compare_type);
});
}
XlaOp EqTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) {
auto compare_type = Comparison::Type::kFloatTotalOrder;
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq,
compare_type);
return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
ComparisonDirection::kEq);
}
XlaOp Ne(const XlaOp lhs, const XlaOp rhs,
@ -4023,9 +4045,8 @@ XlaOp Ne(const XlaOp lhs, const XlaOp rhs,
XlaOp NeTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) {
auto compare_type = Comparison::Type::kFloatTotalOrder;
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe,
compare_type);
return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
ComparisonDirection::kNe);
}
XlaOp Ge(const XlaOp lhs, const XlaOp rhs,
@ -4035,9 +4056,8 @@ XlaOp Ge(const XlaOp lhs, const XlaOp rhs,
XlaOp GeTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) {
auto compare_type = Comparison::Type::kFloatTotalOrder;
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe,
compare_type);
return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
ComparisonDirection::kGe);
}
XlaOp Gt(const XlaOp lhs, const XlaOp rhs,
@ -4047,9 +4067,8 @@ XlaOp Gt(const XlaOp lhs, const XlaOp rhs,
XlaOp GtTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) {
auto compare_type = Comparison::Type::kFloatTotalOrder;
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt,
compare_type);
return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
ComparisonDirection::kGt);
}
XlaOp Le(const XlaOp lhs, const XlaOp rhs,
@ -4059,10 +4078,10 @@ XlaOp Le(const XlaOp lhs, const XlaOp rhs,
XlaOp LeTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) {
auto compare_type = Comparison::Type::kFloatTotalOrder;
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe,
compare_type);
return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
ComparisonDirection::kLe);
}
XlaOp Lt(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) {
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt);
@ -4070,8 +4089,8 @@ XlaOp Lt(const XlaOp lhs, const XlaOp rhs,
XlaOp LtTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> broadcast_dimensions) {
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt,
Comparison::Type::kFloatTotalOrder);
return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
ComparisonDirection::kLt);
}
XlaOp Compare(const XlaOp lhs, const XlaOp rhs,

View File

@ -3721,6 +3721,7 @@ cc_library(
":hlo_casting_utils",
":hlo_pass",
":shape_inference",
"//tensorflow/compiler/xla:comparison_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
@ -1746,6 +1747,31 @@ Status CheckElementwiseInstruction(HloInstruction* instruction) {
ShapeUtil::HumanString(operand_shape));
}
}
if (auto* comparison = DynCast<HloCompareInstruction>(instruction)) {
const Shape& operand_shape = comparison->operand(1)->shape();
PrimitiveType operand_element_type = operand_shape.element_type();
Comparison::Type default_comparison_type =
Comparison::DefaultComparisonType(operand_element_type);
if (primitive_util::IsFloatingPointType(operand_element_type)) {
if (comparison->type() != Comparison::Type::kFloat &&
comparison->type() != Comparison::Type::kFloatTotalOrder) {
return FailedPrecondition(
"Expected comparison type %s or %s.\n"
"actual: %s\noperand: %s\n",
ComparisonTypeToString(Comparison::Type::kFloat),
ComparisonTypeToString(Comparison::Type::kFloatTotalOrder),
ComparisonTypeToString(comparison->type()),
ShapeUtil::HumanString(operand_shape));
}
} else if (comparison->type() != default_comparison_type) {
return FailedPrecondition(
"Expected comparison type %s.\n"
"actual: %s\noperand: %s\n",
ComparisonTypeToString(default_comparison_type),
ComparisonTypeToString(comparison->type()),
ShapeUtil::HumanString(operand_shape));
}
}
return Status::OK();
}

View File

@ -1220,5 +1220,77 @@ TEST_F(HloVerifierTest, CollectivePermuteDoneNoCollectivePermuteStart) {
"needs to be collective-permute-start, found tuple"));
}
TEST_F(HloVerifierTest, ComparisonTypeFloat) {
const char* const hlo_string = R"(
HloModule Module
ENTRY RngOperandElementTypesNotMatch {
p0 = f32[] parameter(0)
ROOT cmp = pred[] compare(f32[] p0, f32[] p0), direction=LT, type=UNSIGNED
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Expected comparison type FLOAT or TOTALORDER"));
}
TEST_F(HloVerifierTest, ComparisonTypeSigned) {
const char* const hlo_string = R"(
HloModule Module
ENTRY RngOperandElementTypesNotMatch {
p0 = s32[] parameter(0)
ROOT cmp = pred[] compare(s32[] p0, s32[] p0), direction=LT, type=UNSIGNED
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Expected comparison type SIGNED"));
}
TEST_F(HloVerifierTest, ComparisonTypeUnsigned) {
const char* const hlo_string = R"(
HloModule Module
ENTRY RngOperandElementTypesNotMatch {
p0 = u32[] parameter(0)
ROOT cmp = pred[] compare(u32[] p0, u32[] p0), direction=LT, type=SIGNED
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Expected comparison type UNSIGNED"));
}
TEST_F(HloVerifierTest, ComparisonTypePred) {
const char* const hlo_string = R"(
HloModule Module
ENTRY RngOperandElementTypesNotMatch {
p0 = pred[] parameter(0)
ROOT cmp = pred[] compare(pred[] p0, pred[] p0), direction=LT, type=SIGNED
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Expected comparison type UNSIGNED"));
}
} // namespace
} // namespace xla

View File

@ -554,7 +554,7 @@ ENTRY jit_broken.874 {
abs.129 = f32[4]{0} abs(subtract.126)
constant.130 = f32[] constant(inf)
broadcast.131 = f32[4]{0} broadcast(constant.130), dimensions={}
compare.132 = pred[4]{0} compare(abs.129, broadcast.131), direction=EQ, type=UNSIGNED
compare.132 = pred[4]{0} compare(abs.129, broadcast.131), direction=EQ
not.133 = pred[4]{0} not(compare.132)
and.134 = pred[4]{0} and(not.128, not.133)
add.135 = f32[4]{0} add(add.124, add.89)
@ -577,7 +577,7 @@ ENTRY jit_broken.874 {
abs.219 = f32[4]{0} abs(subtract.216)
constant.220 = f32[] constant(inf)
broadcast.221 = f32[4]{0} broadcast(constant.220), dimensions={}
compare.222 = pred[4]{0} compare(abs.219, broadcast.221), direction=EQ, type=UNSIGNED
compare.222 = pred[4]{0} compare(abs.219, broadcast.221), direction=EQ
not.223 = pred[4]{0} not(compare.222)
and.224 = pred[4]{0} and(not.218, not.223)
add.225 = f32[4]{0} add(add.214, add.179)
@ -600,7 +600,7 @@ ENTRY jit_broken.874 {
abs.309 = f32[4]{0} abs(subtract.306)
constant.310 = f32[] constant(inf)
broadcast.311 = f32[4]{0} broadcast(constant.310), dimensions={}
compare.312 = pred[4]{0} compare(abs.309, broadcast.311), direction=EQ, type=UNSIGNED
compare.312 = pred[4]{0} compare(abs.309, broadcast.311), direction=EQ
not.313 = pred[4]{0} not(compare.312)
and.314 = pred[4]{0} and(not.308, not.313)
add.315 = f32[4]{0} add(add.304, add.269)
@ -623,7 +623,7 @@ ENTRY jit_broken.874 {
abs.399 = f32[4]{0} abs(subtract.396)
constant.400 = f32[] constant(inf)
broadcast.401 = f32[4]{0} broadcast(constant.400), dimensions={}
compare.402 = pred[4]{0} compare(abs.399, broadcast.401), direction=EQ, type=UNSIGNED
compare.402 = pred[4]{0} compare(abs.399, broadcast.401), direction=EQ
not.403 = pred[4]{0} not(compare.402)
and.404 = pred[4]{0} and(not.398, not.403)
add.405 = f32[4]{0} add(add.394, add.359)
@ -646,7 +646,7 @@ ENTRY jit_broken.874 {
abs.489 = f32[4]{0} abs(subtract.486)
constant.490 = f32[] constant(inf)
broadcast.491 = f32[4]{0} broadcast(constant.490), dimensions={}
compare.492 = pred[4]{0} compare(abs.489, broadcast.491), direction=EQ, type=UNSIGNED
compare.492 = pred[4]{0} compare(abs.489, broadcast.491), direction=EQ
not.493 = pred[4]{0} not(compare.492)
and.494 = pred[4]{0} and(not.488, not.493)
add.495 = f32[4]{0} add(add.484, add.449)
@ -669,7 +669,7 @@ ENTRY jit_broken.874 {
abs.579 = f32[4]{0} abs(subtract.576)
constant.580 = f32[] constant(inf)
broadcast.581 = f32[4]{0} broadcast(constant.580), dimensions={}
compare.582 = pred[4]{0} compare(abs.579, broadcast.581), direction=EQ, type=UNSIGNED
compare.582 = pred[4]{0} compare(abs.579, broadcast.581), direction=EQ
not.583 = pred[4]{0} not(compare.582)
and.584 = pred[4]{0} and(not.578, not.583)
add.585 = f32[4]{0} add(add.574, add.539)
@ -692,7 +692,7 @@ ENTRY jit_broken.874 {
abs.669 = f32[4]{0} abs(subtract.666)
constant.670 = f32[] constant(inf)
broadcast.671 = f32[4]{0} broadcast(constant.670), dimensions={}
compare.672 = pred[4]{0} compare(abs.669, broadcast.671), direction=EQ, type=UNSIGNED
compare.672 = pred[4]{0} compare(abs.669, broadcast.671), direction=EQ
not.673 = pred[4]{0} not(compare.672)
and.674 = pred[4]{0} and(not.668, not.673)
add.675 = f32[4]{0} add(add.664, add.629)
@ -715,7 +715,7 @@ ENTRY jit_broken.874 {
abs.759 = f32[4]{0} abs(subtract.756)
constant.760 = f32[] constant(inf)
broadcast.761 = f32[4]{0} broadcast(constant.760), dimensions={}
compare.762 = pred[4]{0} compare(abs.759, broadcast.761), direction=EQ, type=UNSIGNED
compare.762 = pred[4]{0} compare(abs.759, broadcast.761), direction=EQ
not.763 = pred[4]{0} not(compare.762)
and.764 = pred[4]{0} and(not.758, not.763)
add.765 = f32[4]{0} add(add.754, add.719)
@ -738,7 +738,7 @@ ENTRY jit_broken.874 {
abs.849 = f32[4]{0} abs(subtract.846)
constant.850 = f32[] constant(inf)
broadcast.851 = f32[4]{0} broadcast(constant.850), dimensions={}
compare.852 = pred[4]{0} compare(abs.849, broadcast.851), direction=EQ, type=UNSIGNED
compare.852 = pred[4]{0} compare(abs.849, broadcast.851), direction=EQ
not.853 = pred[4]{0} not(compare.852)
and.854 = pred[4]{0} and(not.848, not.853)
add.855 = f32[4]{0} add(add.844, add.809)