Fix two bugs in hlo_rematerialization_test.

First bug: copy instructions change physical layouts, not the logicial
dimensions.
Second bug: used the wrong variable in ShapeSizePadMinorTo64.
These were discovered by using VerifiedHloModule.
Also use VerifiedModule for hlo_verifier_tests that should succeed.

PiperOrigin-RevId: 276270423
Change-Id: Idcc946c07f77442026f97d8d25eeeab4ec68c6f9
This commit is contained in:
Adrian Kuegel 2019-10-23 07:28:15 -07:00 committed by TensorFlower Gardener
parent 163ba74db0
commit 5ad2eafac2
2 changed files with 17 additions and 15 deletions

View File

@ -546,7 +546,7 @@ class CompressingRematerializationTest : public RematerializationTestBase {
int64 size =
ShapeUtil::ByteSizeOfPrimitiveType(descending_shape.element_type());
for (int64 i = 0; i < descending_shape.rank(); ++i) {
int64 dim = shape.dimensions(i);
int64 dim = descending_shape.dimensions(i);
if (i == descending_shape.rank() - 1) {
dim = RoundUpToNearest<int64>(dim, 64);
}
@ -555,8 +555,8 @@ class CompressingRematerializationTest : public RematerializationTestBase {
return size;
}
// Swap the two most-minor dimensions if the second-minor dimension is bigger
// than the most-minor dimension.
// Swap the layout of the two most-minor dimensions if the second-minor
// dimension is bigger than the most-minor dimension.
static StatusOr<Shape> ChooseCompactLayoutForShape(const Shape& shape) {
Shape result = shape;
Layout layout = result.layout();
@ -565,8 +565,10 @@ class CompressingRematerializationTest : public RematerializationTestBase {
int64 most_minor = result.dimensions(most_minor_index);
int64 second_minor = result.dimensions(second_minor_index);
if (most_minor < second_minor) {
result.set_dimensions(most_minor_index, second_minor);
result.set_dimensions(second_minor_index, most_minor);
Layout new_layout = layout;
new_layout.set_minor_to_major(0, second_minor_index);
new_layout.set_minor_to_major(1, most_minor_index);
*result.mutable_layout() = new_layout;
}
return result;
}
@ -603,7 +605,7 @@ ENTRY %entry {
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
@ -643,7 +645,7 @@ ENTRY %entry {
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(

View File

@ -323,7 +323,7 @@ TEST_F(HloVerifierTestAllowMixedPrecision, RngMixedPrecisionAllowed) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
ParseAndReturnVerifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_TRUE(status.ok());
@ -440,7 +440,7 @@ static const char* const kAddWithLayoutChangeHlo = R"(
TEST_F(HloVerifierTest, AddWithLayoutChange) {
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnUnverifiedModule(kAddWithLayoutChangeHlo));
auto module, ParseAndReturnVerifiedModule(kAddWithLayoutChangeHlo));
auto status = verifier().Run(module.get()).status();
ASSERT_TRUE(status.ok());
}
@ -462,7 +462,7 @@ TEST_F(HloVerifierTest, ScalarIndexDynamicSlice) {
debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
config.set_debug_options(debug_options);
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
kScalarIndexDynamicSlice, config));
auto status = verifier().Run(module.get()).status();
ASSERT_TRUE(status.ok());
@ -488,7 +488,7 @@ TEST_F(HloVerifierTest, ScalarIndexDynamicUpdateSlice) {
debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
config.set_debug_options(debug_options);
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
kScalarIndexDynamicSlice, config));
auto status = verifier().Run(module.get()).status();
ASSERT_TRUE(status.ok());
@ -590,7 +590,7 @@ TEST_F(HloVerifierTestAllowMixedPrecision, SelectMixedPrecisionAllowed) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
ParseAndReturnVerifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_TRUE(status.ok());
@ -627,7 +627,7 @@ TEST_F(HloVerifierTestLayoutSensitive, CopyStartAndCopyDone) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
ParseAndReturnVerifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_TRUE(status.ok());
@ -794,7 +794,7 @@ TEST_F(HloVerifierTest, MapOperandComputationMismatch) {
}
TEST_F(HloVerifierTestAllowMixedPrecision, MapOperandComputationMismatch) {
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
kMapOperandComputationMismatchHlo));
auto status = verifier().Run(module.get()).status();
ASSERT_TRUE(status.ok());
@ -827,7 +827,7 @@ TEST_F(HloVerifierTest, ReduceOperandComputationMismatch) {
TEST_F(HloVerifierTestAllowMixedPrecision, ReduceOperandComputationMismatch) {
TF_ASSERT_OK_AND_ASSIGN(
auto module,
ParseAndReturnUnverifiedModule(kReduceOperandComputationMismatchHlo));
ParseAndReturnVerifiedModule(kReduceOperandComputationMismatchHlo));
auto status = verifier().Run(module.get()).status();
ASSERT_TRUE(status.ok());
}