[XLA] Verify the comparison type for comparisons
PiperOrigin-RevId: 343577367 Change-Id: Ic39711697188f4927d0b23eb7f4b6cda6a14630a
This commit is contained in:
parent
1ef1b443fa
commit
91c1504965
@ -234,6 +234,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:shape_inference",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
|
@ -34,6 +34,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||
#include "tensorflow/compiler/xla/execution_options_util.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
@ -46,6 +47,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -3350,6 +3352,11 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
|
||||
}
|
||||
|
||||
if (!need_rewrite) {
|
||||
if (opcode == HloOpcode::kCompare) {
|
||||
CHECK(!instr_proto->comparison_type().empty());
|
||||
new_instr->set_comparison_type(
|
||||
ComparisonTypeToString(Comparison::DefaultComparisonType(PRED)));
|
||||
}
|
||||
*new_instr->mutable_name() =
|
||||
GetFullName(instr_proto->opcode(), kNameSeparator, id);
|
||||
return Status::OK();
|
||||
@ -4009,11 +4016,26 @@ XlaOp Eq(const XlaOp lhs, const XlaOp rhs,
|
||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq);
|
||||
}
|
||||
|
||||
static XlaOp CompareTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions,
|
||||
ComparisonDirection comparison_direction) {
|
||||
auto b = lhs.builder();
|
||||
return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(auto operand_shape, b->GetShape(lhs));
|
||||
auto operand_element_type = operand_shape.element_type();
|
||||
auto compare_type =
|
||||
primitive_util::IsFloatingPointType(operand_element_type)
|
||||
? Comparison::Type::kFloatTotalOrder
|
||||
: Comparison::DefaultComparisonType(operand_element_type);
|
||||
return Compare(lhs, rhs, broadcast_dimensions, comparison_direction,
|
||||
compare_type);
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp EqTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
auto compare_type = Comparison::Type::kFloatTotalOrder;
|
||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq,
|
||||
compare_type);
|
||||
return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
|
||||
ComparisonDirection::kEq);
|
||||
}
|
||||
|
||||
XlaOp Ne(const XlaOp lhs, const XlaOp rhs,
|
||||
@ -4023,9 +4045,8 @@ XlaOp Ne(const XlaOp lhs, const XlaOp rhs,
|
||||
|
||||
XlaOp NeTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
auto compare_type = Comparison::Type::kFloatTotalOrder;
|
||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe,
|
||||
compare_type);
|
||||
return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
|
||||
ComparisonDirection::kNe);
|
||||
}
|
||||
|
||||
XlaOp Ge(const XlaOp lhs, const XlaOp rhs,
|
||||
@ -4035,9 +4056,8 @@ XlaOp Ge(const XlaOp lhs, const XlaOp rhs,
|
||||
|
||||
XlaOp GeTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
auto compare_type = Comparison::Type::kFloatTotalOrder;
|
||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe,
|
||||
compare_type);
|
||||
return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
|
||||
ComparisonDirection::kGe);
|
||||
}
|
||||
|
||||
XlaOp Gt(const XlaOp lhs, const XlaOp rhs,
|
||||
@ -4047,9 +4067,8 @@ XlaOp Gt(const XlaOp lhs, const XlaOp rhs,
|
||||
|
||||
XlaOp GtTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
auto compare_type = Comparison::Type::kFloatTotalOrder;
|
||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt,
|
||||
compare_type);
|
||||
return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
|
||||
ComparisonDirection::kGt);
|
||||
}
|
||||
|
||||
XlaOp Le(const XlaOp lhs, const XlaOp rhs,
|
||||
@ -4059,10 +4078,10 @@ XlaOp Le(const XlaOp lhs, const XlaOp rhs,
|
||||
|
||||
XlaOp LeTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
auto compare_type = Comparison::Type::kFloatTotalOrder;
|
||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe,
|
||||
compare_type);
|
||||
return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
|
||||
ComparisonDirection::kLe);
|
||||
}
|
||||
|
||||
XlaOp Lt(const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt);
|
||||
@ -4070,8 +4089,8 @@ XlaOp Lt(const XlaOp lhs, const XlaOp rhs,
|
||||
|
||||
XlaOp LtTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt,
|
||||
Comparison::Type::kFloatTotalOrder);
|
||||
return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
|
||||
ComparisonDirection::kLt);
|
||||
}
|
||||
|
||||
XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
|
||||
|
@ -3721,6 +3721,7 @@ cc_library(
|
||||
":hlo_casting_utils",
|
||||
":hlo_pass",
|
||||
":shape_inference",
|
||||
"//tensorflow/compiler/xla:comparison_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
@ -1746,6 +1747,31 @@ Status CheckElementwiseInstruction(HloInstruction* instruction) {
|
||||
ShapeUtil::HumanString(operand_shape));
|
||||
}
|
||||
}
|
||||
if (auto* comparison = DynCast<HloCompareInstruction>(instruction)) {
|
||||
const Shape& operand_shape = comparison->operand(1)->shape();
|
||||
PrimitiveType operand_element_type = operand_shape.element_type();
|
||||
Comparison::Type default_comparison_type =
|
||||
Comparison::DefaultComparisonType(operand_element_type);
|
||||
if (primitive_util::IsFloatingPointType(operand_element_type)) {
|
||||
if (comparison->type() != Comparison::Type::kFloat &&
|
||||
comparison->type() != Comparison::Type::kFloatTotalOrder) {
|
||||
return FailedPrecondition(
|
||||
"Expected comparison type %s or %s.\n"
|
||||
"actual: %s\noperand: %s\n",
|
||||
ComparisonTypeToString(Comparison::Type::kFloat),
|
||||
ComparisonTypeToString(Comparison::Type::kFloatTotalOrder),
|
||||
ComparisonTypeToString(comparison->type()),
|
||||
ShapeUtil::HumanString(operand_shape));
|
||||
}
|
||||
} else if (comparison->type() != default_comparison_type) {
|
||||
return FailedPrecondition(
|
||||
"Expected comparison type %s.\n"
|
||||
"actual: %s\noperand: %s\n",
|
||||
ComparisonTypeToString(default_comparison_type),
|
||||
ComparisonTypeToString(comparison->type()),
|
||||
ShapeUtil::HumanString(operand_shape));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -1220,5 +1220,77 @@ TEST_F(HloVerifierTest, CollectivePermuteDoneNoCollectivePermuteStart) {
|
||||
"needs to be collective-permute-start, found tuple"));
|
||||
}
|
||||
|
||||
TEST_F(HloVerifierTest, ComparisonTypeFloat) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
ENTRY RngOperandElementTypesNotMatch {
|
||||
p0 = f32[] parameter(0)
|
||||
ROOT cmp = pred[] compare(f32[] p0, f32[] p0), direction=LT, type=UNSIGNED
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
|
||||
auto status = verifier().Run(module.get()).status();
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_THAT(status.error_message(),
|
||||
HasSubstr("Expected comparison type FLOAT or TOTALORDER"));
|
||||
}
|
||||
|
||||
TEST_F(HloVerifierTest, ComparisonTypeSigned) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
ENTRY RngOperandElementTypesNotMatch {
|
||||
p0 = s32[] parameter(0)
|
||||
ROOT cmp = pred[] compare(s32[] p0, s32[] p0), direction=LT, type=UNSIGNED
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
|
||||
auto status = verifier().Run(module.get()).status();
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_THAT(status.error_message(),
|
||||
HasSubstr("Expected comparison type SIGNED"));
|
||||
}
|
||||
|
||||
TEST_F(HloVerifierTest, ComparisonTypeUnsigned) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
ENTRY RngOperandElementTypesNotMatch {
|
||||
p0 = u32[] parameter(0)
|
||||
ROOT cmp = pred[] compare(u32[] p0, u32[] p0), direction=LT, type=SIGNED
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
|
||||
auto status = verifier().Run(module.get()).status();
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_THAT(status.error_message(),
|
||||
HasSubstr("Expected comparison type UNSIGNED"));
|
||||
}
|
||||
|
||||
TEST_F(HloVerifierTest, ComparisonTypePred) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
ENTRY RngOperandElementTypesNotMatch {
|
||||
p0 = pred[] parameter(0)
|
||||
ROOT cmp = pred[] compare(pred[] p0, pred[] p0), direction=LT, type=SIGNED
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
|
||||
auto status = verifier().Run(module.get()).status();
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_THAT(status.error_message(),
|
||||
HasSubstr("Expected comparison type UNSIGNED"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -554,7 +554,7 @@ ENTRY jit_broken.874 {
|
||||
abs.129 = f32[4]{0} abs(subtract.126)
|
||||
constant.130 = f32[] constant(inf)
|
||||
broadcast.131 = f32[4]{0} broadcast(constant.130), dimensions={}
|
||||
compare.132 = pred[4]{0} compare(abs.129, broadcast.131), direction=EQ, type=UNSIGNED
|
||||
compare.132 = pred[4]{0} compare(abs.129, broadcast.131), direction=EQ
|
||||
not.133 = pred[4]{0} not(compare.132)
|
||||
and.134 = pred[4]{0} and(not.128, not.133)
|
||||
add.135 = f32[4]{0} add(add.124, add.89)
|
||||
@ -577,7 +577,7 @@ ENTRY jit_broken.874 {
|
||||
abs.219 = f32[4]{0} abs(subtract.216)
|
||||
constant.220 = f32[] constant(inf)
|
||||
broadcast.221 = f32[4]{0} broadcast(constant.220), dimensions={}
|
||||
compare.222 = pred[4]{0} compare(abs.219, broadcast.221), direction=EQ, type=UNSIGNED
|
||||
compare.222 = pred[4]{0} compare(abs.219, broadcast.221), direction=EQ
|
||||
not.223 = pred[4]{0} not(compare.222)
|
||||
and.224 = pred[4]{0} and(not.218, not.223)
|
||||
add.225 = f32[4]{0} add(add.214, add.179)
|
||||
@ -600,7 +600,7 @@ ENTRY jit_broken.874 {
|
||||
abs.309 = f32[4]{0} abs(subtract.306)
|
||||
constant.310 = f32[] constant(inf)
|
||||
broadcast.311 = f32[4]{0} broadcast(constant.310), dimensions={}
|
||||
compare.312 = pred[4]{0} compare(abs.309, broadcast.311), direction=EQ, type=UNSIGNED
|
||||
compare.312 = pred[4]{0} compare(abs.309, broadcast.311), direction=EQ
|
||||
not.313 = pred[4]{0} not(compare.312)
|
||||
and.314 = pred[4]{0} and(not.308, not.313)
|
||||
add.315 = f32[4]{0} add(add.304, add.269)
|
||||
@ -623,7 +623,7 @@ ENTRY jit_broken.874 {
|
||||
abs.399 = f32[4]{0} abs(subtract.396)
|
||||
constant.400 = f32[] constant(inf)
|
||||
broadcast.401 = f32[4]{0} broadcast(constant.400), dimensions={}
|
||||
compare.402 = pred[4]{0} compare(abs.399, broadcast.401), direction=EQ, type=UNSIGNED
|
||||
compare.402 = pred[4]{0} compare(abs.399, broadcast.401), direction=EQ
|
||||
not.403 = pred[4]{0} not(compare.402)
|
||||
and.404 = pred[4]{0} and(not.398, not.403)
|
||||
add.405 = f32[4]{0} add(add.394, add.359)
|
||||
@ -646,7 +646,7 @@ ENTRY jit_broken.874 {
|
||||
abs.489 = f32[4]{0} abs(subtract.486)
|
||||
constant.490 = f32[] constant(inf)
|
||||
broadcast.491 = f32[4]{0} broadcast(constant.490), dimensions={}
|
||||
compare.492 = pred[4]{0} compare(abs.489, broadcast.491), direction=EQ, type=UNSIGNED
|
||||
compare.492 = pred[4]{0} compare(abs.489, broadcast.491), direction=EQ
|
||||
not.493 = pred[4]{0} not(compare.492)
|
||||
and.494 = pred[4]{0} and(not.488, not.493)
|
||||
add.495 = f32[4]{0} add(add.484, add.449)
|
||||
@ -669,7 +669,7 @@ ENTRY jit_broken.874 {
|
||||
abs.579 = f32[4]{0} abs(subtract.576)
|
||||
constant.580 = f32[] constant(inf)
|
||||
broadcast.581 = f32[4]{0} broadcast(constant.580), dimensions={}
|
||||
compare.582 = pred[4]{0} compare(abs.579, broadcast.581), direction=EQ, type=UNSIGNED
|
||||
compare.582 = pred[4]{0} compare(abs.579, broadcast.581), direction=EQ
|
||||
not.583 = pred[4]{0} not(compare.582)
|
||||
and.584 = pred[4]{0} and(not.578, not.583)
|
||||
add.585 = f32[4]{0} add(add.574, add.539)
|
||||
@ -692,7 +692,7 @@ ENTRY jit_broken.874 {
|
||||
abs.669 = f32[4]{0} abs(subtract.666)
|
||||
constant.670 = f32[] constant(inf)
|
||||
broadcast.671 = f32[4]{0} broadcast(constant.670), dimensions={}
|
||||
compare.672 = pred[4]{0} compare(abs.669, broadcast.671), direction=EQ, type=UNSIGNED
|
||||
compare.672 = pred[4]{0} compare(abs.669, broadcast.671), direction=EQ
|
||||
not.673 = pred[4]{0} not(compare.672)
|
||||
and.674 = pred[4]{0} and(not.668, not.673)
|
||||
add.675 = f32[4]{0} add(add.664, add.629)
|
||||
@ -715,7 +715,7 @@ ENTRY jit_broken.874 {
|
||||
abs.759 = f32[4]{0} abs(subtract.756)
|
||||
constant.760 = f32[] constant(inf)
|
||||
broadcast.761 = f32[4]{0} broadcast(constant.760), dimensions={}
|
||||
compare.762 = pred[4]{0} compare(abs.759, broadcast.761), direction=EQ, type=UNSIGNED
|
||||
compare.762 = pred[4]{0} compare(abs.759, broadcast.761), direction=EQ
|
||||
not.763 = pred[4]{0} not(compare.762)
|
||||
and.764 = pred[4]{0} and(not.758, not.763)
|
||||
add.765 = f32[4]{0} add(add.754, add.719)
|
||||
@ -738,7 +738,7 @@ ENTRY jit_broken.874 {
|
||||
abs.849 = f32[4]{0} abs(subtract.846)
|
||||
constant.850 = f32[] constant(inf)
|
||||
broadcast.851 = f32[4]{0} broadcast(constant.850), dimensions={}
|
||||
compare.852 = pred[4]{0} compare(abs.849, broadcast.851), direction=EQ, type=UNSIGNED
|
||||
compare.852 = pred[4]{0} compare(abs.849, broadcast.851), direction=EQ
|
||||
not.853 = pred[4]{0} not(compare.852)
|
||||
and.854 = pred[4]{0} and(not.848, not.853)
|
||||
add.855 = f32[4]{0} add(add.844, add.809)
|
||||
|
Loading…
x
Reference in New Issue
Block a user