diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 64c14e38e5a..16f6de1bcde 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -227,7 +227,6 @@ cc_library( deps = [ ":codegen_test_base", ":filecheck", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:test", @@ -333,7 +332,6 @@ xla_test( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:bfloat16_normalization", "//tensorflow/compiler/xla/service:despecializer", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -359,7 +357,6 @@ xla_test( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:bfloat16_normalization", "//tensorflow/compiler/xla/service:despecializer", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -372,8 +369,6 @@ xla_test( timeout = "long", srcs = ["grouped_convolution_test.cc"], blacklisted_backends = [ - # disabled because of a break b/119590850. - "gpu", # disabled because it times out. "cpu", ], @@ -386,7 +381,6 @@ xla_test( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:bfloat16_normalization", "//tensorflow/compiler/xla/service:despecializer", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -749,7 +743,6 @@ cc_library( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:bfloat16_normalization", "//tensorflow/compiler/xla/service:despecializer", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -995,7 +988,6 @@ xla_test( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -1009,7 +1001,6 @@ xla_test( ":test_macros_header", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -1549,7 +1540,6 @@ xla_test( srcs = ["reduce_hlo_test.cc"], deps = [ ":test_macros_header", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -1563,7 +1553,6 @@ xla_test( srcs = ["token_hlo_test.cc"], deps = [ ":test_macros_header", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1889,7 +1878,6 @@ xla_test( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", @@ -2123,6 +2111,7 @@ tf_cc_test( srcs = ["llvm_compiler_test.cc"], tags = tf_cuda_tests_tags(), deps = [ + ":verified_hlo_module", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/service:backend", @@ -2175,7 +2164,6 @@ xla_test( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", diff --git a/tensorflow/compiler/xla/tests/all_reduce_test.cc b/tensorflow/compiler/xla/tests/all_reduce_test.cc index 41941a313d9..33a8db8de32 100644 --- a/tensorflow/compiler/xla/tests/all_reduce_test.cc +++ b/tensorflow/compiler/xla/tests/all_reduce_test.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" diff --git a/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc b/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc index 1be982e37c3..ff7e7955876 100644 --- a/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc +++ b/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/service/bfloat16_normalization.h" #include "tensorflow/compiler/xla/service/despecializer.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" diff --git a/tensorflow/compiler/xla/tests/conv_depthwise_common.cc b/tensorflow/compiler/xla/tests/conv_depthwise_common.cc index e11ec33e730..f21f965c69e 100644 --- a/tensorflow/compiler/xla/tests/conv_depthwise_common.cc +++ b/tensorflow/compiler/xla/tests/conv_depthwise_common.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/service/bfloat16_normalization.h" #include "tensorflow/compiler/xla/service/despecializer.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" diff --git a/tensorflow/compiler/xla/tests/conv_depthwise_common.h b/tensorflow/compiler/xla/tests/conv_depthwise_common.h index 0c00f8d0abe..18c92f21862 100644 --- a/tensorflow/compiler/xla/tests/conv_depthwise_common.h +++ b/tensorflow/compiler/xla/tests/conv_depthwise_common.h @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/service/bfloat16_normalization.h" #include "tensorflow/compiler/xla/service/despecializer.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" diff --git a/tensorflow/compiler/xla/tests/conv_depthwise_test.cc b/tensorflow/compiler/xla/tests/conv_depthwise_test.cc index 6d8ddc199e2..d99aaef22db 100644 --- a/tensorflow/compiler/xla/tests/conv_depthwise_test.cc +++ b/tensorflow/compiler/xla/tests/conv_depthwise_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/service/bfloat16_normalization.h" #include "tensorflow/compiler/xla/service/despecializer.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index df005a67097..2c281377974 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -45,7 +45,7 @@ class CopyOpTest : public HloTestBase { builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kCopy, constant)); auto computation = builder.Build(); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(std::move(computation)); Literal result = ExecuteAndTransfer(std::move(module), {}); @@ -98,7 +98,7 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) { auto computation = builder.Build(); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(std::move(computation)); Literal result = ExecuteAndTransfer(std::move(module), {&literal}); @@ -119,7 +119,7 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) { auto computation = builder.Build(); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(std::move(computation)); Literal result = ExecuteAndTransfer(std::move(module), {}); LiteralTestUtil::ExpectR2Near({{1.0, 2.0}, {3.0, 4.0}}, result, @@ -145,7 +145,7 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { std::unique_ptr computation = builder.Build(); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(std::move(computation)); Literal result = ExecuteAndTransfer(std::move(module), {}); @@ -177,7 +177,7 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) { std::unique_ptr computation = builder.Build(); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(std::move(computation)); ForceResultLayout(module.get(), LayoutUtil::MakeLayout({1, 2, 0})); Literal result = ExecuteAndTransfer(std::move(module), {}); @@ -211,7 +211,7 @@ void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, std::unique_ptr computation = builder.Build(); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(std::move(computation)); ForceResultLayout(module.get(), LayoutUtil::MakeLayout(permutation)); Literal result = ExecuteAndTransfer(std::move(module), {}); diff --git a/tensorflow/compiler/xla/tests/cpu_gpu_fusion_test.cc b/tensorflow/compiler/xla/tests/cpu_gpu_fusion_test.cc index 7719e89f9e8..83ed3c93df1 100644 --- a/tensorflow/compiler/xla/tests/cpu_gpu_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/cpu_gpu_fusion_test.cc @@ -35,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 63c3b4b5b02..a7a6c7bd2be 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -79,7 +79,7 @@ class CustomCallTest : public HloTestBase { }; XLA_TEST_F(CustomCallTest, CustomCallR0F32Add2) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( @@ -94,7 +94,7 @@ XLA_TEST_F(CustomCallTest, CustomCallR0F32Add2) { } XLA_TEST_F(CustomCallTest, CustomCallR2F32Reduce) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Array2D array(2, 2); @@ -115,7 +115,7 @@ XLA_TEST_F(CustomCallTest, CustomCallR2F32Reduce) { } XLA_TEST_F(CustomCallTest, UsedInOtherComputations) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto b = HloComputation::Builder(TestName()); auto input = b.AddInstruction( @@ -139,7 +139,7 @@ XLA_TEST_F(CustomCallTest, UsedInOtherComputations) { } XLA_TEST_F(CustomCallTest, InputAndOutputLayoutDiffer) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto b = HloComputation::Builder(TestName()); auto input = @@ -164,7 +164,7 @@ XLA_TEST_F(CustomCallTest, LayoutConstrained) { // The argument and result of the computation are set to different layouts, // but the custom call is layout constrained to a fixed operand and result // layout, so the correct result should be produced. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto b = HloComputation::Builder(TestName()); auto input = diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index fb0eb666a89..47d3546fc41 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -38,7 +37,7 @@ class GatherOperationTest : public HloTestBase { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_text, config)); + ParseAndReturnVerifiedModule(hlo_text, config)); EXPECT_TRUE(RunAndCompare(std::move(module), args, nullopt)); } }; diff --git a/tensorflow/compiler/xla/tests/grouped_convolution_test.cc b/tensorflow/compiler/xla/tests/grouped_convolution_test.cc index bfabfe44aa0..4b06fe2678f 100644 --- a/tensorflow/compiler/xla/tests/grouped_convolution_test.cc +++ b/tensorflow/compiler/xla/tests/grouped_convolution_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/service/bfloat16_normalization.h" #include "tensorflow/compiler/xla/service/despecializer.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index a78ccacec11..ccccda1cab4 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/llvm_compiler.h" + #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/backend.h" @@ -22,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/verified_hlo_module.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/stream_executor/stream_executor.h" @@ -68,7 +70,7 @@ class LLVMCompilerTest : public ::testing::Test { builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(builder.Build()); compiler->SetPreOptimizationHook(pre_opt_hook); @@ -90,7 +92,7 @@ class LLVMCompilerTest : public ::testing::Test { builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - std::unique_ptr hlo_module = CreateNewUnverifiedModule(); + std::unique_ptr hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto module_group = absl::make_unique("test_module_group"); @@ -124,10 +126,13 @@ class LLVMCompilerTest : public ::testing::Test { return ::testing::UnitTest::GetInstance()->current_test_info()->name(); } - static std::unique_ptr CreateNewUnverifiedModule() { + std::unique_ptr CreateNewVerifiedModule() { HloModuleConfig config; config.set_debug_options(GetDebugOptionsFromFlags()); - return absl::make_unique(TestName(), config); + return absl::make_unique( + TestName(), config, /*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/true, + backend_->compiler()->ShapeSizeBytesFunction()); } }; diff --git a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc index 8df4a57afcd..8b95c17d199 100644 --- a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc +++ b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/tests/filecheck.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -66,7 +65,7 @@ void LlvmIrGenTestBase::CompileAndVerifyIr(const string& hlo_text, HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_text, config)); + ParseAndReturnVerifiedModule(hlo_text, config)); CompileAndVerifyIr(std::move(module), expected_llvm_ir, match_optimized_ir); } diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 7578094e07f..0dcc0c278ae 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -62,7 +62,7 @@ class MultiOutputFusionTest : public HloTestBase { void RunTest2D(bool manual_fusion, int64 size) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); const Shape elem_shape0 = ShapeUtil::MakeShapeWithLayout(F32, {}, {}); const Shape elem_shape2 = @@ -122,7 +122,7 @@ class MultiOutputFusionTest : public HloTestBase { void RunTest1D(bool manual_fusion, int size) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); const Shape elem_shape_F32 = ShapeUtil::MakeShapeWithDescendingLayout(F32, {size}); diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc index 05e56ab1870..d0d6a91e84b 100644 --- a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/platform/test.h" @@ -50,10 +49,10 @@ void PrintTo(const ReduceLayout& reduce_layout, ::std::ostream* os) { class ReduceWithLayoutTest : public HloTestBase, - public ::testing::WithParamInterface {}; - -StatusOr> GetParsedModule() { - const char* const hlo_string = R"( + public ::testing::WithParamInterface { + public: + StatusOr> GetParsedModule() { + const char* const hlo_string = R"( HloModule BadReduce Sum { @@ -70,12 +69,11 @@ ENTRY reduce.1 { } )"; - return ParseAndReturnUnverifiedModule(hlo_string); -} + return ParseAndReturnVerifiedModule(hlo_string); + } +}; -// TODO(b/72454718): XLA:GPU does not support executing code compiled without -// optimizations. -XLA_TEST_P(ReduceWithLayoutTest, DISABLED_ON_GPU(Reduce)) { +XLA_TEST_P(ReduceWithLayoutTest, Reduce) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, GetParsedModule()); HloInstruction* reduce_instruction = module->entry_computation()->root_instruction()->mutable_operand(0); diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc index 0fdd176e8ef..c7b95f389de 100644 --- a/tensorflow/compiler/xla/tests/scatter_test.cc +++ b/tensorflow/compiler/xla/tests/scatter_test.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -36,7 +35,7 @@ class ScatterTest : public HloTestBase { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_text, config)); + ParseAndReturnVerifiedModule(hlo_text, config)); EXPECT_TRUE(RunAndCompare(std::move(module), args, nullopt)); } }; @@ -158,7 +157,7 @@ ENTRY main { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule(hlo_text, config)); + ParseAndReturnVerifiedModule(hlo_text, config)); auto actual = ExecuteAndTransfer(std::move(module), {&permutation}); Literal expected = LiteralUtil::CreateR2({{3, 0, 2, 1}, {1, 3, 2, 0}, {3, 2, 0, 1}}); diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc index 4a4fd29f091..d631f02bd09 100644 --- a/tensorflow/compiler/xla/tests/token_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include #include "absl/strings/str_cat.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -29,7 +28,7 @@ namespace { class TokenHloTest : public HloTestBase {}; XLA_TEST_F(TokenHloTest, SingleTokenInstruction) { - std::unique_ptr module = CreateNewUnverifiedModule(); + std::unique_ptr module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); builder.AddInstruction(HloInstruction::CreateToken()); @@ -40,7 +39,7 @@ XLA_TEST_F(TokenHloTest, SingleTokenInstruction) { } XLA_TEST_F(TokenHloTest, TokenInTuple) { - std::unique_ptr module = CreateNewUnverifiedModule(); + std::unique_ptr module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto token = builder.AddInstruction(HloInstruction::CreateToken()); builder.AddInstruction(HloInstruction::CreateTuple({token})); @@ -54,7 +53,7 @@ XLA_TEST_F(TokenHloTest, TokenInTuple) { } XLA_TEST_F(TokenHloTest, TokenTree) { - std::unique_ptr module = CreateNewUnverifiedModule(); + std::unique_ptr module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto token0 = builder.AddInstruction(HloInstruction::CreateToken()); auto token1 = builder.AddInstruction(HloInstruction::CreateToken()); @@ -222,7 +221,7 @@ ENTRY %AddDependency (p0: f32[], p1: f32[]) -> f32[] { )"; TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_string, GetModuleConfigForTest())); + ParseAndReturnVerifiedModule(module_string, GetModuleConfigForTest())); auto p0 = LiteralUtil::CreateR0(10.0); auto p1 = LiteralUtil::CreateR0(3.0); auto expected = LiteralUtil::CreateR0(-156.0); @@ -243,7 +242,7 @@ ENTRY %AddDependency (p0: f32[]) -> f32[] { )"; TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_string, GetModuleConfigForTest())); + ParseAndReturnVerifiedModule(module_string, GetModuleConfigForTest())); auto p0 = LiteralUtil::CreateR0(10.0); auto expected = LiteralUtil::CreateR0(420.0); EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&p0})); @@ -261,7 +260,7 @@ ENTRY %AddDependency (p: f32[3]) -> f32[3] { )"; TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_string, GetModuleConfigForTest())); + ParseAndReturnVerifiedModule(module_string, GetModuleConfigForTest())); auto input = LiteralUtil::CreateR1({1.0, 3.0, 7.0}); auto expected = LiteralUtil::CreateR1({-1.0, -3.0, -7.0}); EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&input})); @@ -284,7 +283,7 @@ ENTRY %TupleShapedAddDependency (p0: f32[3], p1: f32[3]) -> f32[3] { )"; TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr module, - ParseAndReturnUnverifiedModule(module_string, GetModuleConfigForTest())); + ParseAndReturnVerifiedModule(module_string, GetModuleConfigForTest())); auto p0 = LiteralUtil::CreateR1({3.0, 3.0, 47.0}); auto p1 = LiteralUtil::CreateR1({1.0, -2.0, 2.0}); auto expected = LiteralUtil::CreateR1({2.0, 5.0, 45.0});