[XLA] Add commas between 0s and 1s when printing PRED literals.

And support parsing array shaped pred literals.

PiperOrigin-RevId: 219889665
This commit is contained in:
Cong Liu 2018-11-02 17:39:51 -07:00 committed by TensorFlower Gardener
parent 5f915f4dc5
commit 4d83992d1e
6 changed files with 42 additions and 28 deletions

View File

@ -1075,12 +1075,11 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
auto element_to_string = [&](absl::Span<const int64> indices) -> string {
PrimitiveType element_type = subshape.element_type();
if (element_type == PRED) {
// We display predicates in a densely packed form.
return literal.Get<bool>(indices, shape_index) ? "1" : "0";
}
return ((!indices.empty() && indices.back() > 0) ? ", " : "") +
literal.GetAsString(indices, shape_index);
// We display predicates as 0s and 1s so that the string is more dense.
string elem = element_type == PRED
? literal.Get<bool>(indices, shape_index) ? "1" : "0"
: literal.GetAsString(indices, shape_index);
return ((!indices.empty() && indices.back() > 0) ? ", " : "") + elem;
};
if (ShapeUtil::Rank(subshape) == 0) {

View File

@ -133,7 +133,7 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) {
TEST_F(LiteralUtilTest, LiteralVectorToString) {
auto pred_vec = LiteralUtil::CreateR1<bool>({true, false, true});
EXPECT_EQ("{101}", pred_vec.ToString());
EXPECT_EQ("{1, 0, 1}", pred_vec.ToString());
}
TEST_F(LiteralUtilTest, R2ToString) {

View File

@ -1806,6 +1806,10 @@ bool HloParser::SetValueInLiteral(tensorflow::int64 value,
case U64:
return SetValueInLiteralHelper<tensorflow::uint64>(value, linear_index,
literal);
case PRED:
// Bool type literals with rank >= 1 are printed in 0s and 1s.
return SetValueInLiteralHelper<bool>(static_cast<bool>(value),
linear_index, literal);
default:
LOG(FATAL) << "unknown integral primitive type "
<< PrimitiveType_Name(shape.element_type());
@ -2060,14 +2064,13 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) {
}
if (lexer_.GetKind() == TokKind::kw_true ||
lexer_.GetKind() == TokKind::kw_false) {
// TODO(congliu): bool type literals with rank >= 1 are actually
// printed in a compact form instead of "true" or "false". Fix that.
if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true,
linear_index++, literal)) {
return false;
}
lexer_.Lex();
} else if (primitive_util::IsIntegralType(shape.element_type())) {
} else if (primitive_util::IsIntegralType(shape.element_type()) ||
shape.element_type() == PRED) {
LocTy loc = lexer_.GetLoc();
tensorflow::int64 value;
if (!ParseInt64(&value)) {

View File

@ -75,6 +75,18 @@ ENTRY %constant_pred () -> pred[] {
)"
},
// pred array constant
{
"ConstantPredArray",
R"(HloModule module
ENTRY %constant_pred_array () -> pred[2,3] {
ROOT %constant = pred[2,3]{1,0} constant(pred[2,3] { { 0, 1, 0 }, { 1, 0, 1 } })
}
)"
},
// s32 constant
{
"ConstantS32",

View File

@ -2478,8 +2478,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) {
Ne(v, m, /*broadcast_dimensions=*/{1});
const string expected = R"(pred[2,2] {
{ 00 },
{ 01 }
{ 0, 0 },
{ 0, 1 }
})";
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
}
@ -2492,8 +2492,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) {
Ge(v, m, /*broadcast_dimensions=*/{1});
const string expected = R"(pred[2,4] {
{ 1100 },
{ 0001 }
{ 1, 1, 0, 0 },
{ 0, 0, 0, 1 }
})";
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
}
@ -2506,8 +2506,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) {
Gt(v, m, /*broadcast_dimensions=*/{1});
const string expected = R"(pred[2,4] {
{ 0100 },
{ 0000 }
{ 0, 1, 0, 0 },
{ 0, 0, 0, 0 }
})";
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
}
@ -2520,8 +2520,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) {
Le(v, m, /*broadcast_dimensions=*/{1});
const string expected = R"(pred[2,4] {
{ 1011 },
{ 1111 }
{ 1, 0, 1, 1 },
{ 1, 1, 1, 1 }
})";
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
}
@ -2534,8 +2534,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) {
Lt(v, m, /*broadcast_dimensions=*/{1});
const string expected = R"(pred[2,4] {
{ 0011 },
{ 1110 }
{ 0, 0, 1, 1 },
{ 1, 1, 1, 0 }
})";
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
}
@ -2744,12 +2744,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) {
Array3D<int> expected_3d(
{{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}});
const string expected = R"(pred[2,3,2] {
{ { 01 },
{ 00 },
{ 00 } },
{ { 01 },
{ 10 },
{ 01 } }
{ { 0, 1 },
{ 0, 0 },
{ 0, 0 } },
{ { 0, 1 },
{ 1, 0 },
{ 0, 1 } }
})";
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
}

View File

@ -87,8 +87,8 @@ TEST_F(PredTest, ConstantR2Pred) {
XlaBuilder builder(TestName());
ConstantR2<bool>(&builder, {{false, true, true}, {true, false, false}});
const string expected = R"(pred[2,3] {
{ 011 },
{ 100 }
{ 0, 1, 1 },
{ 1, 0, 0 }
})";
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
}