diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 996c05f8460..e18521811c0 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -546,7 +546,7 @@ class CompressingRematerializationTest : public RematerializationTestBase { int64 size = ShapeUtil::ByteSizeOfPrimitiveType(descending_shape.element_type()); for (int64 i = 0; i < descending_shape.rank(); ++i) { - int64 dim = shape.dimensions(i); + int64 dim = descending_shape.dimensions(i); if (i == descending_shape.rank() - 1) { dim = RoundUpToNearest(dim, 64); } @@ -555,8 +555,8 @@ class CompressingRematerializationTest : public RematerializationTestBase { return size; } - // Swap the two most-minor dimensions if the second-minor dimension is bigger - // than the most-minor dimension. + // Swap the layout of the two most-minor dimensions if the second-minor + // dimension is bigger than the most-minor dimension. static StatusOr ChooseCompactLayoutForShape(const Shape& shape) { Shape result = shape; Layout layout = result.layout(); @@ -565,8 +565,10 @@ class CompressingRematerializationTest : public RematerializationTestBase { int64 most_minor = result.dimensions(most_minor_index); int64 second_minor = result.dimensions(second_minor_index); if (most_minor < second_minor) { - result.set_dimensions(most_minor_index, second_minor); - result.set_dimensions(second_minor_index, most_minor); + Layout new_layout = layout; + new_layout.set_minor_to_major(0, second_minor_index); + new_layout.set_minor_to_major(1, most_minor_index); + *result.mutable_layout() = new_layout; } return result; } @@ -603,7 +605,7 @@ ENTRY %entry { )"; TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(hlo_string)); + ParseAndReturnVerifiedModule(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( @@ -643,7 +645,7 @@ ENTRY %entry { )"; TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(hlo_string)); + ParseAndReturnVerifiedModule(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 93e6b9469d8..df603102157 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -323,7 +323,7 @@ TEST_F(HloVerifierTestAllowMixedPrecision, RngMixedPrecisionAllowed) { } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(hlo_string)); + ParseAndReturnVerifiedModule(hlo_string)); auto status = verifier().Run(module.get()).status(); ASSERT_TRUE(status.ok()); @@ -440,7 +440,7 @@ static const char* const kAddWithLayoutChangeHlo = R"( TEST_F(HloVerifierTest, AddWithLayoutChange) { TF_ASSERT_OK_AND_ASSIGN( - auto module, ParseAndReturnUnverifiedModule(kAddWithLayoutChangeHlo)); + auto module, ParseAndReturnVerifiedModule(kAddWithLayoutChangeHlo)); auto status = verifier().Run(module.get()).status(); ASSERT_TRUE(status.ok()); } @@ -462,7 +462,7 @@ TEST_F(HloVerifierTest, ScalarIndexDynamicSlice) { debug_options.set_xla_allow_scalar_index_dynamic_ops(true); config.set_debug_options(debug_options); - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule( + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( kScalarIndexDynamicSlice, config)); auto status = verifier().Run(module.get()).status(); ASSERT_TRUE(status.ok()); @@ -488,7 +488,7 @@ TEST_F(HloVerifierTest, ScalarIndexDynamicUpdateSlice) { debug_options.set_xla_allow_scalar_index_dynamic_ops(true); config.set_debug_options(debug_options); - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule( + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( kScalarIndexDynamicSlice, config)); auto status = verifier().Run(module.get()).status(); ASSERT_TRUE(status.ok()); @@ -590,7 +590,7 @@ TEST_F(HloVerifierTestAllowMixedPrecision, SelectMixedPrecisionAllowed) { } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(hlo_string)); + ParseAndReturnVerifiedModule(hlo_string)); auto status = verifier().Run(module.get()).status(); ASSERT_TRUE(status.ok()); @@ -627,7 +627,7 @@ TEST_F(HloVerifierTestLayoutSensitive, CopyStartAndCopyDone) { } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(hlo_string)); + ParseAndReturnVerifiedModule(hlo_string)); auto status = verifier().Run(module.get()).status(); ASSERT_TRUE(status.ok()); @@ -794,7 +794,7 @@ TEST_F(HloVerifierTest, MapOperandComputationMismatch) { } TEST_F(HloVerifierTestAllowMixedPrecision, MapOperandComputationMismatch) { - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule( + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( kMapOperandComputationMismatchHlo)); auto status = verifier().Run(module.get()).status(); ASSERT_TRUE(status.ok()); @@ -827,7 +827,7 @@ TEST_F(HloVerifierTest, ReduceOperandComputationMismatch) { TEST_F(HloVerifierTestAllowMixedPrecision, ReduceOperandComputationMismatch) { TF_ASSERT_OK_AND_ASSIGN( auto module, - ParseAndReturnUnverifiedModule(kReduceOperandComputationMismatchHlo)); + ParseAndReturnVerifiedModule(kReduceOperandComputationMismatchHlo)); auto status = verifier().Run(module.get()).status(); ASSERT_TRUE(status.ok()); }