From 388eb3753ba309f80201e9230341f931c7dae2b6 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Sun, 11 Nov 2018 16:18:06 -0800 Subject: [PATCH] [XLA] Merge HloTestBase and HloVerifiedTestBase. No need to have this distinction any longer; you can simply call CreateNewUnverifiedModule or CreateNewVerifiedModule as you please. PiperOrigin-RevId: 221018719 --- tensorflow/compiler/xla/service/BUILD | 49 ++------ .../xla/service/algebraic_simplifier_test.cc | 13 +- .../service/batch_dot_simplification_test.cc | 3 +- .../xla/service/batchnorm_expander_test.cc | 4 +- .../bfloat16_conversion_folding_test.cc | 8 +- .../service/bfloat16_normalization_test.cc | 8 +- .../xla/service/bfloat16_propagation_test.cc | 8 +- .../xla/service/buffer_assignment_test.cc | 6 +- .../compiler/xla/service/call_graph_test.cc | 4 +- .../compiler/xla/service/call_inliner_test.cc | 4 +- .../service/conditional_simplifier_test.cc | 4 +- tensorflow/compiler/xla/service/cpu/BUILD | 5 - .../service/cpu/conv_canonicalization_test.cc | 4 +- .../service/cpu/cpu_copy_insertion_test.cc | 4 +- .../cpu/cpu_hlo_support_checker_test.cc | 6 +- .../cpu/parallel_task_assignment_test.cc | 6 +- .../xla/service/cpu/shape_partition_test.cc | 8 +- .../compiler/xla/service/cpu/tests/BUILD | 1 - .../xla/service/cpu/tests/cpu_fusion_test.cc | 4 +- .../compiler/xla/service/defuser_test.cc | 4 +- .../xla/service/flatten_call_graph_test.cc | 4 +- tensorflow/compiler/xla/service/gpu/BUILD | 9 +- .../cudnn_conv_pad_for_tensor_cores_test.cc | 4 +- .../service/gpu/cudnn_conv_rewriter_test.cc | 8 +- .../xla/service/gpu/gpu_hlo_schedule_test.cc | 4 +- .../gpu/gpu_hlo_support_checker_test.cc | 6 +- .../xla/service/gpu/stream_assignment_test.cc | 4 +- .../service/gpu/variadic_op_splitter_test.cc | 4 +- .../xla/service/heap_simulator_test.cc | 6 +- .../xla/service/hlo_alias_analysis_test.cc | 6 +- .../xla/service/hlo_constant_folding_test.cc | 4 +- .../xla/service/hlo_creation_utils_test.cc | 4 +- .../compiler/xla/service/hlo_cse_test.cc | 4 +- .../compiler/xla/service/hlo_domain_test.cc | 3 +- .../xla/service/hlo_evaluator_test.cc | 10 +- .../xla/service/hlo_instruction_test.cc | 4 +- .../xla/service/hlo_pass_pipeline_test.cc | 4 +- .../xla/service/hlo_reachability_test.cc | 4 +- .../xla/service/hlo_rematerialization_test.cc | 4 +- .../xla/service/hlo_tfgraph_builder_test.cc | 4 +- .../compiler/xla/service/hlo_verifier_test.cc | 2 +- .../implicit_broadcast_remover_test.cc | 4 +- .../service/indexed_array_analysis_test.cc | 4 +- .../xla/service/layout_assignment_test.cc | 4 +- .../compiler/xla/service/map_inliner_test.cc | 4 +- .../xla/service/reshape_mover_test.cc | 4 +- .../xla/service/tuple_simplifier_test.cc | 4 +- .../xla/service/while_loop_analysis_test.cc | 4 +- .../while_loop_invariant_code_motion_test.cc | 4 +- .../xla/service/while_loop_simplifier_test.cc | 4 +- tensorflow/compiler/xla/tests/BUILD | 38 ------ .../compiler/xla/tests/hlo_test_base.cc | 49 +++++++- tensorflow/compiler/xla/tests/hlo_test_base.h | 46 ++++++- .../xla/tests/hlo_verified_test_base.cc | 68 ----------- .../xla/tests/hlo_verified_test_base.h | 86 ------------- .../xla/tests/hlo_verified_test_base_test.cc | 115 ------------------ 56 files changed, 219 insertions(+), 474 deletions(-) delete mode 100644 tensorflow/compiler/xla/tests/hlo_verified_test_base.cc delete mode 100644 tensorflow/compiler/xla/tests/hlo_verified_test_base.h delete mode 100644 tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 16d79da2de5..19b5c1ca25d 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -87,7 +87,6 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -124,7 +123,6 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -164,7 +162,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep ], @@ -282,7 +279,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo_element_type_converter", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", @@ -365,7 +362,6 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -421,7 +417,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -467,7 +462,6 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -520,7 +514,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -569,7 +562,6 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -592,7 +584,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -1089,7 +1080,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1172,7 +1162,6 @@ tf_cc_test( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1433,7 +1422,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", "@com_google_absl//absl/memory", @@ -1509,7 +1497,6 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/memory", @@ -1561,7 +1548,6 @@ tf_cc_test( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1598,7 +1584,6 @@ tf_cc_test( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1648,7 +1633,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -1707,7 +1692,7 @@ tf_cc_test( ":while_loop_analysis", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -1740,7 +1725,7 @@ tf_cc_test( ":hlo_matchers", ":while_loop_simplifier", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/strings", @@ -1771,7 +1756,7 @@ tf_cc_test( ":hlo_matchers", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", ], ) @@ -1799,7 +1784,7 @@ tf_cc_test( ":implicit_broadcast_remover", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", ], ) @@ -1844,7 +1829,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:test", ], ) @@ -1878,7 +1862,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/memory", @@ -2284,7 +2268,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -2347,7 +2330,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -2444,7 +2426,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:test", ], ) @@ -2562,7 +2543,6 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -2629,7 +2609,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -2691,7 +2670,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -2732,7 +2711,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", @@ -2771,7 +2749,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -2846,7 +2823,6 @@ tf_cc_test( "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", "@com_google_absl//absl/memory", @@ -3034,7 +3010,6 @@ tf_cc_test( deps = [ ":hlo_tfgraph_builder", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", ], @@ -3361,7 +3336,7 @@ tf_cc_test( ":while_loop_invariant_code_motion", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test", ], ) @@ -3391,7 +3366,7 @@ tf_cc_test( ":while_loop_constant_sinking", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test", ], ) @@ -3452,7 +3427,7 @@ tf_cc_test( ":indexed_array_analysis", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:test", ], @@ -3549,7 +3524,7 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "@com_google_absl//absl/memory", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 18eaa4d3f0e..e4c4da1b0e7 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -33,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -54,7 +53,7 @@ AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() { return [](const Shape&, const Shape&) { return false; }; } -class AlgebraicSimplifierTest : public HloVerifiedTestBase {}; +class AlgebraicSimplifierTest : public HloTestBase {}; // Test that A + 0 is simplified to A TEST_F(AlgebraicSimplifierTest, AddZero) { @@ -2906,7 +2905,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); // TODO(b/80488902): verify this module. - auto module = HloTestBase::CreateNewUnverifiedModule(); + auto module = CreateNewUnverifiedModule(); auto* computation = module->AddEntryComputation(b.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, @@ -3084,7 +3083,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x). TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // TODO(b/80488902): verify this module. - auto module = HloTestBase::CreateNewUnverifiedModule(); + auto module = CreateNewUnverifiedModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -3166,7 +3165,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // ReduceWindow(Convert(op), x). TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { // TODO(b/80488902): verify this module. - auto module = HloTestBase::CreateNewUnverifiedModule(); + auto module = CreateNewUnverifiedModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -3846,7 +3845,7 @@ struct DotOfConcatTestSpec { }; class DotOfConcatSimplificationTest - : public HloVerifiedTestBase, + : public HloTestBase, public ::testing::WithParamInterface {}; // Test that we transform @@ -4022,7 +4021,7 @@ struct DotOfGatherTestSpec { }; class DotOfGatherSimplificationTest - : public HloVerifiedTestBase, + : public HloTestBase, public ::testing::WithParamInterface {}; // input: dot(DS(ctA), ctB)) diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc index dd6e56a7207..52ec1a794c5 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc @@ -17,14 +17,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" namespace xla { namespace { namespace op = xla::testing::opcode_matchers; -class BatchDotSimplificationTest : public HloVerifiedTestBase {}; +class BatchDotSimplificationTest : public HloTestBase {}; TEST_F(BatchDotSimplificationTest, ElideSingleDegenerateBatchDotDim_VectorVector) { diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc index 0b0000f7c05..08cf8026177 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc @@ -29,14 +29,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace { -using BatchNormExpanderTest = HloVerifiedTestBase; +using BatchNormExpanderTest = HloTestBase; // Test that we expand BatchNormTraining. TEST_F(BatchNormExpanderTest, BatchNormTraining) { diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index d5da43db269..4ce351acc2c 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -65,11 +65,11 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16ConversionFoldingTest : public HloVerifiedTestBase { +class BFloat16ConversionFoldingTest : public HloTestBase { protected: BFloat16ConversionFoldingTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/true) {} + : HloTestBase(/*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/true) {} bool FoldConversions(HloModule* module) { TestBFloat16Support bfloat16_support_; diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 677a50aaa14..9f97d18c565 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -68,11 +68,11 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16NormalizationTest : public HloVerifiedTestBase { +class BFloat16NormalizationTest : public HloTestBase { protected: BFloat16NormalizationTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/true) {} + : HloTestBase(/*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/true) {} bool Normalize(HloModule* module) { TestBFloat16Support bfloat16_support_; diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index d0843c97011..5be7141aae4 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -55,11 +55,11 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16PropagationTest : public HloVerifiedTestBase { +class BFloat16PropagationTest : public HloTestBase { protected: BFloat16PropagationTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/true) {} + : HloTestBase(/*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/true) {} // Runs the propagation pass on the given module, and returns whether the // module is changed after this pass. diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index c8ac446073e..b1fc50cb188 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -38,7 +38,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -81,7 +81,7 @@ const std::vector GetInstructions(HloInstruction* root) { return main_list.GetInstructions(); } -class BufferAssignmentTest : public HloVerifiedTestBase { +class BufferAssignmentTest : public HloTestBase { protected: ~BufferAssignmentTest() override {} @@ -1818,7 +1818,7 @@ ENTRY main { } } -class WhileBufferAssignmentTest : public HloVerifiedTestBase { +class WhileBufferAssignmentTest : public HloTestBase { protected: std::unique_ptr BuildWhileConditionComputation( const string& name) { diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index c7ae995198a..a3ac2568b0f 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -31,7 +31,7 @@ namespace { using ::testing::UnorderedElementsAre; -class CallGraphTest : public HloVerifiedTestBase { +class CallGraphTest : public HloTestBase { protected: // Build and return a trivial computation taking and returning a scalar. std::unique_ptr MakeScalarComputation( diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index 7bf4cf7e15b..0b6e323f75c 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -40,7 +40,7 @@ namespace { // Tests for call inlining that are most tractable at the HLO level (vs // ComputationBuilder API in call_test.cc). -using CallInlinerTest = HloVerifiedTestBase; +using CallInlinerTest = HloTestBase; TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { // "inner" computation just has a control dependency from the "zero" value to diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc index 4c3a1877a2e..289eb6d9023 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -37,7 +37,7 @@ namespace { namespace op = xla::testing::opcode_matchers; -class ConditionalSimplifierTest : public HloVerifiedTestBase { +class ConditionalSimplifierTest : public HloTestBase { public: // Makes a computation that contains a conditional with constant predicate. HloComputation* MakeConditional(HloModule* module); diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index c369b3af16f..2763d18121a 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -824,7 +824,6 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -846,7 +845,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -887,7 +885,6 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -971,7 +968,6 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_graph_dumper", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -997,7 +993,6 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index d95a0239eba..c58175428fe 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -32,7 +32,7 @@ namespace cpu { using ::testing::ElementsAre; -class ConvCanonicalizationTest : public HloVerifiedTestBase { +class ConvCanonicalizationTest : public HloTestBase { public: ConvCanonicalizationTest() { for (int i = 0; i < 2; ++i) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc index 2062d8a0d5b..c085f85fb73 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -52,7 +52,7 @@ int64 CountCopies(const HloModule& module) { return count; } -class CpuCopyInsertionTest : public HloVerifiedTestBase { +class CpuCopyInsertionTest : public HloTestBase { protected: void InsertCopies(HloModule* module) { CpuCopyInsertion copy_insertion; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc index b8c2bc7e91c..9cbfb88834b 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -25,7 +25,7 @@ namespace { using ::testing::HasSubstr; -class CpuHloSupportCheckerTest : public HloVerifiedTestBase { +class CpuHloSupportCheckerTest : public HloTestBase { protected: CpuHloSupportChecker& checker() { return checker_; } @@ -60,7 +60,7 @@ TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) { // Since verifier is reporting sparse layouts as errors, we should // use a regular HloModule instead of VerifiedHloModule to avoid // verifier errors being triggered in the destructor. - auto module = HloTestBase::CreateNewUnverifiedModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); Status status = checker().Run(module.get()).status(); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index 31df0a08931..f0b65046c14 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -17,13 +17,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { -class ParallelTaskAssignmentTest : public HloVerifiedTestBase { +class ParallelTaskAssignmentTest : public HloTestBase { protected: const HloCostAnalysis::ShapeSizeFunction shape_size_func_ = cpu::CpuExecutable::ShapeSizeBytes; @@ -35,7 +35,7 @@ class ParallelTaskAssignmentTest : public HloVerifiedTestBase { cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_; ParallelTaskAssignmentTest() - : HloVerifiedTestBase(), target_machine_features_([](int64 shape_size) { + : HloTestBase(), target_machine_features_([](int64 shape_size) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }) {} diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc index 1a3d82de954..7d8e51f909e 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc +++ b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc @@ -19,14 +19,14 @@ limitations under the License. #include #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" namespace xla { namespace cpu { namespace { -class ShapePartitionAssignerTest : public HloVerifiedTestBase { +class ShapePartitionAssignerTest : public HloTestBase { protected: typedef std::vector Vec; @@ -91,7 +91,7 @@ TEST_F(ShapePartitionAssignerTest, Shape532WithLayout201) { expected_partitions); } -class ShapePartitionIteratorTest : public HloVerifiedTestBase { +class ShapePartitionIteratorTest : public HloTestBase { protected: typedef std::vector> Partition; }; @@ -145,7 +145,7 @@ TEST_F(ShapePartitionIteratorTest, Shape532WithLayout210) { } } -class RandomShapePartitionIteratorTest : public HloVerifiedTestBase { +class RandomShapePartitionIteratorTest : public HloTestBase { protected: typedef std::vector> Partition; RandomShapePartitionIteratorTest() diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index 4b129c95d46..382dfd0d99d 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -48,7 +48,6 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/cpu:cpu_instruction_fusion", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc index 707178c0224..04a81dfd35f 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" @@ -34,7 +34,7 @@ namespace xla { namespace cpu { namespace { -class CpuFusionTest : public HloVerifiedTestBase { +class CpuFusionTest : public HloTestBase { protected: CpuFusionTest() {} diff --git a/tensorflow/compiler/xla/service/defuser_test.cc b/tensorflow/compiler/xla/service/defuser_test.cc index bb25c48f2b3..64fb5031839 100644 --- a/tensorflow/compiler/xla/service/defuser_test.cc +++ b/tensorflow/compiler/xla/service/defuser_test.cc @@ -18,14 +18,14 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class DefuserTest : public HloVerifiedTestBase { +class DefuserTest : public HloTestBase { protected: // Returns the number of fusion instructions in the module. int FusionCount(const HloModule* m) { diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc index beecaf61771..8eeb930b481 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -30,7 +30,7 @@ limitations under the License. namespace xla { namespace { -class FlattenCallGraphTest : public HloVerifiedTestBase { +class FlattenCallGraphTest : public HloTestBase { protected: // Build and return a trivial computation taking and returning a scalar. std::unique_ptr MakeScalarComputation() { diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 1e8435fe542..b1629616acd 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -111,7 +111,6 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -463,7 +462,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:shape_inference", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:test", ], @@ -627,7 +626,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep ], ) @@ -849,7 +848,6 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "@com_google_absl//absl/memory", @@ -909,7 +907,6 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", @@ -1036,6 +1033,6 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:pattern_matcher", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores_test.cc index fa3afa6a5d3..af9303a5b76 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" namespace xla { @@ -29,7 +29,7 @@ namespace { namespace op = xla::testing::opcode_matchers; using ::testing::_; -class CudnnConvPadForTensorCoresTest : public HloVerifiedTestBase {}; +class CudnnConvPadForTensorCoresTest : public HloTestBase {}; TEST_F(CudnnConvPadForTensorCoresTest, PadF16ForwardConvInputChannels) { auto module = ParseAndReturnVerifiedModule(R"( diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc index a88132a7444..443883a89f6 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -34,11 +34,11 @@ namespace { namespace op = xla::testing::opcode_matchers; using ::testing::_; -class CudnnConvRewriterTest : public HloVerifiedTestBase { +class CudnnConvRewriterTest : public HloTestBase { public: CudnnConvRewriterTest() - : HloVerifiedTestBase(/*layout_sensitive=*/true, - /*allow_mixed_precision=*/false) { + : HloTestBase(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false) { for (int i = 0; i < 2; ++i) { WindowDimension* window_dim = default_conv_window_.add_dimensions(); window_dim->set_size(1); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc index c8b3770e2aa..6d3aed15ebe 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -24,14 +24,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" namespace xla { namespace gpu { -class GpuHloScheduleTest : public HloVerifiedTestBase { +class GpuHloScheduleTest : public HloTestBase { protected: using HloVec = std::vector; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc index 15a8dd85974..b511155f85f 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -25,7 +25,7 @@ namespace { using ::testing::HasSubstr; -class GpuHloSupportCheckerTest : public HloVerifiedTestBase { +class GpuHloSupportCheckerTest : public HloTestBase { protected: GpuHloSupportChecker& checker() { return checker_; } @@ -60,7 +60,7 @@ TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) { // Since verifier is reporting sparse layouts as errors, we should // use a regular HloModule instead of VerifiedHloModule to avoid // verifier errors being triggered in the destructor. - auto module = HloTestBase::CreateNewUnverifiedModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); Status status = checker().Run(module.get()).status(); diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index 52ee7a960f6..f2ef11e1e6a 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -21,14 +21,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" namespace xla { namespace gpu { -class StreamAssignmentTest : public HloVerifiedTestBase { +class StreamAssignmentTest : public HloTestBase { protected: std::unique_ptr CreateNewUnverifiedModule() { HloModuleConfig config; diff --git a/tensorflow/compiler/xla/service/gpu/variadic_op_splitter_test.cc b/tensorflow/compiler/xla/service/gpu/variadic_op_splitter_test.cc index 5fa9e91050a..3d00ac4dc7b 100644 --- a/tensorflow/compiler/xla/service/gpu/variadic_op_splitter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/variadic_op_splitter_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -32,7 +32,7 @@ namespace gpu { namespace { using match::Concatenate; -class VariadicOpSplitterTest : public HloVerifiedTestBase {}; +class VariadicOpSplitterTest : public HloTestBase {}; TEST_F(VariadicOpSplitterTest, DontSplit) { auto module = ParseAndReturnVerifiedModule(R"( diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index ac650cf442f..fad3215fc81 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -30,13 +30,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { -class MinimumMemoryForSequenceTest : public HloVerifiedTestBase {}; +class MinimumMemoryForSequenceTest : public HloTestBase {}; TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { auto module = CreateNewVerifiedModule(); @@ -351,7 +351,7 @@ class HeapSimulatorTracker { HeapSimulator::Result result_; }; -class HeapSimulatorTest : public HloVerifiedTestBase { +class HeapSimulatorTest : public HloTestBase { protected: HeapSimulatorTest() {} ~HeapSimulatorTest() override {} diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index 2d0244b1cf4..7e6150e9415 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/logging.h" @@ -39,9 +39,9 @@ namespace { using ::testing::UnorderedElementsAre; -class HloAliasAnalysisTest : public HloVerifiedTestBase { +class HloAliasAnalysisTest : public HloTestBase { protected: - HloAliasAnalysisTest() : HloVerifiedTestBase() { + HloAliasAnalysisTest() : HloTestBase() { module_ = CreateNewVerifiedModule(); } diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 30227bf001d..d12f920722e 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" @@ -37,7 +37,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using HloConstantFoldingTest = HloVerifiedTestBase; +using HloConstantFoldingTest = HloTestBase; TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { HloComputation::Builder builder(TestName()); diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc index 55356fd080e..aaa9ec60eb3 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc @@ -19,13 +19,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/platform/test.h" namespace xla { namespace { -class HloCreationUtilsTest : public HloVerifiedTestBase { +class HloCreationUtilsTest : public HloTestBase { protected: std::unique_ptr CreateModuleWithProgramShape( PrimitiveType primitive_type, absl::Span input_shape_dims, diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 0952515312e..1eb0260468c 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" @@ -44,7 +44,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class HloCseTest : public HloVerifiedTestBase { +class HloCseTest : public HloTestBase { protected: HloCseTest() {} }; diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index 483550f1fdb..acdb42128e3 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -22,13 +22,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { -class HloDomainTest : public HloVerifiedTestBase { +class HloDomainTest : public HloTestBase { protected: bool FindUserViaDomainPath(HloInstruction* instruction, HloInstruction* operand) const { diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 0c7921e845c..d95b6ad04f2 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -33,7 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -50,9 +50,9 @@ namespace { static std::array use_bf16_params{true, false}; class HloEvaluatorTest : public ::testing::WithParamInterface, - public HloVerifiedTestBase { + public HloTestBase { protected: - HloEvaluatorTest() : HloVerifiedTestBase(), use_bfloat16_(GetParam()) { + HloEvaluatorTest() : HloTestBase(), use_bfloat16_(GetParam()) { evaluator_ = absl::make_unique(); } @@ -67,7 +67,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, } // Evaluate function that takes in a local module instead of using m_ - // that is in HloVerifiedTestBase. Once m_ in HloVerifiedTestBase is + // that is in HloTestBase. Once m_ in HloTestBase is // removed, this should be the default Evaluate function. Literal EvaluateWithModule( HloModule* module, absl::Span arg_literals = {}) { @@ -1298,7 +1298,7 @@ TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {}; +class HloEvaluatorPreciseReduceTest : public HloTestBase {}; // Tests that Reduce doesn't lose precision when adding many numbers (because // it accumulates its result in a double). diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index f3b7f8f09a0..8048e332cb5 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" @@ -39,7 +39,7 @@ namespace { using ::testing::ElementsAre; using ::testing::UnorderedElementsAre; -class HloInstructionTest : public HloVerifiedTestBase { +class HloInstructionTest : public HloTestBase { protected: Shape r0f32_ = ShapeUtil::MakeShape(F32, {}); }; diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc index ee8cb12b231..20384b9da6b 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc @@ -19,14 +19,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { -class HloPassPipelineTest : public HloVerifiedTestBase { +class HloPassPipelineTest : public HloTestBase { protected: StatusOr ParseModuleGroup( absl::Span hlo_strings) { diff --git a/tensorflow/compiler/xla/service/hlo_reachability_test.cc b/tensorflow/compiler/xla/service/hlo_reachability_test.cc index eec7fd65f81..59517670980 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability_test.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability_test.cc @@ -20,13 +20,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" namespace xla { namespace { -class HloReachabilityTest : public HloVerifiedTestBase {}; +class HloReachabilityTest : public HloTestBase {}; TEST_F(HloReachabilityTest, Reachability) { // Construct and test a reachability graph of the following form: diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 062b8a0bf46..22c3c40a93a 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -36,7 +36,7 @@ namespace op = xla::testing::opcode_matchers; using ::testing::_; -class HloRematerializationTest : public HloVerifiedTestBase { +class HloRematerializationTest : public HloTestBase { protected: // Creates and returns a computation which can benefit from // rematerialization. The computation looks like: diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc index 6fd734a2b9e..1e2b31a1f2b 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" @@ -24,7 +24,7 @@ namespace { using ::tensorflow::GraphDef; -class HloTfGraphBuilderTest : public HloVerifiedTestBase { +class HloTfGraphBuilderTest : public HloTestBase { protected: HloTfGraphBuilderTest() {} HloTfGraphBuilder generator_; diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 4f6cd2f1e2d..5ddfe0a944f 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -35,7 +35,7 @@ namespace { using ::testing::HasSubstr; -// This class cannot be converted to use HloVerifiedTestBase. It explicitly +// This class cannot be converted to use HloTestBase. It explicitly // uses HloTestBase to create and test malformed HLOs. class HloVerifierTest : public HloTestBase { public: diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc index 0d67f23a81d..cf6cf897fe1 100644 --- a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc +++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc @@ -18,14 +18,14 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class ImplicitBroadcastRemoverTest : public HloVerifiedTestBase { +class ImplicitBroadcastRemoverTest : public HloTestBase { protected: ImplicitBroadcastRemover remover_; }; diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index 900fbfaa975..20cc18f9815 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -16,12 +16,12 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" namespace xla { namespace { -class IndexedArrayAnalysisTest : public HloVerifiedTestBase { +class IndexedArrayAnalysisTest : public HloTestBase { protected: void AssertArrayForRootExpressionIs(const string& hlo_text, const string& root_expression) { diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 3fe324005da..2400b7bb7c4 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -35,7 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -49,7 +49,7 @@ namespace { using ::testing::ElementsAre; -class LayoutAssignmentTest : public HloVerifiedTestBase { +class LayoutAssignmentTest : public HloTestBase { protected: void AssignLayouts(HloModule* m, ComputationLayout* entry_computation_layout, ChannelLayoutConstraints* channel_constraints = nullptr) { diff --git a/tensorflow/compiler/xla/service/map_inliner_test.cc b/tensorflow/compiler/xla/service/map_inliner_test.cc index 320e0d1acb4..fd18bfdc3e7 100644 --- a/tensorflow/compiler/xla/service/map_inliner_test.cc +++ b/tensorflow/compiler/xla/service/map_inliner_test.cc @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -35,7 +35,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using MapInlinerTest = HloVerifiedTestBase; +using MapInlinerTest = HloTestBase; // Test that `map` with `max` is transformed to `max` TEST_F(MapInlinerTest, MapMax) { diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index e100460f7d4..341659b15c4 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -34,7 +34,7 @@ namespace { namespace op = xla::testing::opcode_matchers; -class ReshapeMoverTest : public HloVerifiedTestBase {}; +class ReshapeMoverTest : public HloTestBase {}; TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { auto m = CreateNewVerifiedModule(); diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc index 8add6b189a8..65b0f8c8044 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -34,7 +34,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class TupleSimplifierTest : public HloVerifiedTestBase { +class TupleSimplifierTest : public HloTestBase { protected: void Run(HloModule* module, bool change_expected) { TupleSimplifier simplifier; diff --git a/tensorflow/compiler/xla/service/while_loop_analysis_test.cc b/tensorflow/compiler/xla/service/while_loop_analysis_test.cc index f845cd3c065..1da0fbeac89 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_analysis_test.cc @@ -17,13 +17,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { -class WhileLoopAnalysisTest : public HloVerifiedTestBase {}; +class WhileLoopAnalysisTest : public HloTestBase {}; TEST_F(WhileLoopAnalysisTest, SingleIterationUpperBound) { const char* const kHloModule = R"( diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index a6d0db0f6f4..046ccb2d3f2 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { @@ -26,7 +26,7 @@ namespace { namespace op = xla::testing::opcode_matchers; -class WhileLoopInvariantCodeMotionTest : public HloVerifiedTestBase { +class WhileLoopInvariantCodeMotionTest : public HloTestBase { public: // Makes a computation which has one parameter, of the given shape, and always // returns PRED[]{true}. This is useful as a dummy loop condition. diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index e9ad188efbc..05005e0b262 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { @@ -29,7 +29,7 @@ namespace { namespace op = xla::testing::opcode_matchers; -class WhileLoopSimplifierTest : public HloVerifiedTestBase { +class WhileLoopSimplifierTest : public HloTestBase { protected: // Makes an HloModule that contains a loop with `num_iters` iteration. TF_MUST_USE_RESULT std::unique_ptr diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index e11aeffaddd..db34d34f969 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -141,44 +141,6 @@ cc_library( ], ) -cc_library( - name = "hlo_verified_test_base", - testonly = True, - srcs = ["hlo_verified_test_base.cc"], - hdrs = ["hlo_verified_test_base.h"], - deps = [ - ":hlo_test_base", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/service:hlo_verifier", - "//tensorflow/core:lib", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/memory", - ], -) - -tf_cc_test( - name = "hlo_verified_test_base_test", - srcs = ["hlo_verified_test_base_test.cc"], - deps = [ - ":hlo_test_base", - ":hlo_verified_test_base", - ":test_macros_cpu", - ":test_utils", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/service:hlo_verifier", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], -) - tf_cc_binary( name = "local_client_aot_test_helper", srcs = ["local_client_aot_test_helper.cc"], diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index afd2073d751..d8fa00272f8 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -85,6 +85,23 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) { } // namespace +Status VerifiedHloModule::Verify() { + if (computation_count() == 0) { + // The computation was never built. Nothing to verify. + return Status::OK(); + } + return verifier_.Run(this).status(); +} + +void VerifiedHloModule::VerifyOrAddFailure(const string& message) { + Status status = Verify(); + if (!status.ok()) { + ADD_FAILURE() << "HloVerifier failed on module " << name() + << (message.empty() ? "" : absl::StrCat(" (", message, ")")) + << ": " << status; + } +} + HloTestBase::HloTestBase(bool verifier_layout_sensitive, bool allow_mixed_precision_in_hlo_verifier, std::function @@ -100,7 +117,11 @@ HloTestBase::HloTestBase(se::Platform* test_platform, bool allow_mixed_precision_in_hlo_verifier, std::function instruction_can_change_layout_func) - : test_runner_(test_platform), reference_runner_(reference_platform) { + : test_runner_(test_platform), + reference_runner_(reference_platform), + verifier_layout_sensitive_(verifier_layout_sensitive), + allow_mixed_precision_in_hlo_verifier_( + allow_mixed_precision_in_hlo_verifier) { hlo_verifier_ = absl::make_unique( /*layout_sensitive=*/verifier_layout_sensitive, /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier, @@ -112,6 +133,32 @@ std::unique_ptr HloTestBase::CreateNewUnverifiedModule( return absl::make_unique(name, GetModuleConfigForTest()); } +std::unique_ptr HloTestBase::CreateNewVerifiedModule( + const string& name) { + return absl::make_unique( + name, GetModuleConfigForTest(), verifier_layout_sensitive_, + allow_mixed_precision_in_hlo_verifier_); +} + +StatusOr> +HloTestBase::ParseAndReturnUnverifiedModule(absl::string_view hlo_text, + const HloModuleConfig& config) { + auto module = absl::make_unique(TestName(), config); + TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get())); + return std::move(module); +} + +StatusOr> +HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text, + const HloModuleConfig& config) { + auto module = absl::make_unique( + TestName(), config, verifier_layout_sensitive_, + allow_mixed_precision_in_hlo_verifier_); + TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get())); + TF_RETURN_IF_ERROR(module->Verify()); + return std::move(module); +} + /* static */ StatusOr HloTestBase::RunHloPass(HloPassInterface* hlo_pass, HloModule* module) { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 75a79e52cab..366726d90b4 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -38,6 +38,31 @@ limitations under the License. namespace xla { +// An HLO module derived class which verifies itself on destruction. This class +// is intended to be used in unit tests. Any verification errors are raised via +// ADD_FAILURE. +class VerifiedHloModule : public HloModule { + public: + VerifiedHloModule(const string& name, const HloModuleConfig& config, + bool verifier_layout_sensitive, + bool allow_mixed_precision_in_hlo_verifier) + : HloModule(name, config), + verifier_(verifier_layout_sensitive, + allow_mixed_precision_in_hlo_verifier) {} + + ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); } + + // Verifies the module using HloVerifier and returns the status. + Status Verify(); + + // Verifies the module and flags any error with ADD_FAILURE. 'message' is + // included in the failure message. + void VerifyOrAddFailure(const string& message); + + private: + HloVerifier verifier_; +}; + // A base class for tests which build and/or run HLO code. The class includes // support for running an HLO module on two platforms and compare the results. // This is a lower level of abstraction than using the client interface and @@ -74,11 +99,26 @@ class HloTestBase : public ::testing::Test { // tests. // // This returns a vanilla HloModule that doesn't run the HLO verifier on - // destruction. If you want to run the verifier, you want - // HloVerifiedTestBase::CreateNewVerifiedModule. + // destruction. std::unique_ptr CreateNewUnverifiedModule( const string& name = TestName()); + // Like CreateNewUnverifiedModule, except the HloModule returned here runs the + // HLO verifier on destruction. + std::unique_ptr CreateNewVerifiedModule( + const string& name = TestName()); + + // Parses the given string and returns module as a vanilla, unverified + // HloModule. + StatusOr> ParseAndReturnUnverifiedModule( + absl::string_view hlo_text, + const HloModuleConfig& config = HloModuleConfig()); + + // Parses the given string and returns module as a VerifiedHloModule. + StatusOr> ParseAndReturnVerifiedModule( + absl::string_view hlo_text, + const HloModuleConfig& config = HloModuleConfig()); + // Runs the hlo_pass with the provided module and returns the result. This // function also verifies that the module remains unchanged when hlo_pass // returns false as the StatusOr value. @@ -252,6 +292,8 @@ class HloTestBase : public ::testing::Test { HloRunner test_runner_; HloRunner reference_runner_; + bool verifier_layout_sensitive_; + bool allow_mixed_precision_in_hlo_verifier_; std::unique_ptr hlo_verifier_; ErrorSpec error_spec_{0.0001}; diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc deleted file mode 100644 index 6a63c69c2b5..00000000000 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ /dev/null @@ -1,68 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" - -#include "absl/memory/memory.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/service/hlo_verifier.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/core/platform/logging.h" - -namespace xla { - -Status VerifiedHloModule::Verify() { - if (computation_count() == 0) { - // The computation was never built. Nothing to verify. - return Status::OK(); - } - return verifier_.Run(this).status(); -} - -void VerifiedHloModule::VerifyOrAddFailure(const string& message) { - Status status = Verify(); - if (!status.ok()) { - ADD_FAILURE() << "HloVerifier failed on module " << name() - << (message.empty() ? "" : absl::StrCat(" (", message, ")")) - << ": " << status; - } -} - -HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive, - bool allow_mixed_precision) - : HloTestBase( - /*verifier_layout_sensitive=*/layout_sensitive, - /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision), - verifier_layout_sensitive_(layout_sensitive), - allow_mixed_precision_in_hlo_verifier_(allow_mixed_precision) {} - -StatusOr> -HloVerifiedTestBase::ParseAndReturnVerifiedModule( - absl::string_view hlo_text, const HloModuleConfig& config) { - auto module = CreateNewVerifiedModule(TestName()); - TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get())); - TF_RETURN_IF_ERROR(module->Verify()); - return std::move(module); -} - -std::unique_ptr HloVerifiedTestBase::CreateNewVerifiedModule( - const string& name) { - return absl::make_unique( - name, GetModuleConfigForTest(), verifier_layout_sensitive_, - allow_mixed_precision_in_hlo_verifier_); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h deleted file mode 100644 index 65fdf9a9194..00000000000 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h +++ /dev/null @@ -1,86 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_ -#define TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_ - -#include -#include -#include - -#include "absl/base/macros.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" - -namespace xla { - -// An HLO module derived class which verifies itself on destruction. This class -// is intended to be used in unit tests. Any verification errors are raised via -// ADD_FAILURE. -class VerifiedHloModule : public HloModule { - public: - VerifiedHloModule(const string& name, const HloModuleConfig& config, - bool verifier_layout_sensitive, - bool allow_mixed_precision_in_hlo_verifier) - : HloModule(name, config), - verifier_(verifier_layout_sensitive, - allow_mixed_precision_in_hlo_verifier) {} - - ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); } - - // Verifies the module using HloVerifier and returns the status. - Status Verify(); - - // Verifies the module and flags any error with ADD_FAILURE. 'message' is - // included in the failure message. - void VerifyOrAddFailure(const string& message); - - private: - HloVerifier verifier_; -}; - -// A base class for HLO tests that stores a default VerifiedHloModule. -class HloVerifiedTestBase : public HloTestBase { - protected: - HloVerifiedTestBase(bool layout_sensitive = false, - bool allow_mixed_precision = false); - - // Constructs a default shape verifier. - std::unique_ptr MakeShapeVerifier(); - - // Parses the given string and returns module as a VerifiedHloModule. - StatusOr> ParseAndReturnVerifiedModule( - absl::string_view hlo_text, - const HloModuleConfig& config = HloModuleConfig()); - - // Creates and returns a verified HLO module with the given name. - std::unique_ptr CreateNewVerifiedModule( - const string& name = TestName()); - - // CreateNewUnverifiedModule creates an *unverified* module, which presumably - // isn't what you want if you're using HloVerifiedTestBase, so we delete this - // function to keep you from accidentally calling it. If you really want it, - // you can get it by calling HloTestBase::CreateNewUnverifiedModule(). - std::unique_ptr CreateNewUnverifiedModule( - const string& name = TestName()) = delete; - - private: - bool verifier_layout_sensitive_; - bool allow_mixed_precision_in_hlo_verifier_; -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_ diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc deleted file mode 100644 index 6df4c1aad2e..00000000000 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc +++ /dev/null @@ -1,115 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" - -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_verifier.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { -namespace { - -// This class includes unit tests which are expected to fail because invalid HLO -// modules are intentionally built. Unfortunately, Tensorflow doesn't appear to -// include the necessary gunit parts to test this test machinery (needs the -// macro EXPECT_NONFATAL_FAILURE). The disabled tests can be run with the -// disabled tests enabled and failures can be manually compared against -// expectations. -class HloVerifiedTestBaseTest : public HloVerifiedTestBase {}; - -XLA_TEST_F(HloVerifiedTestBaseTest, NoModule) { - // Test shouldn't fail if no module is created at all. -} - -XLA_TEST_F(HloVerifiedTestBaseTest, GoodCreateNewUnverifiedModule) { - // Call CreateNewUnverifiedModule and build up a valid module. - auto module = CreateNewVerifiedModule(); - auto builder = HloComputation::Builder(TestName()); - auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); - module->AddEntryComputation(builder.Build()); -} - -// This test is expected to fail. See test class comment. -XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadCreateNewUnverifiedModule) { - // Call CreateNewUnverifiedModule and build up a invalid module. - auto module = CreateNewVerifiedModule(); - auto builder = HloComputation::Builder(TestName()); - auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); - module->AddEntryComputation(builder.Build()); - - *module->entry_computation()->root_instruction()->mutable_shape() = - ShapeUtil::MakeShape(PRED, {1, 2, 3}); -} - -XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleGood) { - const char* const hlo_string = R"( -HloModule ParseAndReturnVerifiedModuleGood - -ENTRY entry { - x = f32[] parameter(0) - y = f32[] parameter(1) - ROOT add = f32[] add(x,y) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_string)); - EXPECT_EQ(module->entry_computation()->instruction_count(), 3); -} - -XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleInvalidText) { - const char* const hlo_string = R"( -HloModule ParseAndReturnVerifiedModuleGood - -ENTRY entry { - x = f32[] parameter(0) - y = f32[] parameter(1) - ROOT add = f32[] add(x,y) -} - -RANDOM GARBAGE -)"; - - ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status()); -} - -// This test is expected to fail. See test class comment. -XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_ParseAndReturnVerifiedModuleBad) { - const char* const hlo_string = R"( -HloModule ParseAndReturnVerifiedModuleBad - -ENTRY entry { - x = f32[] parameter(0) - y = f32[] parameter(1) - ROOT add = f32[1234] add(x,y) -} -)"; - - ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status()); -} - -} // namespace -} // namespace xla