diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index e5de53c1c82..5ef3bd09296 100755 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -532,10 +532,10 @@ tf_cc_test( srcs = ["pattern_matcher_test.cc"], deps = [ ":hlo", - ":hlo_parser", ":pattern_matcher", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", "@com_google_absl//absl/strings", @@ -588,7 +588,6 @@ tf_cc_test( srcs = ["hlo_reachability_test.cc"], deps = [ ":hlo", - ":hlo_parser", ":hlo_reachability", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", @@ -1424,7 +1423,6 @@ tf_cc_test( ":hlo_dce", ":hlo_memory_scheduler", ":hlo_ordering", - ":hlo_parser", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -2672,7 +2670,6 @@ tf_cc_test( srcs = ["hlo_replication_analysis_test.cc"], deps = [ ":hlo", - ":hlo_parser", ":hlo_replication_analysis", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", @@ -3913,10 +3910,14 @@ tf_cc_test( srcs = ["tuple_util_test.cc"], deps = [ ":hlo_matchers", + ":hlo_module_config", ":hlo_parser", ":tuple_util", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:verified_hlo_module", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/memory", ], ) @@ -3946,6 +3947,7 @@ tf_cc_test( ":while_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "@com_google_absl//absl/algorithm:container", ], @@ -4004,7 +4006,6 @@ tf_cc_test( srcs = ["while_loop_constant_sinking_test.cc"], deps = [ ":hlo_matchers", - ":hlo_parser", ":while_loop_constant_sinking", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 7a97740ed8b..4c9fa9a2432 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -213,11 +213,10 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(text)); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text)); ASSERT_FALSE(module->has_schedule()); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module_copy, + auto module_copy, HloModule::CreateFromProto(module->ToProto(), module->config())); ASSERT_FALSE(module_copy->has_schedule()); } @@ -235,11 +234,10 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(text)); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text)); ASSERT_TRUE(module->has_schedule()); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module_copy, + auto module_copy, HloModule::CreateFromProto(module->ToProto(), module->config())); ASSERT_TRUE(module_copy->has_schedule()); TF_ASSERT_OK(module_copy->schedule().Verify()); @@ -272,8 +270,7 @@ ENTRY ReduceR3ToR2.v3 { ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(text)); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text)); // Perform various transformations on the graph: // @@ -306,7 +303,7 @@ ENTRY ReduceR3ToR2.v3 { // Serialize and deserialize and verify that the instruction and computations // unique ids are the same. TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module_copy, + auto module_copy, HloModule::CreateFromProto(module->ToProto(), module->config())); // The module IDs should *not* be the same because module ids must be globally @@ -366,8 +363,7 @@ TEST_F(HloModuleTest, VerifyReplaceComputationsWithSortOp) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(text)); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text)); // Create a replacement computation HloComputation* new_comp; diff --git a/tensorflow/compiler/xla/service/hlo_reachability_test.cc b/tensorflow/compiler/xla/service/hlo_reachability_test.cc index b49f22fe721..e973c35f9cf 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability_test.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -208,7 +207,7 @@ TEST_F(HloReachabilityTest, ChannelReachability) { } TEST_F(HloReachabilityTest, ReplaceInstructions) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule test ENTRY entry { diff --git a/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc index 3fabeab203b..958e99dedb8 100644 --- a/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" @@ -57,8 +56,8 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(module_str)); auto param = module->entry_computation()->parameter_instruction(0); param->set_parameter_replicated_at_leaf_buffers( absl::Span{false, true}); @@ -106,8 +105,8 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(module_str)); auto param = module->entry_computation()->parameter_instruction(0); param->set_parameter_replicated_at_leaf_buffers( absl::Span{true, false}); @@ -158,8 +157,8 @@ ENTRY SimpleWhileLoop { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(module_str)); auto param = module->entry_computation()->parameter_instruction(0); param->set_parameter_replicated_at_leaf_buffers( absl::Span{true, true}); @@ -207,8 +206,8 @@ ENTRY WhileLoopParameterAliasingNonReplicatedOutput { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(module_str)); auto param = module->entry_computation()->parameter_instruction(0); param->set_parameter_replicated_at_leaf_buffers( absl::Span{true, true}); @@ -253,8 +252,8 @@ ENTRY WhileLoopDifferentCondition { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(module_str)); auto param = module->entry_computation()->parameter_instruction(0); param->set_parameter_replicated_at_leaf_buffers( absl::Span{true, true}); @@ -302,8 +301,8 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(module_str)); auto param = module->entry_computation()->parameter_instruction(0); param->set_parameter_replicated_at_leaf_buffers( absl::Span{true, true, true, true, false, true, true}); @@ -366,8 +365,8 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(module_str)); auto param = module->entry_computation()->parameter_instruction(0); param->set_parameter_replicated_at_leaf_buffers( absl::Span{true, true, true, true, true, true}); @@ -404,8 +403,8 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(module_str)); auto param = module->entry_computation()->parameter_instruction(0); param->set_parameter_replicated_at_leaf_buffers( absl::Span{true, false, true, true, true}); @@ -430,8 +429,8 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(module_str)); auto param = module->entry_computation()->parameter_instruction(0); param->set_parameter_replicated_at_leaf_buffers( absl::Span{true, true, true, true, false}); diff --git a/tensorflow/compiler/xla/service/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/hlo_schedule_test.cc index e69b3e14ebc..30be21756bf 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/hlo_schedule_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" @@ -53,7 +52,7 @@ ENTRY main { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_str)); + ParseAndReturnVerifiedModule(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, ScheduleModule(module.get(), [](const BufferValue& buffer) { @@ -87,7 +86,7 @@ ENTRY main { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_str)); + ParseAndReturnVerifiedModule(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, ScheduleModule(module.get(), [](const BufferValue& buffer) { @@ -136,7 +135,7 @@ ENTRY main { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_str)); + ParseAndReturnVerifiedModule(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, ScheduleModule(module.get(), [](const BufferValue& buffer) { @@ -180,7 +179,7 @@ ENTRY main { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_str)); + ParseAndReturnVerifiedModule(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, ScheduleModule(module.get(), [](const BufferValue& buffer) { @@ -241,7 +240,7 @@ ENTRY %WhileLoop () -> s32[] { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_str)); + ParseAndReturnVerifiedModule(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, ScheduleModule(module.get(), [](const BufferValue& buffer) { @@ -310,7 +309,7 @@ ENTRY %WhileLoop () -> s32[] { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_str)); + ParseAndReturnVerifiedModule(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, ScheduleModule(module.get(), [](const BufferValue& buffer) { diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 941d0260bc7..1879f5b869c 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -47,7 +47,7 @@ class InstructionFusionForTesting : public InstructionFusion { }; TEST_F(InstructionFusionTest, FuseInstructions) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule test_module ENTRY entry_computation { p0 = f32[4,3]{1,0} parameter(0) @@ -67,7 +67,7 @@ TEST_F(InstructionFusionTest, FuseInstructions) { } TEST_F(InstructionFusionTest, FuseIntoFusionInstruction) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule test_module fused_computation { p1 = f32[4,3] parameter(0) @@ -90,7 +90,7 @@ TEST_F(InstructionFusionTest, FuseIntoFusionInstruction) { } TEST_F(InstructionFusionTest, FuseInstructionsIntoMultiOutput) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule test_module ENTRY entry_computation { p0 = f32[4,3]{1,0} parameter(0) @@ -196,7 +196,7 @@ static int Count(const HloModule& module, HloOpcode op) { } TEST_F(InstructionFusionTest, FuseCheapNonDuplicatableOps) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[4,3]{1,0} parameter(0) @@ -433,7 +433,7 @@ TEST_F(InstructionFusionTest, AllowBinarySameValueOperandsDuplication) { } TEST_F(InstructionFusionTest, FuseDiamondGraphsNoDuplication) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule test_module ENTRY Test { p0 = f32[100] parameter(0) @@ -458,7 +458,7 @@ TEST_F(InstructionFusionTest, FuseDiamondGraphsNoDuplication) { } TEST_F(InstructionFusionTest, FuseDiamondGraphsAllowDuplication) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule test_module ENTRY Test { p0 = f32[100] parameter(0) @@ -484,7 +484,7 @@ TEST_F(InstructionFusionTest, FuseDiamondGraphsAllowDuplication) { TEST_F(InstructionFusionTest, WideningConvertsAreAlwaysDuplicableIntoConsumers) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule test_module ENTRY Test { p0 = f16[100] parameter(0) diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index 7f08ba49a71..b923117318a 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -14,19 +14,21 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/pattern_matcher.h" + #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/platform/test.h" namespace xla { namespace { namespace m = match; +using PatternMatcherTest = HloTestBase; -TEST(PatternMatcherTest, AddOp) { +TEST_F(PatternMatcherTest, AddOp) { constexpr char kModuleStr[] = R"(HloModule two_plus_two_module ENTRY %two_plus_two_computation () -> f32[] { %two = f32[] constant(2) @@ -34,7 +36,7 @@ TEST(PatternMatcherTest, AddOp) { } )"; TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, - ParseAndReturnUnverifiedModule(kModuleStr)); + ParseAndReturnVerifiedModule(kModuleStr)); const HloInstruction* matched_inst; HloInstruction* matched_operand; @@ -66,7 +68,7 @@ TEST(PatternMatcherTest, AddOp) { match::Multiply(&matched_inst, match::Op(), match::Op()))); } -TEST(PatternMatcherTest, ScalarShape) { +TEST_F(PatternMatcherTest, ScalarShape) { auto scalar_shape = ShapeUtil::MakeShape(F32, {}); Shape* matched_shape; EXPECT_TRUE(Match(&scalar_shape, match::Shape(&matched_shape).IsScalar())); @@ -81,7 +83,7 @@ TEST(PatternMatcherTest, ScalarShape) { match::Shape().WithSubshape({0}, match::Shape()).WithElementType(F32))); } -TEST(PatternMatcherTest, DenseArrayShape) { +TEST_F(PatternMatcherTest, DenseArrayShape) { auto array_shape = ShapeUtil::MakeShape(F32, {2, 3, 4}); Shape* matched_shape; EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray())); @@ -104,7 +106,7 @@ TEST(PatternMatcherTest, DenseArrayShape) { EXPECT_EQ(matched_layout, &array_shape.layout()); } -TEST(PatternMatcherTest, SparseArrayShape) { +TEST_F(PatternMatcherTest, SparseArrayShape) { auto array_shape = ShapeUtil::MakeShapeWithSparseLayout(F32, {2, 3, 4}, 10); Shape* matched_shape; EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray())); @@ -127,7 +129,7 @@ TEST(PatternMatcherTest, SparseArrayShape) { EXPECT_EQ(matched_layout, &array_shape.layout()); } -TEST(PatternMatcherTest, TupleShape) { +TEST_F(PatternMatcherTest, TupleShape) { auto tuple_shape = ShapeUtil::MakeTupleShape({ ShapeUtil::MakeShape(F32, {1, 2, 3}), ShapeUtil::MakeShape(S32, {4, 5}), @@ -175,7 +177,7 @@ TEST(PatternMatcherTest, TupleShape) { Match(&tuple_shape, match::Shape().WithSubshape({0, 0}, match::Shape()))); } -TEST(PatternMatcherTest, FusionKind) { +TEST_F(PatternMatcherTest, FusionKind) { constexpr char kModuleStr[] = R"( HloModule test_module @@ -188,7 +190,7 @@ TEST(PatternMatcherTest, FusionKind) { ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=fused_computation })"; TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, - ParseAndReturnUnverifiedModule(kModuleStr)); + ParseAndReturnVerifiedModule(kModuleStr)); auto* root = hlo_module->entry_computation()->root_instruction(); EXPECT_TRUE(Match( @@ -199,7 +201,7 @@ TEST(PatternMatcherTest, FusionKind) { HloInstruction::FusionKind::kLoop))); } -TEST(PatternMatcherTest, GetTupleElement) { +TEST_F(PatternMatcherTest, GetTupleElement) { constexpr char kModuleStr[] = R"( HloModule test_module @@ -208,7 +210,7 @@ TEST(PatternMatcherTest, GetTupleElement) { ROOT gte = f32[] get-tuple-element(p0), index=1 })"; TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, - ParseAndReturnUnverifiedModule(kModuleStr)); + ParseAndReturnVerifiedModule(kModuleStr)); auto* root = hlo_module->entry_computation()->root_instruction(); EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(0))); @@ -218,11 +220,11 @@ TEST(PatternMatcherTest, GetTupleElement) { EXPECT_TRUE(Match(root, match::GetTupleElement(match::Op(), 1))); } -TEST(PatternMatcherTest, AnyOf) { +TEST_F(PatternMatcherTest, AnyOf) { constexpr char kModuleStr[] = R"( HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })"; TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, - ParseAndReturnUnverifiedModule(kModuleStr)); + ParseAndReturnVerifiedModule(kModuleStr)); auto* root = hlo_module->entry_computation()->root_instruction(); EXPECT_TRUE( @@ -236,7 +238,7 @@ TEST(PatternMatcherTest, AnyOf) { match::ConstantScalar(2)))); } -TEST(PatternMatcherTest, ConstantScalar) { +TEST_F(PatternMatcherTest, ConstantScalar) { using match::ConstantEffectiveScalar; using match::ConstantScalar; using match::Op; @@ -253,7 +255,7 @@ TEST(PatternMatcherTest, ConstantScalar) { ROOT tuple = (s32[], s32[1,1], s32[1,2], f32[], f32[]) tuple(a,b,c,d,e) })"; TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, - ParseAndReturnUnverifiedModule(kModuleStr)); + ParseAndReturnVerifiedModule(kModuleStr)); auto* root = hlo_module->entry_computation()->root_instruction(); const HloInstruction* a = root->operand(0); @@ -308,7 +310,7 @@ TEST(PatternMatcherTest, ConstantScalar) { EXPECT_EQ(instr, a); } -TEST(PatternMatcherTest, MultiplyAnyOrder) { +TEST_F(PatternMatcherTest, MultiplyAnyOrder) { using match::ConstantScalar; using match::MultiplyAnyOrder; @@ -320,7 +322,7 @@ TEST(PatternMatcherTest, MultiplyAnyOrder) { ROOT multiply = f16[] multiply(lhs, rhs) })"; TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, - ParseAndReturnUnverifiedModule(kModuleStr)); + ParseAndReturnVerifiedModule(kModuleStr)); auto* root = hlo_module->entry_computation()->root_instruction(); const HloInstruction* instr; @@ -339,7 +341,7 @@ TEST(PatternMatcherTest, MultiplyAnyOrder) { .IsNonConstant())); } -TEST(PatternMatcherTest, AnyOfShortCircuit) { +TEST_F(PatternMatcherTest, AnyOfShortCircuit) { using match::AnyOf; using match::Multiply; using match::Op; @@ -352,7 +354,7 @@ TEST(PatternMatcherTest, AnyOfShortCircuit) { ROOT multiply = f16[] multiply(lhs, rhs) })"; TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, - ParseAndReturnUnverifiedModule(kModuleStr)); + ParseAndReturnVerifiedModule(kModuleStr)); auto* root = hlo_module->entry_computation()->root_instruction(); { @@ -375,7 +377,7 @@ TEST(PatternMatcherTest, AnyOfShortCircuit) { } } -TEST(PatternMatcherTest, AllOf) { +TEST_F(PatternMatcherTest, AllOf) { using match::AllOf; using match::Broadcast; using match::Constant; @@ -384,7 +386,7 @@ TEST(PatternMatcherTest, AllOf) { constexpr char kModuleStr[] = R"( HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })"; TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, - ParseAndReturnUnverifiedModule(kModuleStr)); + ParseAndReturnVerifiedModule(kModuleStr)); auto* root = hlo_module->entry_computation()->root_instruction(); auto f16_scalar = ShapeUtil::MakeShape(F16, {}); @@ -407,7 +409,7 @@ TEST(PatternMatcherTest, AllOf) { Match(root, AllOf(Broadcast(Op()), scalar_pattern))); } -TEST(PatternMatcherTest, AllOfNoCaptureIfNotMatch) { +TEST_F(PatternMatcherTest, AllOfNoCaptureIfNotMatch) { using match::AllOf; using match::Broadcast; using match::Constant; @@ -419,7 +421,7 @@ TEST(PatternMatcherTest, AllOfNoCaptureIfNotMatch) { ROOT v = f16[] constant(42) })"; TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, - ParseAndReturnUnverifiedModule(kModuleStr)); + ParseAndReturnVerifiedModule(kModuleStr)); auto* root = hlo_module->entry_computation()->root_instruction(); const HloInstruction* constant = nullptr; @@ -430,7 +432,7 @@ TEST(PatternMatcherTest, AllOfNoCaptureIfNotMatch) { EXPECT_NE(nullptr, constant); } -TEST(PatternMatcherTest, TestNoCapture) { +TEST_F(PatternMatcherTest, TestNoCapture) { using match::Constant; constexpr char kModuleStr[] = R"( @@ -439,7 +441,7 @@ TEST(PatternMatcherTest, TestNoCapture) { ROOT v = f16[] constant(42) })"; TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, - ParseAndReturnUnverifiedModule(kModuleStr)); + ParseAndReturnVerifiedModule(kModuleStr)); auto* root = hlo_module->entry_computation()->root_instruction(); const HloInstruction* constant = nullptr; @@ -447,7 +449,7 @@ TEST(PatternMatcherTest, TestNoCapture) { EXPECT_EQ(nullptr, constant); } -TEST(PatternMatcherTest, TestCaptureMatchedSubPatternForAnyOf) { +TEST_F(PatternMatcherTest, TestCaptureMatchedSubPatternForAnyOf) { using match::Add; using match::AddAnyOrder; using match::AnyOf; @@ -461,7 +463,7 @@ TEST(PatternMatcherTest, TestCaptureMatchedSubPatternForAnyOf) { ROOT add = f16[] add(u, v) })"; TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, - ParseAndReturnUnverifiedModule(kModuleStr)); + ParseAndReturnVerifiedModule(kModuleStr)); auto* root = hlo_module->entry_computation()->root_instruction(); const HloInstruction* addend0 = nullptr; @@ -477,7 +479,7 @@ TEST(PatternMatcherTest, TestCaptureMatchedSubPatternForAnyOf) { EXPECT_EQ(nullptr, addend2); } -TEST(PatternMatcherTest, TestConcat) { +TEST_F(PatternMatcherTest, TestConcat) { using match::Concatenate; using match::ConstantScalar; using match::Op; @@ -497,7 +499,7 @@ TEST(PatternMatcherTest, TestConcat) { ROOT concat = u32[4] concatenate(r1, r2, r3, r4), dimensions={0} })"; TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, - ParseAndReturnUnverifiedModule(kModuleStr)); + ParseAndReturnVerifiedModule(kModuleStr)); auto* root = hlo_module->entry_computation()->root_instruction(); ASSERT_TRUE(Match( root, @@ -557,7 +559,7 @@ string Explanation(const Elem& elem, const Pattern& pattern) { EXPECT_EQ(Explanation((elem), (pattern)), expected_explanation); \ } while (0) -TEST(PatternMatcherTest, LayoutDescribeToAndExplain) { +TEST_F(PatternMatcherTest, LayoutDescribeToAndExplain) { auto layout = LayoutUtil::MakeLayout({1, 2}); auto layout2 = LayoutUtil::MakeLayout({2, 2}); @@ -577,7 +579,7 @@ TEST(PatternMatcherTest, LayoutDescribeToAndExplain) { "Layout has format DENSE but expected SPARSE"); } -TEST(PatternMatcherTest, CustomCallTargetMatcherDescribeAndExplain) { +TEST_F(PatternMatcherTest, CustomCallTargetMatcherDescribeAndExplain) { constexpr char kModuleStr[] = R"( HloModule test_module @@ -587,7 +589,7 @@ TEST(PatternMatcherTest, CustomCallTargetMatcherDescribeAndExplain) { )"; TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, - ParseAndReturnUnverifiedModule(kModuleStr)); + ParseAndReturnVerifiedModule(kModuleStr)); auto* root = hlo_module->entry_computation()->root_instruction(); EXPECT_TRUE(Match(root, match::Op().WithCustomCallTarget("test_target"))); @@ -600,7 +602,7 @@ TEST(PatternMatcherTest, CustomCallTargetMatcherDescribeAndExplain) { "out = f32[] custom-call(), custom_call_target=\"test_target\""); } -TEST(PatternMatcherTest, ShapeDescribeToAndExplain) { +TEST_F(PatternMatcherTest, ShapeDescribeToAndExplain) { auto shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {0, 1}); auto layout = shape.layout(); @@ -715,7 +717,7 @@ std::unique_ptr SetName(absl::string_view name, return instr; } -TEST(PatternMatcherTest, HloInstructionDescribeToAndExplain) { +TEST_F(PatternMatcherTest, HloInstructionDescribeToAndExplain) { std::unique_ptr iota = SetName("i", HloInstruction::CreateIota(ShapeUtil::MakeShape(S32, {42}), /*iota_dimension=*/0)); @@ -810,7 +812,7 @@ TEST(PatternMatcherTest, HloInstructionDescribeToAndExplain) { "in c = s32[] constant(0)")); } -TEST(PatternMatcherTest, HloInstructionMatcherAnyOrderDescribeTo) { +TEST_F(PatternMatcherTest, HloInstructionMatcherAnyOrderDescribeTo) { auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); EXPECT_DESC_AND_EXPLANATION( SetName("a", HloInstruction::CreateBinary( @@ -866,7 +868,7 @@ TEST(PatternMatcherTest, HloInstructionMatcherAnyOrderDescribeTo) { "in a = s32[] add(s32[] p, s32[] c)"); } -TEST(PatternMatcherTest, AnyOfMatcherDescribeToAndExplain) { +TEST_F(PatternMatcherTest, AnyOfMatcherDescribeToAndExplain) { EXPECT_DESC_AND_EXPLANATION( SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))), m::AnyOf(m::Op().WithName("foo"), @@ -887,7 +889,7 @@ TEST(PatternMatcherTest, AnyOfMatcherDescribeToAndExplain) { " in c = s32[] constant(0)"); } -TEST(PatternMatcherTest, Parameter) { +TEST_F(PatternMatcherTest, Parameter) { auto param = HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p1"); auto non_param = @@ -911,7 +913,7 @@ TEST(PatternMatcherTest, Parameter) { "in p0 = f32[] parameter(0)"); } -TEST(PatternMatcherTest, OneUseAndOneUser) { +TEST_F(PatternMatcherTest, OneUseAndOneUser) { auto param = HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0"); @@ -966,7 +968,7 @@ TEST(PatternMatcherTest, OneUseAndOneUser) { "in p0 = f32[] parameter(0)"); } -TEST(HloMatchersTest, Comparison) { +TEST_F(PatternMatcherTest, Comparison) { auto shape = ShapeUtil::MakeShape(F32, {1}); auto p0 = HloInstruction::CreateParameter(0, shape, "param.0"); auto p1 = HloInstruction::CreateParameter(1, shape, "param.1"); diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index c6666ce842b..345bb987b0d 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -68,8 +68,8 @@ ENTRY entry_computation { ROOT dot = f32[2,2]{1,0} dot(x, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); FoldTranspose(module.get()); @@ -90,8 +90,8 @@ ENTRY entry_computation { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); TransposeFolding transpose_folding( [](const HloInstruction& dot, @@ -118,7 +118,7 @@ ENTRY entry_computation { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo_string)); TransposeFolding transpose_folding( @@ -146,8 +146,8 @@ ENTRY entry_computation { ROOT dot = f32[1,3]{1,0} dot(transpose, transpose.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); FoldTranspose(module.get()); @@ -204,8 +204,8 @@ ENTRY entry_computation { ROOT call = f32[2,2]{1,0} call(y, x), to_apply=callee } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); FoldTranspose(module.get()); const HloComputation* callee = module->GetComputationWithName("callee"); diff --git a/tensorflow/compiler/xla/service/tuple_util_test.cc b/tensorflow/compiler/xla/service/tuple_util_test.cc index 807089c9d05..85d78af0b09 100644 --- a/tensorflow/compiler/xla/service/tuple_util_test.cc +++ b/tensorflow/compiler/xla/service/tuple_util_test.cc @@ -15,16 +15,22 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_util.h" +#include + +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/verified_hlo_module.h" namespace xla { namespace { namespace op = ::xla::testing::opcode_matchers; -StatusOr> GetParsedModule( +StatusOr> GetParsedModule( HloComputation** entry_computation, HloInstruction** param0, HloInstruction** param1) { const char* const hlo_string = R"( @@ -36,8 +42,12 @@ ENTRY entry { } )"; - TF_ASSIGN_OR_RETURN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_string)); + auto module = absl::make_unique( + "TupleUtilTest", HloModuleConfig(), /*verifier_layout_sensitive=*/true, + /*allow_mixed_precision_in_hlo_verifier=*/false, + ShapeUtil::ByteSizeOfElements); + TF_RETURN_IF_ERROR(ParseHloString(hlo_string, module.get())); + TF_RETURN_IF_ERROR(module->Verify()); *entry_computation = module->entry_computation(); *param0 = (*entry_computation)->parameter_instruction(0); @@ -51,8 +61,7 @@ TEST(TupleUtilTest, ExtractPrefix) { HloComputation* entry_computation; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, - GetParsedModule(&entry_computation, ¶m0, ¶m1)); + auto module, GetParsedModule(&entry_computation, ¶m0, ¶m1)); HloInstruction* prefix = TupleUtil::ExtractPrefix(param0, 2); @@ -65,8 +74,7 @@ TEST(TupleUtilTest, AppendSuffix) { HloComputation* entry_computation; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, - GetParsedModule(&entry_computation, ¶m0, ¶m1)); + auto module, GetParsedModule(&entry_computation, ¶m0, ¶m1)); HloInstruction* with_suffix = TupleUtil::AppendSuffix(param0, {param1, param1}); diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc index 6c6fd387d5f..088529444ee 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { @@ -25,8 +25,7 @@ namespace { namespace op = xla::testing::opcode_matchers; using ::testing::_; - -class WhileLoopConstantSinkingTest : public ::testing::Test {}; +using WhileLoopConstantSinkingTest = HloTestBase; TEST_F(WhileLoopConstantSinkingTest, SinkOneConstant) { const char* const hlo_string = R"( @@ -54,8 +53,8 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, WhileLoopConstantSinking{}.Run(module.get())); @@ -94,8 +93,8 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, WhileLoopConstantSinking{}.Run(module.get())); @@ -135,8 +134,8 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, WhileLoopConstantSinking{}.Run(module.get())); @@ -183,8 +182,8 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, WhileLoopConstantSinking{}.Run(module.get())); @@ -225,8 +224,8 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, WhileLoopConstantSinking{}.Run(module.get())); @@ -271,8 +270,8 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, WhileLoopConstantSinking{}.Run(module.get())); @@ -311,8 +310,8 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, WhileLoopConstantSinking{}.Run(module.get())); @@ -354,8 +353,8 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, WhileLoopConstantSinking{}.Run(module.get())); @@ -405,8 +404,8 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, WhileLoopConstantSinking{}.Run(module.get())); ASSERT_TRUE(changed); diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc index eea72e808a1..aa5da855e33 100644 --- a/tensorflow/compiler/xla/service/while_util_test.cc +++ b/tensorflow/compiler/xla/service/while_util_test.cc @@ -15,10 +15,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_util.h" +#include + #include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" namespace xla { @@ -26,10 +29,12 @@ namespace { namespace op = ::xla::testing::opcode_matchers; -StatusOr> GetParsedModule( - HloComputation** entry_computation, HloInstruction** param0, - HloInstruction** param1, HloInstruction** param2) { - const char* const hlo_string = R"( +class WhileUtilTest : public HloTestBase { + protected: + StatusOr> GetParsedModule( + HloComputation** entry_computation, HloInstruction** param0, + HloInstruction** param1, HloInstruction** param2) { + const char* const hlo_string = R"( HloModule ModuleWithWhile while_body { @@ -50,23 +55,25 @@ ENTRY entry { } )"; - TF_ASSIGN_OR_RETURN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_string)); + // TODO(b/80488902): Use VerifiedHloModule here. + TF_ASSIGN_OR_RETURN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); - *entry_computation = module->entry_computation(); - *param0 = (*entry_computation)->parameter_instruction(0); - *param1 = (*entry_computation)->parameter_instruction(1); - *param2 = (*entry_computation)->parameter_instruction(2); + *entry_computation = module->entry_computation(); + *param0 = (*entry_computation)->parameter_instruction(0); + *param1 = (*entry_computation)->parameter_instruction(1); + *param2 = (*entry_computation)->parameter_instruction(2); - return std::move(module); -} + return std::move(module); + } +}; -TEST(WhileUtil, MakeZeroInstructionsLiveOp) { +TEST_F(WhileUtilTest, MakeZeroInstructionsLiveOp) { HloInstruction *param0, *param1, *param2; HloComputation* entry_computation; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, + auto module, GetParsedModule(&entry_computation, ¶m0, ¶m1, ¶m2)); HloInstruction* while_instr = entry_computation->root_instruction(); @@ -92,12 +99,12 @@ TEST(WhileUtil, MakeZeroInstructionsLiveOp) { op::GetTupleElement(param_reconstructed, 1))); } -TEST(WhileUtilTest, MakeTwoInstructionsLive) { +TEST_F(WhileUtilTest, MakeTwoInstructionsLive) { HloInstruction *param0, *param1, *param2; HloComputation* entry_computation; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, + auto module, GetParsedModule(&entry_computation, ¶m0, ¶m1, ¶m2)); HloInstruction* while_instr = entry_computation->root_instruction(); @@ -128,7 +135,7 @@ TEST(WhileUtilTest, MakeTwoInstructionsLive) { op::GetTupleElement(op::Parameter(0), 3))); } -TEST(WhileUtilTest, GetInvariantGTEsForWhileBody) { +TEST_F(WhileUtilTest, GetInvariantGTEsForWhileBody) { const char* const hlo_string = R"( HloModule ModuleWithWhile @@ -151,8 +158,8 @@ ENTRY main { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); HloComputation* while_body = module->GetComputationWithName("body"); @@ -166,7 +173,7 @@ ENTRY main { EXPECT_EQ((*gte_list.begin())->name(), "gte.0"); } -TEST(WhileUtilTest, AlwaysRemovePreviousWhileBody) { +TEST_F(WhileUtilTest, AlwaysRemovePreviousWhileBody) { const char* const hlo_string = R"( HloModule WhileWithSideEffects @@ -192,8 +199,8 @@ ENTRY main { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); HloComputation* main = module->GetComputationWithName("main"); HloInstruction* while_instr = main->root_instruction();