diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 0364f3d25fe..7566e59bed7 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -107,6 +107,25 @@ cc_library( ], ) +cc_library( + name = "verified_hlo_module", + testonly = True, + srcs = ["verified_hlo_module.cc"], + hdrs = ["verified_hlo_module.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/core:lib", + "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:test", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "hlo_test_base", testonly = True, @@ -115,6 +134,7 @@ cc_library( deps = [ ":literal_test_util", ":test_utils", + ":verified_hlo_module", "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", @@ -249,6 +269,8 @@ cc_library( srcs = ["local_client_test_base.cc"], hdrs = ["local_client_test_base.h"], deps = [ + ":client_library_test_base", + ":verified_hlo_module", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -259,17 +281,19 @@ cc_library( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:transfer_manager", - "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory_allocator", "//third_party/eigen3", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) @@ -1851,7 +1875,11 @@ xla_test( "interpreter", ], deps = [ + ":client_library_test_base", + ":hlo_test_base", + ":literal_test_util", ":test_macros_header", + ":xla_internal_test_main", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1862,10 +1890,6 @@ xla_test( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", @@ -2080,15 +2104,15 @@ xla_test( name = "broadcast_test", srcs = ["broadcast_test.cc"], deps = [ + ":hlo_test_base", + ":literal_test_util", ":test_macros_header", + ":xla_internal_test_main", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", "@com_google_absl//absl/memory", ], @@ -2433,10 +2457,10 @@ xla_test( ":local_client_test_base", ":test_macros_header", ":test_utils", + ":xla_internal_test_main", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/base", diff --git a/tensorflow/compiler/xla/tests/all_reduce_test.cc b/tensorflow/compiler/xla/tests/all_reduce_test.cc index 32a1509910b..41941a313d9 100644 --- a/tensorflow/compiler/xla/tests/all_reduce_test.cc +++ b/tensorflow/compiler/xla/tests/all_reduce_test.cc @@ -44,7 +44,7 @@ XLA_TEST_F(TrivialAllReduceTest, OneOperand) { ROOT crs = f32[3] all-reduce(p), to_apply=add })"; auto module = - ParseAndReturnUnverifiedModule(module_str, GetModuleConfigForTest()) + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()) .ValueOrDie(); auto literal = LiteralUtil::CreateR1({1, 2, 3}); EXPECT_EQ(literal, ExecuteAndTransfer(std::move(module), {&literal})); @@ -66,7 +66,7 @@ XLA_TEST_F(TrivialAllReduceTest, MultipleOperands) { ROOT crs = (f32[3], f32[2]) all-reduce(p0, p1), to_apply=add })"; auto module = - ParseAndReturnUnverifiedModule(module_str, GetModuleConfigForTest()) + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()) .ValueOrDie(); auto literal0 = LiteralUtil::CreateR1({1, 2, 3}); auto literal1 = LiteralUtil::CreateR1({10, 20}); @@ -93,7 +93,7 @@ XLA_TEST_F(TrivialAllReduceTest, ConstantOperand) { ROOT crs = (f32[3], f32[2]) all-reduce(p0, p1), to_apply=add })"; auto module = - ParseAndReturnUnverifiedModule(module_str, GetModuleConfigForTest()) + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()) .ValueOrDie(); auto literal0 = LiteralUtil::CreateR1({1, 2, 3}); auto literal1 = LiteralUtil::CreateR1({10, 20}); diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index 9930bfc95c2..37f8e32aeee 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -42,7 +42,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { ShapeUtil::MakeShape(F32, {}), input, {})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -58,7 +58,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { ShapeUtil::MakeShape(F32, {2, 2}), input, {})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -81,7 +81,7 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { builder.AddInstruction(HloInstruction::CreateTuple({element1, element2})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -102,7 +102,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -121,7 +121,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -138,7 +138,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -158,7 +158,7 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { ShapeUtil::MakeShape(F32, {2, 2, 3, 3}), input, {1})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -183,7 +183,7 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { ShapeUtil::MakeShape(F32, {3, 3, 3, r1_size}), input, {3})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -214,7 +214,7 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { ShapeUtil::MakeShape(F32, {32, 64, 7, 7}), input, {1})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -230,7 +230,7 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(builder.Build()); LOG(INFO) << hlo_module->ToString(); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -253,7 +253,7 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { ShapeUtil::MakeShape(F32, {3, 3, 2, 2}), input, {2, 3})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -287,7 +287,7 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), input, {0, 1, 2})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 64f5440b99f..e1b180d8359 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -85,25 +85,6 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) { } // namespace -Status VerifiedHloModule::Verify() { - if (computation_count() == 0) { - // The computation was never built. Nothing to verify. - return Status::OK(); - } - return verifier_.Run(this).status(); -} - -void VerifiedHloModule::VerifyOrAddFailure(const string& message) { - Status status = Verify(); - if (!status.ok()) { - ADD_FAILURE() << "HloVerifier failed on module " << name() - << (message.empty() ? "" : absl::StrCat(" (", message, ")")) - << ": " << status; - LOG(ERROR) << "Contents of bad module:"; - XLA_LOG_LINES(tensorflow::ERROR, ToString()); - } -} - HloTestBase::HloTestBase(bool verifier_layout_sensitive, bool allow_mixed_precision_in_hlo_verifier, std::function diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index d4a1788c928..848b334cfec 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/verified_hlo_module.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -39,33 +40,6 @@ limitations under the License. namespace xla { -// An HLO module derived class which verifies itself on destruction. This class -// is intended to be used in unit tests. Any verification errors are raised via -// ADD_FAILURE. -class VerifiedHloModule : public HloModule { - public: - VerifiedHloModule(const string& name, const HloModuleConfig& config, - bool verifier_layout_sensitive, - bool allow_mixed_precision_in_hlo_verifier, - std::function shape_size_function) - : HloModule(name, config), - verifier_( - verifier_layout_sensitive, allow_mixed_precision_in_hlo_verifier, - /*instruction_can_change_layout_func=*/{}, shape_size_function) {} - - ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); } - - // Verifies the module using HloVerifier and returns the status. - Status Verify(); - - // Verifies the module and flags any error with ADD_FAILURE. 'message' is - // included in the failure message. - void VerifyOrAddFailure(const string& message); - - private: - HloVerifier verifier_; -}; - // A base class for tests which build and/or run HLO code. The class includes // support for running an HLO module on two platforms and compare the results. // This is a lower level of abstraction than using the client interface and diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index 5c93ca5f2d1..1532f1b5d8d 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -16,16 +16,22 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/local_client_test_base.h" +#include #include #include "absl/memory/memory.h" +#include "absl/strings/string_view.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/byte_order.h" #include "tensorflow/core/platform/env.h" @@ -205,4 +211,21 @@ StatusOr LocalClientTestBase::ExecuteLocally( return std::move(ret); } +StatusOr> +LocalClientTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text) { + return ParseAndReturnVerifiedModule(hlo_text, HloModuleConfig()); +} + +StatusOr> +LocalClientTestBase::ParseAndReturnVerifiedModule( + absl::string_view hlo_text, const HloModuleConfig& config) { + auto module = absl::make_unique( + TestName(), config, /*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/true, + local_client_->backend().compiler()->ShapeSizeBytesFunction()); + TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get())); + TF_RETURN_IF_ERROR(module->Verify()); + return std::move(module); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index 877a658fa19..8908a855847 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -20,16 +20,19 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/verified_hlo_module.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -109,6 +112,12 @@ class LocalClientTestBase : public ::testing::Test { const ExecutableBuildOptions& build_options, const ExecutableRunOptions& run_options); + // Parses the given string and returns module as a VerifiedHloModule. + StatusOr> ParseAndReturnVerifiedModule( + absl::string_view hlo_text); + StatusOr> ParseAndReturnVerifiedModule( + absl::string_view hlo_text, const HloModuleConfig& config); + // Returns a default set of execute options. ExecutableBuildOptions DefaultExecutableBuildOptions() const; diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index a3ceebc2516..9db08a5b72f 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -75,7 +75,7 @@ XLA_TEST_F(TestUtilsTest, Token) { } XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) { - auto module = ParseAndReturnUnverifiedModule( + auto module = ParseAndReturnVerifiedModule( R"(HloModule index_space_module ENTRY IndexSpace { @@ -103,7 +103,7 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) { } XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) { - auto module = ParseAndReturnUnverifiedModule( + auto module = ParseAndReturnVerifiedModule( R"(HloModule index_space_module ENTRY IndexSpace { @@ -135,7 +135,7 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) { XLA_TEST_F(TestUtilsTest, NoDuplicatesFloats) { // Inputs which are sort keys in key/value sorts should have no duplicates. - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule sort.148.1589 compare { @@ -166,7 +166,7 @@ ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> ( XLA_TEST_F(TestUtilsTest, NoDuplicatesInt32) { // Inputs which are sort keys in key/value sorts should have no duplicates. - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule sort.148.1589 compare { @@ -197,7 +197,7 @@ ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> ( XLA_TEST_F(TestUtilsTest, NoDuplicatesBfloat16) { // Inputs which are sort keys in key/value sorts should have no duplicates. - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule sort, is_scheduled=true compare { @@ -227,7 +227,7 @@ ENTRY %sort. (parameter.0: bf16[2,1452], parameter.1: s32[2,1452]) -> (bf16[2,14 } XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsR0InputToDynamicSlice) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule Test ENTRY %module (parameter.0: s32[], parameter.1: f32[20,20]) -> f32[] { @@ -255,7 +255,7 @@ ENTRY %module (parameter.0: s32[], parameter.1: f32[20,20]) -> f32[] { } XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsForGather) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule Test ENTRY %module(parameter.0: f32[200,100,300], parameter.1: s32[10,2]) -> @@ -289,7 +289,7 @@ ENTRY %module(parameter.0: f32[200,100,300], parameter.1: s32[10,2]) -> } XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsForScatter) { - auto module = ParseAndReturnUnverifiedModule(R"( + auto module = ParseAndReturnVerifiedModule(R"( HloModule Test scatter_update (lhs: f32[], rhs: f32[]) -> f32[] { diff --git a/tensorflow/compiler/xla/tests/verified_hlo_module.cc b/tensorflow/compiler/xla/tests/verified_hlo_module.cc new file mode 100644 index 00000000000..cd0c4073a26 --- /dev/null +++ b/tensorflow/compiler/xla/tests/verified_hlo_module.cc @@ -0,0 +1,46 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/tests/verified_hlo_module.h" + +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { + +Status VerifiedHloModule::Verify() { + if (computation_count() == 0) { + // The computation was never built. Nothing to verify. + return Status::OK(); + } + return verifier_.Run(this).status(); +} + +void VerifiedHloModule::VerifyOrAddFailure(const string& message) { + Status status = Verify(); + if (!status.ok()) { + ADD_FAILURE() << "HloVerifier failed on module " << name() + << (message.empty() ? "" : absl::StrCat(" (", message, ")")) + << ": " << status; + LOG(ERROR) << "Contents of bad module:"; + XLA_LOG_LINES(tensorflow::ERROR, ToString()); + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/verified_hlo_module.h b/tensorflow/compiler/xla/tests/verified_hlo_module.h new file mode 100644 index 00000000000..1c13773acd4 --- /dev/null +++ b/tensorflow/compiler/xla/tests/verified_hlo_module.h @@ -0,0 +1,59 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_TESTS_VERIFIED_HLO_MODULE_H_ +#define TENSORFLOW_COMPILER_XLA_TESTS_VERIFIED_HLO_MODULE_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status.h" + +namespace xla { + +// An HLO module derived class which verifies itself on destruction. This class +// is intended to be used in unit tests. Any verification errors are raised via +// ADD_FAILURE. +class VerifiedHloModule : public HloModule { + public: + VerifiedHloModule(const string& name, const HloModuleConfig& config, + bool verifier_layout_sensitive, + bool allow_mixed_precision_in_hlo_verifier, + std::function shape_size_function) + : HloModule(name, config), + verifier_( + verifier_layout_sensitive, allow_mixed_precision_in_hlo_verifier, + /*instruction_can_change_layout_func=*/{}, shape_size_function) {} + + ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); } + + // Verifies the module using HloVerifier and returns the status. + Status Verify(); + + // Verifies the module and flags any error with ADD_FAILURE. 'message' is + // included in the failure message. + void VerifyOrAddFailure(const string& message); + + private: + HloVerifier verifier_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TESTS_VERIFIED_HLO_MODULE_H_