diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index d2668c5a2bf..e03f34a4667 100755 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -466,7 +466,6 @@ tf_cc_test( ":hlo_dce", ":hlo_memory_scheduler", ":hlo_ordering", - ":hlo_parser", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -618,9 +617,9 @@ tf_cc_test( srcs = ["hlo_matchers_test.cc"], deps = [ ":hlo_matchers", - ":hlo_parser", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -631,7 +630,6 @@ tf_cc_test( deps = [ ":hlo", ":hlo_casting_utils", - ":hlo_parser", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_util", @@ -1287,7 +1285,6 @@ tf_cc_test( ":hlo_dataflow_analysis", ":hlo_memory_scheduler", ":hlo_ordering", - ":hlo_parser", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -1364,7 +1361,6 @@ tf_cc_test( ":hlo_matchers", ":hlo_module_group", ":hlo_module_group_metadata", - ":hlo_parser", ":hlo_proto", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", @@ -1447,7 +1443,6 @@ tf_cc_test( ":hlo_dce", ":hlo_memory_scheduler", ":hlo_ordering", - ":hlo_parser", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -1848,6 +1843,7 @@ tf_cc_test( ":hlo_pass", ":pattern_matcher", ":pattern_matcher_gmock", + ":shape_inference", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1941,8 +1937,8 @@ tf_cc_test( srcs = ["gather_expander_test.cc"], deps = [ ":gather_expander", - ":hlo_parser", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_macros_header", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep ], @@ -1976,7 +1972,6 @@ tf_cc_test( ":conditional_simplifier", ":hlo", ":hlo_matchers", - ":hlo_parser", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -2017,7 +2012,6 @@ tf_cc_test( ":convolution_group_converter", ":hlo", ":hlo_matchers", - ":hlo_parser", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -2308,7 +2302,6 @@ xla_test( ":hlo_get_dimension_size_rewriter", ":hlo_matchers", ":hlo_parser", - ":hlo_runner", "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -2491,7 +2484,6 @@ tf_cc_test( ":cpu_plugin", ":hlo_cost_analysis", ":hlo_execution_profile", - ":hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -2505,7 +2497,6 @@ tf_cc_test( deps = [ ":hlo", ":hlo_matchers", - ":hlo_parser", ":pattern_matcher", ":pattern_matcher_gmock", "//tensorflow/compiler/xla:literal", @@ -2526,7 +2517,6 @@ tf_cc_test( ":hlo", ":hlo_matchers", ":hlo_memory_scheduler", - ":hlo_parser", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -2720,7 +2710,6 @@ tf_cc_test( deps = [ ":hlo", ":hlo_liveness_analysis", - ":hlo_parser", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -3176,7 +3165,6 @@ tf_cc_test( deps = [ ":hlo", ":hlo_module_dce", - ":hlo_parser", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 20df2637606..33d49392fe1 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -3641,9 +3642,6 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { } } - auto out_dims = in_dims; - out_dims[in_channel_idx] = options.f_output_channels; - auto make_shape = [](absl::Span dims, bool minor_to_major_layout) { if (minor_to_major_layout) { @@ -3654,20 +3652,26 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { }; auto in_shape = make_shape(in_dims, options.input_minor_to_major_layout); auto f_shape = make_shape(f_dims, options.filter_minor_to_major_layout); - auto out_shape = make_shape(out_dims, options.output_minor_to_major_layout); HloInstruction* input = b.AddInstruction(HloInstruction::CreateParameter(0, in_shape, "input")); HloInstruction* filter = b.AddInstruction(HloInstruction::CreateParameter(1, f_shape, "filter")); + Shape out_shape = ShapeInference::InferConvolveShape( + in_shape, f_shape, /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums) + .ValueOrDie(); + if (options.output_minor_to_major_layout) { + out_shape = ShapeUtil::MakeShapeWithLayout(F32, out_shape.dimensions(), + {0, 1, 2, 3}); + } b.AddInstruction(HloInstruction::CreateConvolve( out_shape, input, filter, /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - // TODO(b/80488902): verify this module. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(b.Build()); AlgebraicSimplifierOptions simplifier_options; @@ -3841,8 +3845,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x). TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { - // TODO(b/80488902): verify this module. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -3883,9 +3886,10 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { dim->set_padding_high(100); dim->set_window_dilation(1); dim->set_base_dilation(1); + dim->set_stride(1); } const Shape reduce_window_shape = - ShapeUtil::MakeShape(F32, {111, 113, 113, 115}); + ShapeUtil::MakeShape(F32, {111, 113, 113, 116}); HloInstruction* reduce_init_value = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f))); HloInstruction* reduce_window = @@ -3923,8 +3927,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to // ReduceWindow(Convert(op), x). TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { - // TODO(b/80488902): verify this module. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -3969,9 +3972,10 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { dim->set_padding_high(100); dim->set_window_dilation(1); dim->set_base_dilation(1); + dim->set_stride(1); } const Shape reduce_window_shape = - ShapeUtil::MakeShape(F32, {111, 113, 113, 115}); + ShapeUtil::MakeShape(F32, {111, 113, 113, 116}); HloInstruction* reduce_init_value = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f))); HloInstruction* reduce_window = diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 1c985485d43..8ea38aa5a1e 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -2605,8 +2605,7 @@ ENTRY entry_computation { } )"; - auto module_or_status = - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + auto module_or_status = ParseAndReturnVerifiedModule(hlo_string); auto module = module_or_status.ConsumeValueOrDie(); RunCopyInsertion(module.get()); diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc index 6dba78fe4fe..8a7fba6a48f 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -202,7 +201,7 @@ ENTRY main { ROOT result = (f32[20, 40]) conditional(p,t,t), false_computation=on_false, true_computation=on_true } )"; - auto status = ParseAndReturnUnverifiedModule(hlo_string); + auto status = ParseAndReturnVerifiedModule(hlo_string); TF_ASSERT_OK(status.status()); HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status()); @@ -268,7 +267,7 @@ TEST_F(ConditionalSimplifierTest, param.1), true_computation=computation.1, false_computation=computation.2 })"; - auto status = ParseAndReturnUnverifiedModule(hlo_string); + auto status = ParseAndReturnVerifiedModule(hlo_string); TF_ASSERT_OK(status.status()); std::unique_ptr module = status.ConsumeValueOrDie(); HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); @@ -315,7 +314,7 @@ ENTRY main { ROOT result = () tuple() } )"; - auto status = ParseAndReturnUnverifiedModule(hlo_string); + auto status = ParseAndReturnVerifiedModule(hlo_string); TF_ASSERT_OK(status.status()); HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status()); @@ -359,7 +358,7 @@ ENTRY main { ROOT result = (f32[10,10]{1,0}) tuple(get-first-index) } )"; - auto status = ParseAndReturnUnverifiedModule(hlo_string); + auto status = ParseAndReturnVerifiedModule(hlo_string); TF_ASSERT_OK(status.status()); HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status()); @@ -405,7 +404,7 @@ ENTRY main { ROOT result = (f32[10,10]{1,0}) tuple(get-second-index) } )"; - auto status = ParseAndReturnUnverifiedModule(hlo_string); + auto status = ParseAndReturnVerifiedModule(hlo_string); TF_ASSERT_OK(status.status()); HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status()); @@ -471,7 +470,7 @@ ENTRY main { ROOT add = f32[] add(gte.0, gte.1) } )"; - auto status = ParseAndReturnUnverifiedModule(hlo_string); + auto status = ParseAndReturnVerifiedModule(hlo_string); TF_ASSERT_OK(status.status()); HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status()); diff --git a/tensorflow/compiler/xla/service/convolution_group_converter_test.cc b/tensorflow/compiler/xla/service/convolution_group_converter_test.cc index 85c54d31582..a3c26ad59b5 100644 --- a/tensorflow/compiler/xla/service/convolution_group_converter_test.cc +++ b/tensorflow/compiler/xla/service/convolution_group_converter_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" @@ -44,7 +43,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,2], filter: f32[1,1,2]) -> f32[1,2 ROOT %convolution = f32[1,2,2]{2,0,1} convolution(f32[1,2,2]{2,0,1} %copy, f32[1,1,2]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, feature_group_count=2 })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_string)); + ParseAndReturnVerifiedModule(hlo_string)); auto computation = module->entry_computation(); HloInstruction* root = computation->root_instruction(); @@ -76,7 +75,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,4], filter: f32[1,2,2]) -> f32[1,2 ROOT %convolution = f32[1,2,2]{2,0,1} convolution(f32[1,2,4]{2,0,1} %copy, f32[1,2,2]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, feature_group_count=2 })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_string)); + ParseAndReturnVerifiedModule(hlo_string)); auto computation = module->entry_computation(); HloInstruction* root = computation->root_instruction(); @@ -106,7 +105,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[16,19,19,512]{3,2,1,0}, filter: f32[16 ROOT %convolution = f32[3,3,512,1]{3,2,1,0} convolution(f32[16,19,19,512]{3,2,1,0} %input, f32[16,19,19,512]{3,2,1,0} %filter), window={size=19x19 pad=1_1x1_1}, dim_labels=f01b_i01o->01fb, batch_group_count=512 })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_string)); + ParseAndReturnVerifiedModule(hlo_string)); auto computation = module->entry_computation(); HloInstruction* root = computation->root_instruction(); diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index f0ac579a387..cde75d0c16c 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -404,7 +404,7 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { class WhileCopyInsertionTest : public CopyInsertionTest { protected: - WhileCopyInsertionTest() : module_(CreateNewUnverifiedModule()) {} + WhileCopyInsertionTest() : module_(CreateNewVerifiedModule()) {} // Builds a While condition computation which reads the induction variable // from the tuple parameter, and returns a predicate indicating whether this @@ -451,8 +451,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto data = builder.AddInstruction( HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); // Use 'induction_variable' in computation with no path to output tuple. + Shape f32_scalar_shape = ShapeUtil::MakeShape(F32, {}); + auto convert = builder.AddInstruction( + HloInstruction::CreateConvert(f32_scalar_shape, induction_variable)); auto update = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8})); + HloInstruction::CreateBroadcast(data_shape_, convert, {})); auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, data, update)); // Create output Tuple. @@ -521,8 +524,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest { HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); // Use 'induction_variable' in computation with no path to output tuple. + Shape f32_scalar_shape = ShapeUtil::MakeShape(F32, {}); + auto convert = builder.AddInstruction( + HloInstruction::CreateConvert(f32_scalar_shape, induction_variable)); auto update = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8})); + HloInstruction::CreateBroadcast(data_shape_, convert, {})); auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, data, update)); // Create output Tuple. @@ -685,11 +691,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto one = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto v1 = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, one, {1})); + HloInstruction::CreateBroadcast(data_shape_, one, {})); auto zero = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto v2 = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + HloInstruction::CreateBroadcast(data_shape_, zero, {})); auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({v1, v2})); auto tuple2 = builder.AddInstruction(HloInstruction::CreateTuple({v2, v1})); @@ -709,7 +715,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto one = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto one_vec = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, one, {1})); + HloInstruction::CreateBroadcast(data_shape_, one, {})); auto data_init = builder.AddInstruction(HloInstruction::CreateTuple({one_vec, one_vec})); @@ -722,7 +728,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto one = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto data_init = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, one, {1})); + HloInstruction::CreateBroadcast(data_shape_, one, {})); auto one_vec = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR1( {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); @@ -840,10 +846,10 @@ TEST_F(WhileCopyInsertionTest, DependentTupleElements) { ASSERT_EQ(add->opcode(), HloOpcode::kAdd); ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast); - EXPECT_THAT( - while_hlo->while_body()->root_instruction(), - op::Tuple(op::Add(op::Copy(), op::Constant()), - op::Add(op::GetTupleElement(), op::Broadcast(op::Copy())))); + EXPECT_THAT(while_hlo->while_body()->root_instruction(), + op::Tuple(op::Add(op::Copy(), op::Constant()), + op::Add(op::GetTupleElement(), + op::Broadcast(op::Convert(op::Copy()))))); // Both init indices need copies as they are constants. EXPECT_THAT(while_hlo->operand(0), @@ -953,12 +959,17 @@ TEST_F(WhileCopyInsertionTest, auto data_param = builder.AddInstruction( HloInstruction::CreateParameter(1, data_shape_, "data")); // Add dummy ops to ensure loop_init elements aren't entry parameters. - auto iter_value = builder.AddInstruction(HloInstruction::CreateUnary( - iter_param->shape(), HloOpcode::kExp, iter_param)); + Shape f32_scalar_shape = ShapeUtil::MakeShape(F32, {}); + auto convert = builder.AddInstruction( + HloInstruction::CreateConvert(f32_scalar_shape, iter_param)); + auto iter_value = builder.AddInstruction( + HloInstruction::CreateUnary(convert->shape(), HloOpcode::kExp, convert)); + auto convert2 = builder.AddInstruction( + HloInstruction::CreateConvert(induction_variable_shape_, iter_value)); auto data_value = builder.AddInstruction(HloInstruction::CreateUnary( data_param->shape(), HloOpcode::kExp, data_param)); auto loop_init = builder.AddInstruction( - HloInstruction::CreateTuple({iter_value, data_value})); + HloInstruction::CreateTuple({convert2, data_value})); auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile( loop_state_shape_, condition1, body1, loop_init)); @@ -983,9 +994,9 @@ TEST_F(WhileCopyInsertionTest, EXPECT_EQ(CountCopies(*entry), 2); EXPECT_THAT(while_hlo1->operand(0), - op::Tuple(op::Exp(), op::Copy(op::Exp()))); + op::Tuple(op::Convert(op::Exp()), op::Copy(op::Exp()))); EXPECT_THAT(while_hlo2->operand(0), - op::Tuple(op::Exp(), op::Copy(op::Exp()))); + op::Tuple(op::Convert(op::Exp()), op::Copy(op::Exp()))); } // Tests while body computation with nested tuple elements: diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc index 2707a0ffc05..31810feaec2 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc @@ -565,7 +565,7 @@ ENTRY main { operand = s32[20,10]{1,0} parameter(0) indices = s32[32,20] parameter(1) dynamic_size = s32[] parameter(2) - ROOT gather = f32[32,10,10]{2,1,0} gather(%operand, %indices), + ROOT gather = s32[32,20,10]{2,1,0} gather(%operand, %indices), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, @@ -574,7 +574,7 @@ ENTRY main { } )"; - TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnUnverifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_text)); TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( DynamicParameterBinding::DynamicParameter{2, {}}, DynamicParameterBinding::DynamicDimension{1, {}, 0})); diff --git a/tensorflow/compiler/xla/service/dynamic_index_splitter_test.cc b/tensorflow/compiler/xla/service/dynamic_index_splitter_test.cc index b4010a9af09..d8803b62eed 100644 --- a/tensorflow/compiler/xla/service/dynamic_index_splitter_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_index_splitter_test.cc @@ -45,8 +45,8 @@ TEST_F(DynamicIndexSplitterTest, DynamicSlice) { debug_options.set_xla_allow_scalar_index_dynamic_ops(true); config.set_debug_options(debug_options); - TF_ASSERT_OK_AND_ASSIGN( - auto module, ParseAndReturnUnverifiedModule(kDynamicSlice, config)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kDynamicSlice, config)); TF_ASSERT_OK_AND_ASSIGN(bool changed, DynamicIndexSplitter().Run(module.get())); EXPECT_TRUE(changed); @@ -84,7 +84,7 @@ TEST_F(DynamicIndexSplitterTest, DynamicUpdateSlice) { config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN( - auto module, ParseAndReturnUnverifiedModule(kDynamicUpdateSlice, config)); + auto module, ParseAndReturnVerifiedModule(kDynamicUpdateSlice, config)); TF_ASSERT_OK_AND_ASSIGN(bool changed, DynamicIndexSplitter().Run(module.get())); EXPECT_TRUE(changed); @@ -122,8 +122,8 @@ TEST_F(DynamicIndexSplitterTest, AlreadyScalar) { debug_options.set_xla_allow_scalar_index_dynamic_ops(true); config.set_debug_options(debug_options); - TF_ASSERT_OK_AND_ASSIGN( - auto module, ParseAndReturnUnverifiedModule(kDynamicSlice, config)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kDynamicSlice, config)); TF_ASSERT_OK_AND_ASSIGN(bool changed, DynamicIndexSplitter().Run(module.get())); EXPECT_FALSE(changed); diff --git a/tensorflow/compiler/xla/service/dynamic_padder_test.cc b/tensorflow/compiler/xla/service/dynamic_padder_test.cc index e09d8235a63..7b377edc43f 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder_test.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/service/hlo_runner.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" @@ -220,10 +219,8 @@ TEST_F(DynamicPadderTest, ReduceWindowNoPadForTrivialWindow) { class ExecutionTest : public HloTestBase { protected: std::unique_ptr GetHloModule(const string& hlo_text) { - HloModuleConfig config; - config.set_debug_options(GetDebugOptionsForTest()); std::unique_ptr module = - ParseAndReturnUnverifiedModule(hlo_text, config).ValueOrDie(); + ParseAndReturnVerifiedModule(hlo_text).ValueOrDie(); return module; } Literal PadAndExecute(std::unique_ptr module, diff --git a/tensorflow/compiler/xla/service/dynamic_parameter_binding_test.cc b/tensorflow/compiler/xla/service/dynamic_parameter_binding_test.cc index d7bc3ab95a6..f0f07b56a2b 100644 --- a/tensorflow/compiler/xla/service/dynamic_parameter_binding_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_parameter_binding_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" @@ -56,7 +55,7 @@ ENTRY main { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_str)); + ParseAndReturnVerifiedModule(module_str)); DynamicParameterBinding binding; @@ -94,7 +93,7 @@ ENTRY main { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_str)); + ParseAndReturnVerifiedModule(module_str)); DynamicParameterBinding binding; @@ -133,7 +132,7 @@ ENTRY main { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_str)); + ParseAndReturnVerifiedModule(module_str)); DynamicParameterBinding binding; diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc index ac81e4e52e7..706327091d9 100644 --- a/tensorflow/compiler/xla/service/gather_expander_test.cc +++ b/tensorflow/compiler/xla/service/gather_expander_test.cc @@ -14,13 +14,17 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gather_expander.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" + #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" namespace xla { namespace { -TEST(GatherExpanderTest, ErrorStatusOnTooManyIndices) { + +using GatherExpanderTest = HloTestBase; + +TEST_F(GatherExpanderTest, ErrorStatusOnTooManyIndices) { const string hlo_text = R"( HloModule TensorFlowGatherMultipleBatchDims @@ -36,7 +40,7 @@ ENTRY main { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_text)); + ParseAndReturnVerifiedModule(hlo_text)); Status status = GatherExpander{}.Run(module.get()).status(); EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); @@ -47,7 +51,7 @@ ENTRY main { "indices are not supported.")); } -TEST(GatherExpanderTest, AvoidDegenerateDims) { +TEST_F(GatherExpanderTest, AvoidDegenerateDims) { const string hlo_text = R"( HloModule TensorFlowGatherV2 @@ -63,7 +67,7 @@ ENTRY main { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_text)); + ParseAndReturnVerifiedModule(hlo_text)); TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get())); ASSERT_TRUE(changed); @@ -105,7 +109,7 @@ ENTRY main { ShapeUtil::GetTupleElementShape(while_shape, 3))); } -TEST(GatherExpanderTest, CheckOpMetadata) { +TEST_F(GatherExpanderTest, CheckOpMetadata) { const string hlo_text = R"( HloModule TensorFlowGatherV2 @@ -121,7 +125,7 @@ ENTRY main { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_text)); + ParseAndReturnVerifiedModule(hlo_text)); OpMetadata metadata; metadata.set_op_name("Gather"); module->entry_computation()->root_instruction()->set_metadata(metadata); diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 6b3e1307422..ea31d3fdb88 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -158,7 +157,7 @@ TEST_F(HloComputationTest, PostOrderTrace) { builder.AddInstruction(HloInstruction::CreateTrace("foobar", negate1)); auto negate2 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Trace instructions should be at the end of the sort. EXPECT_THAT(computation->MakeInstructionPostOrder(), @@ -697,7 +696,7 @@ ENTRY entry { ROOT t = (f32[128], f32[128]) tuple(add, crs1) })"; TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(hlo_string)); + ParseAndReturnVerifiedModule(hlo_string)); EXPECT_THAT(module->entry_computation()->MakeInstructionPostOrder(), ElementsAre(op::Parameter(), op::AllReduce(), op::AllReduce(), op::Add(), op::Tuple())); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index e5221b4c738..d6617dea1c4 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -1639,7 +1639,7 @@ ENTRY root { p1 = s32[1000] copy(param) ROOT t = (s32[1000], s32[1000]) tuple(p0, p1) })"; - TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnUnverifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_text)); auto entry = module_->entry_computation(); entry->GetInstructionWithName("t"); auto& dataflow_analysis = RunAnalysis(GetParam()); @@ -1990,7 +1990,7 @@ ENTRY %AddDependency (p: f32[3]) -> f32[3] { )"; TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_string, GetModuleConfigForTest())); + ParseAndReturnVerifiedModule(module_string, GetModuleConfigForTest())); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, HloDataflowAnalysis::Run(*module)); @@ -2010,7 +2010,7 @@ INSTANTIATE_TEST_SUITE_P(HloDataflowAnalysisInstantiation, class HloDataflowAnalysisTestBase : public HloTestBase { protected: void BuildModule(std::unique_ptr computation) { - module_ = CreateNewUnverifiedModule(); + module_ = CreateNewVerifiedModule(); computation_ = module_->AddEntryComputation(std::move(computation)); } @@ -2228,7 +2228,7 @@ TEST_F(CanShareOperandBufferWithUserTest, auto one = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto operand = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape, one, {1})); + HloInstruction::CreateBroadcast(data_shape, one, {})); auto neg = builder.AddInstruction( HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, operand)); @@ -2256,7 +2256,7 @@ TEST_F(CanShareOperandBufferWithUserTest, auto zero = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); auto ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice( - slice_shape, param, {zero, zero}, {1, 2, 2})); + slice_shape, param, {zero, zero}, {1, 2})); auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, param, ds, {zero, zero})); @@ -2448,33 +2448,29 @@ TEST_F(CanShareOperandBufferWithUserTest, TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { auto builder = HloComputation::Builder(TestName()); - Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - Shape update_shape = ShapeUtil::MakeShape(F32, {4}); - Shape starts_shape = ShapeUtil::MakeShape(S32, {1}); + Shape data_shape = ShapeUtil::MakeShape(F32, {1, 8}); + Shape update_shape = ShapeUtil::MakeShape(F32, {1, 4}); + Shape starts_shape = ShapeUtil::MakeShape(S32, {2}); auto data = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "data")); auto update = builder.AddInstruction( HloInstruction::CreateParameter(1, update_shape, "update")); - auto start0 = builder.AddInstruction( - HloInstruction::CreateParameter(2, starts_shape, "start0")); - auto start1 = builder.AddInstruction( - HloInstruction::CreateParameter(3, starts_shape, "start1")); + auto start = builder.AddInstruction( + HloInstruction::CreateParameter(2, starts_shape, "start")); auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, data, update, {start0, start1})); + data_shape, data, update, {start})); BuildModuleAndRunAnalysis(builder.Build()); // The DynamicUpdateSlice instruction can share with the data operand, but not - // with update or starts. + // with update or start. EXPECT_TRUE( dataflow_analysis_->CanShareOperandBufferWithUser(data, {}, dus, {})); EXPECT_FALSE( dataflow_analysis_->CanShareOperandBufferWithUser(update, {}, dus, {})); EXPECT_FALSE( - dataflow_analysis_->CanShareOperandBufferWithUser(start0, {}, dus, {})); - EXPECT_FALSE( - dataflow_analysis_->CanShareOperandBufferWithUser(start1, {}, dus, {})); + dataflow_analysis_->CanShareOperandBufferWithUser(start, {}, dus, {})); } TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) { @@ -2498,7 +2494,7 @@ TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) { index_vector_dim=1 } )"; - TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnUnverifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_text)); computation_ = module_->entry_computation(); RunAnalysis(); @@ -2526,7 +2522,7 @@ TEST_F(CanShareOperandBufferWithUserTest, TriangularSolveCanShare) { transpose_a=NO_TRANSPOSE } )"; - TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnUnverifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_text)); computation_ = module_->entry_computation(); RunAnalysis(); @@ -2611,7 +2607,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { auto one = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto add_operand = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape, one, {1})); + HloInstruction::CreateBroadcast(data_shape, one, {})); auto add = builder.AddInstruction(HloInstruction::CreateBinary( data_shape, HloOpcode::kAdd, dot, add_operand)); @@ -2633,7 +2629,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { auto one = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto operand = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape, one, {1})); + HloInstruction::CreateBroadcast(data_shape, one, {})); auto reverse = builder.AddInstruction( HloInstruction::CreateReverse(data_shape, operand, {0, 1})); @@ -2661,7 +2657,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) { auto one = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto operand = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape, one, {1})); + HloInstruction::CreateBroadcast(data_shape, one, {})); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( data_shape, HloOpcode::kMultiply, operand, operand)); auto two = builder.AddInstruction(HloInstruction::CreateConstant( @@ -2683,14 +2679,30 @@ TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) { } TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { + module_ = CreateNewVerifiedModule(); Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + Shape pred_scalar_shape = ShapeUtil::MakeShape(PRED, {}); - auto make_cond = [&data_shape]() { + auto b = HloComputation::Builder(TestName() + ".And"); + auto p0 = b.AddInstruction( + HloInstruction::CreateParameter(0, pred_scalar_shape, "p0")); + auto p1 = b.AddInstruction( + HloInstruction::CreateParameter(1, pred_scalar_shape, "p1")); + b.AddInstruction( + HloInstruction::CreateBinary(pred_scalar_shape, HloOpcode::kAnd, p0, p1)); + auto and_computation = module_->AddEmbeddedComputation(b.Build()); + + auto make_cond = [&data_shape, &and_computation]() { auto builder = HloComputation::Builder(TestName() + ".Cond"); auto data = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "data")); - builder.AddInstruction(HloInstruction::CreateCompare( - ShapeUtil::MakeShape(PRED, {}), data, data, ComparisonDirection::kEq)); + auto compare = builder.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {8}), data, data, ComparisonDirection::kEq)); + auto true_value = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + builder.AddInstruction( + HloInstruction::CreateReduce(ShapeUtil::MakeShape(PRED, {}), compare, + true_value, {0}, and_computation)); return builder.Build(); }; @@ -2703,7 +2715,6 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { return builder.Build(); }; - module_ = CreateNewUnverifiedModule(); HloComputation* cond_computation = module_->AddEmbeddedComputation(make_cond()); HloComputation* body_computation = @@ -2734,11 +2745,11 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { auto one = sub_builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto ones = sub_builder.AddInstruction( - HloInstruction::CreateBroadcast(shape, one, {1})); + HloInstruction::CreateBroadcast(shape, one, {})); auto add = sub_builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones)); - module_ = CreateNewUnverifiedModule(); + module_ = CreateNewVerifiedModule(); auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build()); sub_computation->CreateFusionInstruction({add, ones}, HloInstruction::FusionKind::kLoop); diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index eef437d41ed..1808c456048 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -76,19 +76,20 @@ TEST_F(HloDceTest, InstructionsWithSideEffect) { auto constant = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto token = builder.AddInstruction(HloInstruction::CreateToken()); - builder.AddInstruction( + auto send = builder.AddInstruction( HloInstruction::CreateSend(constant, token, /*channel_id=*/0)); + builder.AddInstruction(HloInstruction::CreateSendDone(send)); builder.AddInstruction(HloInstruction::CreateTuple({})); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(4, computation->instruction_count()); + EXPECT_EQ(5, computation->instruction_count()); HloDCE dce; EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); - EXPECT_EQ(4, computation->instruction_count()); + EXPECT_EQ(5, computation->instruction_count()); } TEST_F(HloDceTest, CustomCallInstructionsWithSideEffect) { @@ -250,7 +251,7 @@ TEST_F(HloDceTest, DeadInstructionWithCalledComputation) { // Tests that a while instruction with an infeed (effectul instruction) in its // body is not removed, even its user count is 0. TEST_F(HloDceTest, CalledComputationWithSideEffect) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); Shape shape = ShapeUtil::MakeShape(F32, {}); // Condition computation of a while instruction. @@ -274,8 +275,10 @@ TEST_F(HloDceTest, CalledComputationWithSideEffect) { auto token = body_builder.AddInstruction(HloInstruction::CreateToken()); auto infeed = body_builder.AddInstruction( HloInstruction::CreateInfeed(shape, token, "")); - body_builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, infeed)); + auto infeed_data = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, infeed, 0)); + body_builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, param, infeed_data)); } auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); @@ -306,7 +309,7 @@ TEST_F(HloDceTest, CalledComputationWithSideEffect) { // Tests that a nested call instruction with a side effect is not removed. TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); Shape shape = ShapeUtil::MakeShape(F32, {}); // Nested called computation with a side effect. @@ -328,8 +331,8 @@ TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) { { auto param = callee_builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param")); - callee_builder.AddInstruction( - HloInstruction::CreateCall(shape, {param}, nested_called_computation)); + callee_builder.AddInstruction(HloInstruction::CreateCall( + ShapeUtil::MakeTokenShape(), {param}, nested_called_computation)); } auto called_computation = module->AddEmbeddedComputation(callee_builder.Build()); @@ -338,22 +341,20 @@ TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) { auto builder = HloComputation::Builder(TestName()); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param")); - auto live_call = builder.AddInstruction( - HloInstruction::CreateCall(shape, {param}, called_computation)); - builder.AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kNegate, param)); + auto live_call = builder.AddInstruction(HloInstruction::CreateCall( + ShapeUtil::MakeTokenShape(), {param}, called_computation)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(3, computation->instruction_count()); - EXPECT_EQ(2, param->user_count()); + EXPECT_EQ(2, computation->instruction_count()); + EXPECT_EQ(1, param->user_count()); EXPECT_EQ(0, live_call->user_count()); EXPECT_TRUE(HasInstruction(*computation, live_call)); HloDCE dce; EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); - EXPECT_EQ(3, computation->instruction_count()); - EXPECT_EQ(2, param->user_count()); + EXPECT_EQ(2, computation->instruction_count()); + EXPECT_EQ(1, param->user_count()); EXPECT_EQ(0, live_call->user_count()); EXPECT_TRUE(HasInstruction(*computation, live_call)); } @@ -400,7 +401,7 @@ TEST_F(HloDceTest, RemoveDeadSubcomputation) { } TEST_F(HloDceTest, KeepUsedSubcomputation) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloComputation::Builder subcomp_builder("reduction_subcomp"); @@ -418,7 +419,7 @@ TEST_F(HloDceTest, KeepUsedSubcomputation) { // Create a dead reduce instruction. builder.AddInstruction(HloInstruction::CreateReduce( - ShapeUtil::MakeShape(F32, {1}), + ShapeUtil::MakeShape(F32, {}), builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {100}), "param0")), builder.AddInstruction( @@ -428,7 +429,7 @@ TEST_F(HloDceTest, KeepUsedSubcomputation) { // Add another instruction as the root of the computation that also uses // reduce_subcomp. builder.AddInstruction(HloInstruction::CreateReduce( - ShapeUtil::MakeShape(F32, {1}), + ShapeUtil::MakeShape(F32, {}), builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {100}), "param1")), builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc index 1b13619afa9..ce4239ff927 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" namespace xla { @@ -29,7 +28,7 @@ using ::testing::ContainsRegex; class HloExecutionProfileTest : public HloTestBase {}; TEST_F(HloExecutionProfileTest, Basic) { - auto hlo_module = ParseAndReturnUnverifiedModule(R"( + auto hlo_module = ParseAndReturnVerifiedModule(R"( HloModule test_module ENTRY entry_computation { lhs = f32[30,30]{1,0} parameter(0) diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc index a0a06d53ea2..d96f2db3c26 100644 --- a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc @@ -40,7 +40,7 @@ class HloGetDimensionSizeRewriterTest : public HloTestBase { }; TEST_F(HloGetDimensionSizeRewriterTest, Ok) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule _ ENTRY gds { p = s32[3,4] parameter(0) diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc index cd3abcdf1d6..8293d495878 100644 --- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" @@ -80,7 +79,7 @@ ENTRY main { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_str)); + ParseAndReturnVerifiedModule(module_str)); HloInputOutputAliasConfig config( module->entry_computation()->root_instruction()->shape()); @@ -112,7 +111,7 @@ ENTRY main { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_str)); + ParseAndReturnVerifiedModule(module_str)); HloInputOutputAliasConfig config( module->entry_computation()->root_instruction()->shape()); @@ -151,7 +150,7 @@ ENTRY main { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_str)); + ParseAndReturnVerifiedModule(module_str)); HloInputOutputAliasConfig config( module->entry_computation()->root_instruction()->shape()); @@ -182,7 +181,7 @@ ENTRY main { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_str)); + ParseAndReturnVerifiedModule(module_str)); HloInputOutputAliasConfig config( module->entry_computation()->root_instruction()->shape()); @@ -208,7 +207,7 @@ ENTRY main { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_str)); + ParseAndReturnVerifiedModule(module_str)); HloInputOutputAliasConfig config( module->entry_computation()->root_instruction()->shape()); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 68d0575cd8e..6bf0e912bd0 100755 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -2895,7 +2895,7 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { // These opcodes are not handled here. case HloOpcode::kTrace: - break; + return Status::OK(); } return InternalError( "Unhandled HloOpcode for DfsHloVisitor: %s. This should not happen - " diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 633ce875de0..a9d9eb9cfa4 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -28,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -951,7 +950,7 @@ ENTRY entry (param: f32[]) -> (f32[], f32[], f32[]) { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_string)); + ParseAndReturnVerifiedModule(hlo_string)); auto* root = module->entry_computation()->root_instruction(); auto* t1 = root->operand(0); @@ -1187,11 +1186,12 @@ TEST_F(HloInstructionTest, FuseInstructionKeepsInstruction) { p2 = f32[32,32]{1,0} parameter(0) p3 = f32[32,32]{1,0} parameter(1) c1 = f32[] constant(1) + broadcast = f32[32,32]{1,0} broadcast(c1), dimensions={} mul = f32[32,32]{1,0} multiply(p2, p3) - ROOT add = f32[32,32]{1,0} fusion(mul, c1), kind=kLoop, calls=fused_add + ROOT add = f32[32,32]{1,0} fusion(mul, broadcast), kind=kLoop, calls=fused_add })"; TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(kHloString)); + ParseAndReturnVerifiedModule(kHloString)); HloInstruction* fused_add = module->entry_computation()->root_instruction(); HloInstruction* mul = fused_add->mutable_operand(0); EXPECT_EQ(1, mul->user_count()); @@ -1215,11 +1215,12 @@ TEST_F(HloInstructionTest, FuseInstructionIntoMultiOutputKeepsInstruction) { p3 = f32[32,32]{1,0} parameter(1) c1 = f32[] constant(1) mul = f32[32,32]{1,0} multiply(p2, p3) - add = f32[32,32]{1,0} fusion(mul, c1), kind=kLoop, calls=fused_add + broadcast = f32[32,32]{1,0} broadcast(c1), dimensions={} + add = f32[32,32]{1,0} fusion(mul, broadcast), kind=kLoop, calls=fused_add ROOT root = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(mul, add) })"; TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(kHloString)); + ParseAndReturnVerifiedModule(kHloString)); HloInstruction* root = module->entry_computation()->root_instruction(); HloInstruction* mul = root->mutable_operand(0); HloInstruction* fused_add = root->mutable_operand(1); @@ -1740,7 +1741,7 @@ ENTRY entry (param: s32[]) -> s32[] { // Check that deep clones really deep clones every instruction and // computations, without leaving dangling pointers to the old module. TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_string)); + ParseAndReturnVerifiedModule(hlo_string)); std::unique_ptr clone = module->Clone(); for (HloComputation* computation : clone->computations()) { EXPECT_EQ(computation->parent(), clone.get()); @@ -1860,7 +1861,7 @@ TEST_F(HloInstructionTest, PreserveOperandPrecisionOnCloneConv) { dim_labels=b0f_0io->b0f, operand_precision={high,default} })"; TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(kHloString)); + ParseAndReturnVerifiedModule(kHloString)); auto* conv = module->entry_computation()->root_instruction(); auto clone = conv->Clone(); @@ -1873,10 +1874,10 @@ TEST_F(HloInstructionTest, PreserveOuterDimensionPartitionsOnClone) { constexpr char kHloString[] = R"( HloModule test_module ENTRY test { - ROOT iota = f32[100] iota(), iota_dimension=1, outer_dimension_partitions={0, 50} + ROOT iota = f32[100] iota(), iota_dimension=0, outer_dimension_partitions={0, 50} })"; TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(kHloString)); + ParseAndReturnVerifiedModule(kHloString)); auto* iota = module->entry_computation()->root_instruction(); auto clone = iota->Clone(); diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc index a2026818cb2..35db6aa0635 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" @@ -59,7 +58,7 @@ class HloLivenessAnalysisTest : public HloTestBase { // Test that add instruction at entry root is live at all output shape indices. TEST_F(HloLivenessAnalysisTest, AddAtEntryRoot) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule SimpleModule ENTRY SimpleComputation { constant.1 = s32[] constant(0) @@ -75,7 +74,7 @@ TEST_F(HloLivenessAnalysisTest, AddAtEntryRoot) { // Test that a dead add instruction is marked as dead by analysis. TEST_F(HloLivenessAnalysisTest, DeadAdd) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule SimpleModule ENTRY SimpleComputation { constant.1 = s32[] constant(0) @@ -94,7 +93,7 @@ TEST_F(HloLivenessAnalysisTest, DeadAdd) { // Test that all output shape indices of entry root tuple (and defining // instruction in its output) are marked live. TEST_F(HloLivenessAnalysisTest, TupleAtEntryRoot) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule SimpleModule ENTRY SimpleComputation { constant.1 = s32[] constant(0) @@ -113,7 +112,7 @@ TEST_F(HloLivenessAnalysisTest, TupleAtEntryRoot) { // Tests that all outputs of nested tuple and entry root (and defining // instruction values appearing in its output) are marked live. TEST_F(HloLivenessAnalysisTest, NestedTupleAtEntryRoot) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule SimpleModule ENTRY SimpleComputation { constant.1 = s32[] constant(1) @@ -140,7 +139,7 @@ TEST_F(HloLivenessAnalysisTest, NestedTupleAtEntryRoot) { // Tests that GTE at entry root of Tuple instruction only propgates liveness // to the live elements in tuple. TEST_F(HloLivenessAnalysisTest, GteOfTuple) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule SimpleModule ENTRY SimpleComputation { constant.1 = s32[] constant(0) @@ -162,7 +161,7 @@ TEST_F(HloLivenessAnalysisTest, GteOfTuple) { // Tests that GTE at entry root of nested Tuple instruction only propgates // liveness to the live elements in tuple. TEST_F(HloLivenessAnalysisTest, GteOfNestedTuple) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule SimpleModule ENTRY SimpleComputation { constant.1 = s32[] constant(0) @@ -199,7 +198,7 @@ TEST_F(HloLivenessAnalysisTest, GteOfNestedTuple) { // Tests that GTE of GTE (at entry root) of nested Tuple instruction only // propgates liveness to the live elements in tuple. TEST_F(HloLivenessAnalysisTest, GteOfGteOfNestedTuple) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule SimpleModule ENTRY SimpleComputation { constant.1 = s32[] constant(0) @@ -240,7 +239,7 @@ TEST_F(HloLivenessAnalysisTest, GteOfGteOfNestedTuple) { // Test that live/dead while tuple elements are marked live/dead correctly. TEST_F(HloLivenessAnalysisTest, WhileWithDeadTupleElement) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule SimpleLoop SimpleLoop.body { loop_var.1 = (s32[], s32[3]{0}) parameter(0) @@ -291,8 +290,13 @@ TEST_F(HloLivenessAnalysisTest, WhileWithDeadTupleElement) { // Tests that a tuple element live in while.cond computation, propagates // liveness to while.body.root/while.result/while.operand (where it is unused). TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule SimpleLoop + add_S32 { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(lhs, rhs) + } SimpleLoop.body { loop_var.1 = (s32[], s32[3]{0}) parameter(0) get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 @@ -305,8 +309,10 @@ TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) { SimpleLoop.condition { loop_var.2 = (s32[], s32[3]{0}) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 - get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=1 - add.1 = s32[] add(get-tuple-element.3, get-tuple-element.4) + get-tuple-element.4 = s32[3]{0} get-tuple-element(loop_var.2), index=1 + zero = s32[] constant(0) + reduce = s32[] reduce(get-tuple-element.4, zero), dimensions={0}, to_apply=add_S32 + add.1 = s32[] add(get-tuple-element.3, reduce) constant.2 = s32[] constant(5) ROOT less-than = pred[] compare(add.1, constant.2), direction=LT } @@ -338,14 +344,14 @@ TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) { EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {})); EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {0})); EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {1})); - EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.1"), {})); EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "multiply.0"), {})); } // Tests that a use of while.result{0} propagates liveness to // while.body.param{1} to while.body.root{1}, and then to while.body.param{2}. TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule SimpleLoop SimpleLoop.body { loop_var.1 = (s32[], s32[], s32[]) parameter(0) @@ -399,7 +405,7 @@ TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) { } TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule OutfeedLoop WhileBody { body_param = (s32[]) parameter(0) @@ -432,7 +438,7 @@ TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) { } TEST_F(HloLivenessAnalysisTest, NestedWhileWithOutfeed) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule OutfeedLoop InnerWhileBody { body_param = (s32[]) parameter(0) diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc index 0c6c632f5c8..9c63638d492 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -14,9 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_matchers.h" + #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" namespace op = xla::testing::opcode_matchers; using ::testing::_; @@ -25,6 +26,8 @@ using ::testing::Eq; namespace xla { namespace { +using HloMatchersTest = HloTestBase; + string DescribeHloMatcher(const ::testing::Matcher& m) { std::stringstream ss; m.DescribeTo(&ss); @@ -39,7 +42,7 @@ string Explain(const T& t, const M& m) { return listener.str(); } -TEST(HloMatchersTest, Test) { +TEST_F(HloMatchersTest, Test) { auto shape = ShapeUtil::MakeShape(F32, {1}); auto param = HloInstruction::CreateParameter(0, shape, "param"); auto mul = HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, @@ -85,7 +88,7 @@ TEST(HloMatchersTest, Test) { "add, (%param = f32[1]{0} parameter(0))")); } -TEST(HloMatchersTest, CustomCallMatcher) { +TEST_F(HloMatchersTest, CustomCallMatcher) { auto c1 = HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3})); auto c2 = @@ -116,7 +119,7 @@ TEST(HloMatchersTest, CustomCallMatcher) { R"(custom-call with call target that is equal to "foo_target")"); } -TEST(HloMatchersTest, ShapeMatcher) { +TEST_F(HloMatchersTest, ShapeMatcher) { auto p0 = HloInstruction::CreateParameter( 0, ShapeUtil::MakeShapeWithLayout(F32, {5, 7}, {0, 1}), "param"); @@ -154,7 +157,7 @@ TEST(HloMatchersTest, ShapeMatcher) { "(expected: f32[7,5]{1,0})"); } -TEST(HloMatchersTest, ShardingMatcher) { +TEST_F(HloMatchersTest, ShardingMatcher) { auto p0 = HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {5}), "param.0"); p0->clear_sharding(); @@ -196,7 +199,7 @@ TEST(HloMatchersTest, ShardingMatcher) { "has incorrect sharding (expected: {maximal device=0})"); } -TEST(HloMatchersTest, DotMatcher) { +TEST_F(HloMatchersTest, DotMatcher) { string hlo_string = R"( HloModule DotOperationFusion_TransposeFusion @@ -208,7 +211,7 @@ ENTRY DotOperationFusion_TransposeFusion { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_string)); + ParseAndReturnVerifiedModule(hlo_string)); HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Dot(op::Parameter(0), op::Parameter(1), @@ -232,7 +235,7 @@ ENTRY DotOperationFusion_TransposeFusion { "rhs_contracting_dimensions (got {0} want {1})"); } -TEST(HloMatchersTest, ComparisonMatcher) { +TEST_F(HloMatchersTest, ComparisonMatcher) { auto shape = ShapeUtil::MakeShape(F32, {1}); auto p0 = HloInstruction::CreateParameter(0, shape, "param.0"); auto p1 = HloInstruction::CreateParameter(1, shape, "param.1"); @@ -264,7 +267,7 @@ TEST(HloMatchersTest, ComparisonMatcher) { "has wrong comparison direction (got EQ, want NE)")); } -TEST(HloMatchersTest, AsyncCopyMatcher) { +TEST_F(HloMatchersTest, AsyncCopyMatcher) { Shape shape_memspace1 = ShapeUtil::MakeShapeWithLayout( F32, {16}, /*minor_to_major=*/{0}, /*tiles=*/{}, /*element_size_in_bits=*/0, /*memory_space=*/1); diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc index bf10817b3f5..a422e03b26e 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" @@ -137,7 +136,7 @@ ENTRY root { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_str)); + ParseAndReturnVerifiedModule(module_str)); auto size_fn = [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); @@ -190,7 +189,7 @@ ENTRY entry { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_str)); + ParseAndReturnVerifiedModule(module_str)); auto size_fn = [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); @@ -334,7 +333,7 @@ ENTRY main { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_string)); + ParseAndReturnVerifiedModule(hlo_string)); EXPECT_FALSE(module->has_schedule()); TF_ASSERT_OK(HloTrivialScheduler().Run(module.get()).status()); ASSERT_TRUE(module->has_schedule()); diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc index 2584c730aea..dba699dd8c5 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -71,7 +70,7 @@ class HloModuleDceTest : public HloTestBase { // Tests that a while with all outputs live is unmodified. TEST_F(HloModuleDceTest, WhileWithLiveOutputs) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule SimpleLoop SimpleLoop.body { loop_var.1 = (s32[], s32[3]{0}) parameter(0) @@ -108,7 +107,7 @@ TEST_F(HloModuleDceTest, WhileWithLiveOutputs) { // Tests a while loop with one unused output (which is used in the while loop // body by an instruction with side-effects: rng) is unmodified. TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule SimpleLoop SimpleLoop.body { loop_var.1 = (s32[], f32[]) parameter(0) @@ -118,7 +117,7 @@ TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) { get-tuple-element.2 = f32[] get-tuple-element(loop_var.1), index=1 constant.2 = f32[] constant(1.0) rng = f32[] rng(constant.2, get-tuple-element.2), distribution=rng_uniform - add.1 = s32[] add(get-tuple-element.2, constant.2) + add.1 = f32[] add(get-tuple-element.2, constant.2) ROOT tuple = (s32[], f32[]) tuple(add, add.1) } SimpleLoop.condition { @@ -148,7 +147,7 @@ TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) { // Tests that a while loop with one dead tuple element at {1} has its while // loop body modified to make that tuple element pass-through the while body. TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule SimpleLoop SimpleLoop.body { loop_var.1 = (s32[], s32[3]{0}) parameter(0) @@ -191,7 +190,7 @@ TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) { // dead in while.body{1} and at while.result{1}) propgates liveness of this // tuple element to while.body{1} and at while.result{1}. TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule SimpleLoop SimpleLoop.body { loop_var.1 = (s32[], s32[]) parameter(0) @@ -233,7 +232,7 @@ TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) { // Tests that HloModuleDCE can remove a dead tuple element at index {1} between // two dependent while loops. TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule SimpleLoop SimpleLoop.body0 { loop_var.1 = (s32[], s32[3]{0}) parameter(0) @@ -301,7 +300,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) { // Tests that HloModuleDCE can remove a dead tuple element at while.1{0} and // while.2{1}, between two dependent while loops. TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule SimpleLoop SimpleLoop.body0 { loop_var.1 = (s32[3]{0}, s32[]) parameter(0) @@ -367,7 +366,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) { // Tests that a while whose body has outfeed operations is not DCE-ed. TEST_F(HloModuleDceTest, WhileWithOutfeed) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule OutfeedLoop WhileBody { body_param = (s32[]) parameter(0) @@ -404,7 +403,7 @@ TEST_F(HloModuleDceTest, WhileWithOutfeed) { // variable changes are not elided within the loop body, if the condition // computation uses them. TEST_F(HloModuleDceTest, WhileWithOnlyLoopVariableBumping) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule InfiniteLoop WhileBody { body_param = (s32[], s32[]) parameter(0) diff --git a/tensorflow/compiler/xla/service/hlo_module_group_test.cc b/tensorflow/compiler/xla/service/hlo_module_group_test.cc index 1e6f4db1287..1b26451e6e4 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -45,7 +44,7 @@ ENTRY %entry (x: f32[], y: f32[]) -> f32[] { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(text)); + ParseAndReturnVerifiedModule(text)); HloModuleGroup group(std::move(module)); EXPECT_EQ(group.modules().size(), 1); @@ -84,9 +83,9 @@ ENTRY %entry (a: f32[]) -> f32[] { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_0, - ParseAndReturnUnverifiedModule(text_0)); + ParseAndReturnVerifiedModule(text_0)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_1, - ParseAndReturnUnverifiedModule(text_1)); + ParseAndReturnVerifiedModule(text_1)); std::vector> modules; modules.push_back(std::move(module_0)); modules.push_back(std::move(module_1)); @@ -123,9 +122,9 @@ ENTRY %entry (a: f32[]) -> f32[] { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_0, - ParseAndReturnUnverifiedModule(text_0)); + ParseAndReturnVerifiedModule(text_0)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_1, - ParseAndReturnUnverifiedModule(text_1)); + ParseAndReturnVerifiedModule(text_1)); HloModuleGroup group(TestName()); group.push_back(std::move(module_0)); group.push_back(std::move(module_1)); @@ -179,7 +178,7 @@ ENTRY entry { const int64 send_channel = i; const int64 recv_channel = i == 0 ? kDeviceCount - 1 : i - 1; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(absl::StrFormat( + ParseAndReturnVerifiedModule(absl::StrFormat( text, i, send_channel, send_channel, recv_channel, recv_channel))); group.push_back(std::move(module)); diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index d94cef84c39..7a97740ed8b 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -215,7 +214,7 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(text)); + ParseAndReturnVerifiedModule(text)); ASSERT_FALSE(module->has_schedule()); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr module_copy, @@ -237,7 +236,7 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(text)); + ParseAndReturnVerifiedModule(text)); ASSERT_TRUE(module->has_schedule()); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr module_copy, @@ -274,7 +273,7 @@ ENTRY ReduceR3ToR2.v3 { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(text)); + ParseAndReturnVerifiedModule(text)); // Perform various transformations on the graph: // diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 6d5ce6b2849..2b77619f89b 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -338,7 +337,7 @@ ENTRY while.v11 { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_str)); + ParseAndReturnVerifiedModule(module_str)); DependencyHloOrdering ordering(module.get()); ordering.ToString(); // Shouldn't crash. } @@ -375,7 +374,7 @@ ENTRY root { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_str)); + ParseAndReturnVerifiedModule(module_str)); TF_ASSERT_OK_AND_ASSIGN(auto dataflow, HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); DependencyHloOrdering ordering(module.get()); @@ -507,23 +506,22 @@ TEST_F(HloOrderingTest, InterferenceWithOuterRoot) { absl::string_view hlo_string = R"( HloModule InterferenceWithOuterRoot, is_scheduled=true -Emmbedded (embedded_param: f32[42]) -> f32[42] { - embedded_param = f32[42]{0} parameter(0) - multiply = f32[42]{0} multiply(embedded_param, embedded_param) - ROOT log = f32[42]{0} log(multiply) +Emmbedded (embedded_param: f32[4096,4096]) -> f32[4096,4096] { + embedded_param = f32[4096,4096]{1,0} parameter(0) + multiply = f32[4096,4096]{1,0} multiply(embedded_param, embedded_param) + ROOT log = f32[4096,4096]{1,0} log(multiply) } ENTRY InterferenceWithOuterRoot { param = f32[4096,4096]{1,0} parameter(0) ROOT add = f32[4096,4096]{1,0} add(param, param) - call = f32[42]{0} call(param), to_apply=Emmbedded + call = f32[4096,4096]{1,0} call(param), to_apply=Emmbedded } )"; HloModuleConfig hlo_config; - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_string, hlo_config)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string, hlo_config)); TF_ASSERT_OK_AND_ASSIGN(auto dataflow, HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); DependencyHloOrdering ordering(module.get()); diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index dabd9d20f64..996c05f8460 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -602,9 +602,8 @@ ENTRY %entry { } )"; - TF_ASSERT_OK_AND_ASSIGN( - auto module, - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( @@ -643,9 +642,8 @@ ENTRY %entry { } )"; - TF_ASSERT_OK_AND_ASSIGN( - auto module, - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(