[XLA] Add commas between 0s and 1s when printing PRED literals.
And support parsing array shaped pred literals. PiperOrigin-RevId: 219889665
This commit is contained in:
parent
5f915f4dc5
commit
4d83992d1e
@ -1075,12 +1075,11 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
|
|||||||
|
|
||||||
auto element_to_string = [&](absl::Span<const int64> indices) -> string {
|
auto element_to_string = [&](absl::Span<const int64> indices) -> string {
|
||||||
PrimitiveType element_type = subshape.element_type();
|
PrimitiveType element_type = subshape.element_type();
|
||||||
if (element_type == PRED) {
|
// We display predicates as 0s and 1s so that the string is more dense.
|
||||||
// We display predicates in a densely packed form.
|
string elem = element_type == PRED
|
||||||
return literal.Get<bool>(indices, shape_index) ? "1" : "0";
|
? literal.Get<bool>(indices, shape_index) ? "1" : "0"
|
||||||
}
|
: literal.GetAsString(indices, shape_index);
|
||||||
return ((!indices.empty() && indices.back() > 0) ? ", " : "") +
|
return ((!indices.empty() && indices.back() > 0) ? ", " : "") + elem;
|
||||||
literal.GetAsString(indices, shape_index);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if (ShapeUtil::Rank(subshape) == 0) {
|
if (ShapeUtil::Rank(subshape) == 0) {
|
||||||
|
|||||||
@ -133,7 +133,7 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) {
|
|||||||
|
|
||||||
TEST_F(LiteralUtilTest, LiteralVectorToString) {
|
TEST_F(LiteralUtilTest, LiteralVectorToString) {
|
||||||
auto pred_vec = LiteralUtil::CreateR1<bool>({true, false, true});
|
auto pred_vec = LiteralUtil::CreateR1<bool>({true, false, true});
|
||||||
EXPECT_EQ("{101}", pred_vec.ToString());
|
EXPECT_EQ("{1, 0, 1}", pred_vec.ToString());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(LiteralUtilTest, R2ToString) {
|
TEST_F(LiteralUtilTest, R2ToString) {
|
||||||
|
|||||||
@ -1806,6 +1806,10 @@ bool HloParser::SetValueInLiteral(tensorflow::int64 value,
|
|||||||
case U64:
|
case U64:
|
||||||
return SetValueInLiteralHelper<tensorflow::uint64>(value, linear_index,
|
return SetValueInLiteralHelper<tensorflow::uint64>(value, linear_index,
|
||||||
literal);
|
literal);
|
||||||
|
case PRED:
|
||||||
|
// Bool type literals with rank >= 1 are printed in 0s and 1s.
|
||||||
|
return SetValueInLiteralHelper<bool>(static_cast<bool>(value),
|
||||||
|
linear_index, literal);
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "unknown integral primitive type "
|
LOG(FATAL) << "unknown integral primitive type "
|
||||||
<< PrimitiveType_Name(shape.element_type());
|
<< PrimitiveType_Name(shape.element_type());
|
||||||
@ -2060,14 +2064,13 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) {
|
|||||||
}
|
}
|
||||||
if (lexer_.GetKind() == TokKind::kw_true ||
|
if (lexer_.GetKind() == TokKind::kw_true ||
|
||||||
lexer_.GetKind() == TokKind::kw_false) {
|
lexer_.GetKind() == TokKind::kw_false) {
|
||||||
// TODO(congliu): bool type literals with rank >= 1 are actually
|
|
||||||
// printed in a compact form instead of "true" or "false". Fix that.
|
|
||||||
if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true,
|
if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true,
|
||||||
linear_index++, literal)) {
|
linear_index++, literal)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
lexer_.Lex();
|
lexer_.Lex();
|
||||||
} else if (primitive_util::IsIntegralType(shape.element_type())) {
|
} else if (primitive_util::IsIntegralType(shape.element_type()) ||
|
||||||
|
shape.element_type() == PRED) {
|
||||||
LocTy loc = lexer_.GetLoc();
|
LocTy loc = lexer_.GetLoc();
|
||||||
tensorflow::int64 value;
|
tensorflow::int64 value;
|
||||||
if (!ParseInt64(&value)) {
|
if (!ParseInt64(&value)) {
|
||||||
|
|||||||
@ -75,6 +75,18 @@ ENTRY %constant_pred () -> pred[] {
|
|||||||
|
|
||||||
)"
|
)"
|
||||||
},
|
},
|
||||||
|
// pred array constant
|
||||||
|
{
|
||||||
|
"ConstantPredArray",
|
||||||
|
R"(HloModule module
|
||||||
|
|
||||||
|
ENTRY %constant_pred_array () -> pred[2,3] {
|
||||||
|
ROOT %constant = pred[2,3]{1,0} constant(pred[2,3] { { 0, 1, 0 }, { 1, 0, 1 } })
|
||||||
|
}
|
||||||
|
|
||||||
|
)"
|
||||||
|
},
|
||||||
|
|
||||||
// s32 constant
|
// s32 constant
|
||||||
{
|
{
|
||||||
"ConstantS32",
|
"ConstantS32",
|
||||||
|
|||||||
@ -2478,8 +2478,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) {
|
|||||||
Ne(v, m, /*broadcast_dimensions=*/{1});
|
Ne(v, m, /*broadcast_dimensions=*/{1});
|
||||||
|
|
||||||
const string expected = R"(pred[2,2] {
|
const string expected = R"(pred[2,2] {
|
||||||
{ 00 },
|
{ 0, 0 },
|
||||||
{ 01 }
|
{ 0, 1 }
|
||||||
})";
|
})";
|
||||||
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
|
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
|
||||||
}
|
}
|
||||||
@ -2492,8 +2492,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) {
|
|||||||
Ge(v, m, /*broadcast_dimensions=*/{1});
|
Ge(v, m, /*broadcast_dimensions=*/{1});
|
||||||
|
|
||||||
const string expected = R"(pred[2,4] {
|
const string expected = R"(pred[2,4] {
|
||||||
{ 1100 },
|
{ 1, 1, 0, 0 },
|
||||||
{ 0001 }
|
{ 0, 0, 0, 1 }
|
||||||
})";
|
})";
|
||||||
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
|
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
|
||||||
}
|
}
|
||||||
@ -2506,8 +2506,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) {
|
|||||||
Gt(v, m, /*broadcast_dimensions=*/{1});
|
Gt(v, m, /*broadcast_dimensions=*/{1});
|
||||||
|
|
||||||
const string expected = R"(pred[2,4] {
|
const string expected = R"(pred[2,4] {
|
||||||
{ 0100 },
|
{ 0, 1, 0, 0 },
|
||||||
{ 0000 }
|
{ 0, 0, 0, 0 }
|
||||||
})";
|
})";
|
||||||
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
|
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
|
||||||
}
|
}
|
||||||
@ -2520,8 +2520,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) {
|
|||||||
Le(v, m, /*broadcast_dimensions=*/{1});
|
Le(v, m, /*broadcast_dimensions=*/{1});
|
||||||
|
|
||||||
const string expected = R"(pred[2,4] {
|
const string expected = R"(pred[2,4] {
|
||||||
{ 1011 },
|
{ 1, 0, 1, 1 },
|
||||||
{ 1111 }
|
{ 1, 1, 1, 1 }
|
||||||
})";
|
})";
|
||||||
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
|
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
|
||||||
}
|
}
|
||||||
@ -2534,8 +2534,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) {
|
|||||||
Lt(v, m, /*broadcast_dimensions=*/{1});
|
Lt(v, m, /*broadcast_dimensions=*/{1});
|
||||||
|
|
||||||
const string expected = R"(pred[2,4] {
|
const string expected = R"(pred[2,4] {
|
||||||
{ 0011 },
|
{ 0, 0, 1, 1 },
|
||||||
{ 1110 }
|
{ 1, 1, 1, 0 }
|
||||||
})";
|
})";
|
||||||
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
|
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
|
||||||
}
|
}
|
||||||
@ -2744,12 +2744,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) {
|
|||||||
Array3D<int> expected_3d(
|
Array3D<int> expected_3d(
|
||||||
{{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}});
|
{{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}});
|
||||||
const string expected = R"(pred[2,3,2] {
|
const string expected = R"(pred[2,3,2] {
|
||||||
{ { 01 },
|
{ { 0, 1 },
|
||||||
{ 00 },
|
{ 0, 0 },
|
||||||
{ 00 } },
|
{ 0, 0 } },
|
||||||
{ { 01 },
|
{ { 0, 1 },
|
||||||
{ 10 },
|
{ 1, 0 },
|
||||||
{ 01 } }
|
{ 0, 1 } }
|
||||||
})";
|
})";
|
||||||
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
|
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
|
||||||
}
|
}
|
||||||
|
|||||||
@ -87,8 +87,8 @@ TEST_F(PredTest, ConstantR2Pred) {
|
|||||||
XlaBuilder builder(TestName());
|
XlaBuilder builder(TestName());
|
||||||
ConstantR2<bool>(&builder, {{false, true, true}, {true, false, false}});
|
ConstantR2<bool>(&builder, {{false, true, true}, {true, false, false}});
|
||||||
const string expected = R"(pred[2,3] {
|
const string expected = R"(pred[2,3] {
|
||||||
{ 011 },
|
{ 0, 1, 1 },
|
||||||
{ 100 }
|
{ 1, 0, 0 }
|
||||||
})";
|
})";
|
||||||
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
|
EXPECT_EQ(expected, ExecuteToString(&builder, {}));
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user