[XLA] Add commas between 0s and 1s when printing PRED literals.
And support parsing array shaped pred literals. PiperOrigin-RevId: 219889665
This commit is contained in:
parent
5f915f4dc5
commit
4d83992d1e
@ -1075,12 +1075,11 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
|
||||
|
||||
auto element_to_string = [&](absl::Span<const int64> indices) -> string {
|
||||
PrimitiveType element_type = subshape.element_type();
|
||||
if (element_type == PRED) {
|
||||
// We display predicates in a densely packed form.
|
||||
return literal.Get<bool>(indices, shape_index) ? "1" : "0";
|
||||
}
|
||||
return ((!indices.empty() && indices.back() > 0) ? ", " : "") +
|
||||
literal.GetAsString(indices, shape_index);
|
||||
// We display predicates as 0s and 1s so that the string is more dense.
|
||||
string elem = element_type == PRED
|
||||
? literal.Get<bool>(indices, shape_index) ? "1" : "0"
|
||||
: literal.GetAsString(indices, shape_index);
|
||||
return ((!indices.empty() && indices.back() > 0) ? ", " : "") + elem;
|
||||
};
|
||||
|
||||
if (ShapeUtil::Rank(subshape) == 0) {
|
||||
|
||||
@ -133,7 +133,7 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) {
|
||||
|
||||
TEST_F(LiteralUtilTest, LiteralVectorToString) {
|
||||
auto pred_vec = LiteralUtil::CreateR1<bool>({true, false, true});
|
||||
EXPECT_EQ("{101}", pred_vec.ToString());
|
||||
EXPECT_EQ("{1, 0, 1}", pred_vec.ToString());
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, R2ToString) {
|
||||
|
||||
@ -1806,6 +1806,10 @@ bool HloParser::SetValueInLiteral(tensorflow::int64 value,
|
||||
case U64:
|
||||
return SetValueInLiteralHelper<tensorflow::uint64>(value, linear_index,
|
||||
literal);
|
||||
case PRED:
|
||||
// Bool type literals with rank >= 1 are printed in 0s and 1s.
|
||||
return SetValueInLiteralHelper<bool>(static_cast<bool>(value),
|
||||
linear_index, literal);
|
||||
default:
|
||||
LOG(FATAL) << "unknown integral primitive type "
|
||||
<< PrimitiveType_Name(shape.element_type());
|
||||
@ -2060,14 +2064,13 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) {
|
||||
}
|
||||
if (lexer_.GetKind() == TokKind::kw_true ||
|
||||
lexer_.GetKind() == TokKind::kw_false) {
|
||||
// TODO(congliu): bool type literals with rank >= 1 are actually
|
||||
// printed in a compact form instead of "true" or "false". Fix that.
|
||||
if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true,
|
||||
linear_index++, literal)) {
|
||||
return false;
|
||||
}
|
||||
lexer_.Lex();
|
||||
} else if (primitive_util::IsIntegralType(shape.element_type())) {
|
||||
} else if (primitive_util::IsIntegralType(shape.element_type()) ||
|
||||
shape.element_type() == PRED) {
|
||||
LocTy loc = lexer_.GetLoc();
|
||||
tensorflow::int64 value;
|
||||
if (!ParseInt64(&value)) {
|
||||
|
||||
@ -75,6 +75,18 @@ ENTRY %constant_pred () -> pred[] {
|
||||
|
||||
)"
|
||||
},
|
||||
// pred array constant
|
||||
{
|
||||
"ConstantPredArray",
|
||||
R"(HloModule module
|
||||
|
||||
ENTRY %constant_pred_array () -> pred[2,3] {
|
||||
ROOT %constant = pred[2,3]{1,0} constant(pred[2,3] { { 0, 1, 0 }, { 1, 0, 1 } })
|
||||
}
|
||||
|
||||
)"
|
||||
},
|
||||
|
||||
// s32 constant
|
||||
{
|
||||
"ConstantS32",
|
||||
|
||||
@ -2478,8 +2478,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) {
|
||||
Ne(v, m, /*broadcast_dimensions=*/{1});
|
||||
|
||||
const string expected = R"(pred[2,2] {
|
||||
{ 00 },
|
||||
{ 01 }
|
||||
{ 0, 0 },
|
||||
{ 0, 1 }
|
||||
})";
|
||||
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
|
||||
}
|
||||
@ -2492,8 +2492,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) {
|
||||
Ge(v, m, /*broadcast_dimensions=*/{1});
|
||||
|
||||
const string expected = R"(pred[2,4] {
|
||||
{ 1100 },
|
||||
{ 0001 }
|
||||
{ 1, 1, 0, 0 },
|
||||
{ 0, 0, 0, 1 }
|
||||
})";
|
||||
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
|
||||
}
|
||||
@ -2506,8 +2506,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) {
|
||||
Gt(v, m, /*broadcast_dimensions=*/{1});
|
||||
|
||||
const string expected = R"(pred[2,4] {
|
||||
{ 0100 },
|
||||
{ 0000 }
|
||||
{ 0, 1, 0, 0 },
|
||||
{ 0, 0, 0, 0 }
|
||||
})";
|
||||
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
|
||||
}
|
||||
@ -2520,8 +2520,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) {
|
||||
Le(v, m, /*broadcast_dimensions=*/{1});
|
||||
|
||||
const string expected = R"(pred[2,4] {
|
||||
{ 1011 },
|
||||
{ 1111 }
|
||||
{ 1, 0, 1, 1 },
|
||||
{ 1, 1, 1, 1 }
|
||||
})";
|
||||
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
|
||||
}
|
||||
@ -2534,8 +2534,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) {
|
||||
Lt(v, m, /*broadcast_dimensions=*/{1});
|
||||
|
||||
const string expected = R"(pred[2,4] {
|
||||
{ 0011 },
|
||||
{ 1110 }
|
||||
{ 0, 0, 1, 1 },
|
||||
{ 1, 1, 1, 0 }
|
||||
})";
|
||||
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
|
||||
}
|
||||
@ -2744,12 +2744,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) {
|
||||
Array3D<int> expected_3d(
|
||||
{{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}});
|
||||
const string expected = R"(pred[2,3,2] {
|
||||
{ { 01 },
|
||||
{ 00 },
|
||||
{ 00 } },
|
||||
{ { 01 },
|
||||
{ 10 },
|
||||
{ 01 } }
|
||||
{ { 0, 1 },
|
||||
{ 0, 0 },
|
||||
{ 0, 0 } },
|
||||
{ { 0, 1 },
|
||||
{ 1, 0 },
|
||||
{ 0, 1 } }
|
||||
})";
|
||||
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
|
||||
}
|
||||
|
||||
@ -87,8 +87,8 @@ TEST_F(PredTest, ConstantR2Pred) {
|
||||
XlaBuilder builder(TestName());
|
||||
ConstantR2<bool>(&builder, {{false, true, true}, {true, false, false}});
|
||||
const string expected = R"(pred[2,3] {
|
||||
{ 011 },
|
||||
{ 100 }
|
||||
{ 0, 1, 1 },
|
||||
{ 1, 0, 0 }
|
||||
})";
|
||||
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user