Fix two bugs in hlo_rematerialization_test.
First bug: copy instructions change physical layouts, not the logicial dimensions. Second bug: used the wrong variable in ShapeSizePadMinorTo64. These were discovered by using VerifiedHloModule. Also use VerifiedModule for hlo_verifier_tests that should succeed. PiperOrigin-RevId: 276270423 Change-Id: Idcc946c07f77442026f97d8d25eeeab4ec68c6f9
This commit is contained in:
parent
163ba74db0
commit
5ad2eafac2
@ -546,7 +546,7 @@ class CompressingRematerializationTest : public RematerializationTestBase {
|
|||||||
int64 size =
|
int64 size =
|
||||||
ShapeUtil::ByteSizeOfPrimitiveType(descending_shape.element_type());
|
ShapeUtil::ByteSizeOfPrimitiveType(descending_shape.element_type());
|
||||||
for (int64 i = 0; i < descending_shape.rank(); ++i) {
|
for (int64 i = 0; i < descending_shape.rank(); ++i) {
|
||||||
int64 dim = shape.dimensions(i);
|
int64 dim = descending_shape.dimensions(i);
|
||||||
if (i == descending_shape.rank() - 1) {
|
if (i == descending_shape.rank() - 1) {
|
||||||
dim = RoundUpToNearest<int64>(dim, 64);
|
dim = RoundUpToNearest<int64>(dim, 64);
|
||||||
}
|
}
|
||||||
@ -555,8 +555,8 @@ class CompressingRematerializationTest : public RematerializationTestBase {
|
|||||||
return size;
|
return size;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Swap the two most-minor dimensions if the second-minor dimension is bigger
|
// Swap the layout of the two most-minor dimensions if the second-minor
|
||||||
// than the most-minor dimension.
|
// dimension is bigger than the most-minor dimension.
|
||||||
static StatusOr<Shape> ChooseCompactLayoutForShape(const Shape& shape) {
|
static StatusOr<Shape> ChooseCompactLayoutForShape(const Shape& shape) {
|
||||||
Shape result = shape;
|
Shape result = shape;
|
||||||
Layout layout = result.layout();
|
Layout layout = result.layout();
|
||||||
@ -565,8 +565,10 @@ class CompressingRematerializationTest : public RematerializationTestBase {
|
|||||||
int64 most_minor = result.dimensions(most_minor_index);
|
int64 most_minor = result.dimensions(most_minor_index);
|
||||||
int64 second_minor = result.dimensions(second_minor_index);
|
int64 second_minor = result.dimensions(second_minor_index);
|
||||||
if (most_minor < second_minor) {
|
if (most_minor < second_minor) {
|
||||||
result.set_dimensions(most_minor_index, second_minor);
|
Layout new_layout = layout;
|
||||||
result.set_dimensions(second_minor_index, most_minor);
|
new_layout.set_minor_to_major(0, second_minor_index);
|
||||||
|
new_layout.set_minor_to_major(1, most_minor_index);
|
||||||
|
*result.mutable_layout() = new_layout;
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@ -603,7 +605,7 @@ ENTRY %entry {
|
|||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
ParseAndReturnUnverifiedModule(hlo_string));
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||||
RunHloRematerialization(
|
RunHloRematerialization(
|
||||||
@ -643,7 +645,7 @@ ENTRY %entry {
|
|||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
ParseAndReturnUnverifiedModule(hlo_string));
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||||
RunHloRematerialization(
|
RunHloRematerialization(
|
||||||
|
@ -323,7 +323,7 @@ TEST_F(HloVerifierTestAllowMixedPrecision, RngMixedPrecisionAllowed) {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
ParseAndReturnUnverifiedModule(hlo_string));
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
|
||||||
auto status = verifier().Run(module.get()).status();
|
auto status = verifier().Run(module.get()).status();
|
||||||
ASSERT_TRUE(status.ok());
|
ASSERT_TRUE(status.ok());
|
||||||
@ -440,7 +440,7 @@ static const char* const kAddWithLayoutChangeHlo = R"(
|
|||||||
|
|
||||||
TEST_F(HloVerifierTest, AddWithLayoutChange) {
|
TEST_F(HloVerifierTest, AddWithLayoutChange) {
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
auto module, ParseAndReturnUnverifiedModule(kAddWithLayoutChangeHlo));
|
auto module, ParseAndReturnVerifiedModule(kAddWithLayoutChangeHlo));
|
||||||
auto status = verifier().Run(module.get()).status();
|
auto status = verifier().Run(module.get()).status();
|
||||||
ASSERT_TRUE(status.ok());
|
ASSERT_TRUE(status.ok());
|
||||||
}
|
}
|
||||||
@ -462,7 +462,7 @@ TEST_F(HloVerifierTest, ScalarIndexDynamicSlice) {
|
|||||||
debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
|
debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
|
||||||
config.set_debug_options(debug_options);
|
config.set_debug_options(debug_options);
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(
|
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
|
||||||
kScalarIndexDynamicSlice, config));
|
kScalarIndexDynamicSlice, config));
|
||||||
auto status = verifier().Run(module.get()).status();
|
auto status = verifier().Run(module.get()).status();
|
||||||
ASSERT_TRUE(status.ok());
|
ASSERT_TRUE(status.ok());
|
||||||
@ -488,7 +488,7 @@ TEST_F(HloVerifierTest, ScalarIndexDynamicUpdateSlice) {
|
|||||||
debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
|
debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
|
||||||
config.set_debug_options(debug_options);
|
config.set_debug_options(debug_options);
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(
|
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
|
||||||
kScalarIndexDynamicSlice, config));
|
kScalarIndexDynamicSlice, config));
|
||||||
auto status = verifier().Run(module.get()).status();
|
auto status = verifier().Run(module.get()).status();
|
||||||
ASSERT_TRUE(status.ok());
|
ASSERT_TRUE(status.ok());
|
||||||
@ -590,7 +590,7 @@ TEST_F(HloVerifierTestAllowMixedPrecision, SelectMixedPrecisionAllowed) {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
ParseAndReturnUnverifiedModule(hlo_string));
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
|
||||||
auto status = verifier().Run(module.get()).status();
|
auto status = verifier().Run(module.get()).status();
|
||||||
ASSERT_TRUE(status.ok());
|
ASSERT_TRUE(status.ok());
|
||||||
@ -627,7 +627,7 @@ TEST_F(HloVerifierTestLayoutSensitive, CopyStartAndCopyDone) {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
ParseAndReturnUnverifiedModule(hlo_string));
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
|
||||||
auto status = verifier().Run(module.get()).status();
|
auto status = verifier().Run(module.get()).status();
|
||||||
ASSERT_TRUE(status.ok());
|
ASSERT_TRUE(status.ok());
|
||||||
@ -794,7 +794,7 @@ TEST_F(HloVerifierTest, MapOperandComputationMismatch) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HloVerifierTestAllowMixedPrecision, MapOperandComputationMismatch) {
|
TEST_F(HloVerifierTestAllowMixedPrecision, MapOperandComputationMismatch) {
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(
|
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
|
||||||
kMapOperandComputationMismatchHlo));
|
kMapOperandComputationMismatchHlo));
|
||||||
auto status = verifier().Run(module.get()).status();
|
auto status = verifier().Run(module.get()).status();
|
||||||
ASSERT_TRUE(status.ok());
|
ASSERT_TRUE(status.ok());
|
||||||
@ -827,7 +827,7 @@ TEST_F(HloVerifierTest, ReduceOperandComputationMismatch) {
|
|||||||
TEST_F(HloVerifierTestAllowMixedPrecision, ReduceOperandComputationMismatch) {
|
TEST_F(HloVerifierTestAllowMixedPrecision, ReduceOperandComputationMismatch) {
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
auto module,
|
auto module,
|
||||||
ParseAndReturnUnverifiedModule(kReduceOperandComputationMismatchHlo));
|
ParseAndReturnVerifiedModule(kReduceOperandComputationMismatchHlo));
|
||||||
auto status = verifier().Run(module.get()).status();
|
auto status = verifier().Run(module.get()).status();
|
||||||
ASSERT_TRUE(status.ok());
|
ASSERT_TRUE(status.ok());
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user