Extract VerifiedHloModule to its own file.

Use it in local_client_test_base to also allow to create verified HloModules.
Change a few more tests to use verified HloModules.

PiperOrigin-RevId: 275026854
Change-Id: I82cdfeba08b1037d22171204b0281d29f9bc11b1
This commit is contained in:
Adrian Kuegel 2019-10-16 07:23:46 -07:00 committed by TensorFlower Gardener
parent 77b175ecad
commit 27715f81b7
10 changed files with 194 additions and 78 deletions

View File

@ -107,6 +107,25 @@ cc_library(
],
)
cc_library(
name = "verified_hlo_module",
testonly = True,
srcs = ["verified_hlo_module.cc"],
hdrs = ["verified_hlo_module.h"],
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:test",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "hlo_test_base",
testonly = True,
@ -115,6 +134,7 @@ cc_library(
deps = [
":literal_test_util",
":test_utils",
":verified_hlo_module",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:shape_layout",
"//tensorflow/compiler/xla:shape_util",
@ -249,6 +269,8 @@ cc_library(
srcs = ["local_client_test_base.cc"],
hdrs = ["local_client_test_base.h"],
deps = [
":client_library_test_base",
":verified_hlo_module",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@ -259,17 +281,19 @@ cc_library(
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:local_service",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/stream_executor:device_memory_allocator",
"//third_party/eigen3",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
@ -1851,7 +1875,11 @@ xla_test(
"interpreter",
],
deps = [
":client_library_test_base",
":hlo_test_base",
":literal_test_util",
":test_macros_header",
":xla_internal_test_main",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@ -1862,10 +1890,6 @@ xla_test(
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:test",
@ -2080,15 +2104,15 @@ xla_test(
name = "broadcast_test",
srcs = ["broadcast_test.cc"],
deps = [
":hlo_test_base",
":literal_test_util",
":test_macros_header",
":xla_internal_test_main",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
"@com_google_absl//absl/memory",
],
@ -2433,10 +2457,10 @@ xla_test(
":local_client_test_base",
":test_macros_header",
":test_utils",
":xla_internal_test_main",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"@com_google_absl//absl/base",

View File

@ -44,7 +44,7 @@ XLA_TEST_F(TrivialAllReduceTest, OneOperand) {
ROOT crs = f32[3] all-reduce(p), to_apply=add
})";
auto module =
ParseAndReturnUnverifiedModule(module_str, GetModuleConfigForTest())
ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())
.ValueOrDie();
auto literal = LiteralUtil::CreateR1<float>({1, 2, 3});
EXPECT_EQ(literal, ExecuteAndTransfer(std::move(module), {&literal}));
@ -66,7 +66,7 @@ XLA_TEST_F(TrivialAllReduceTest, MultipleOperands) {
ROOT crs = (f32[3], f32[2]) all-reduce(p0, p1), to_apply=add
})";
auto module =
ParseAndReturnUnverifiedModule(module_str, GetModuleConfigForTest())
ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())
.ValueOrDie();
auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
@ -93,7 +93,7 @@ XLA_TEST_F(TrivialAllReduceTest, ConstantOperand) {
ROOT crs = (f32[3], f32[2]) all-reduce(p0, p1), to_apply=add
})";
auto module =
ParseAndReturnUnverifiedModule(module_str, GetModuleConfigForTest())
ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())
.ValueOrDie();
auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
auto literal1 = LiteralUtil::CreateR1<float>({10, 20});

View File

@ -42,7 +42,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) {
ShapeUtil::MakeShape(F32, {}), input, {}));
// Create HLO module, compile, and execute.
auto hlo_module = CreateNewUnverifiedModule();
auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@ -58,7 +58,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
ShapeUtil::MakeShape(F32, {2, 2}), input, {}));
// Create HLO module, compile, and execute.
auto hlo_module = CreateNewUnverifiedModule();
auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@ -81,7 +81,7 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
builder.AddInstruction(HloInstruction::CreateTuple({element1, element2}));
// Create HLO module, compile, and execute.
auto hlo_module = CreateNewUnverifiedModule();
auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@ -102,7 +102,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1}));
// Create HLO module, compile, and execute.
auto hlo_module = CreateNewUnverifiedModule();
auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@ -121,7 +121,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0}));
// Create HLO module, compile, and execute.
auto hlo_module = CreateNewUnverifiedModule();
auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@ -138,7 +138,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2}));
// Create HLO module, compile, and execute.
auto hlo_module = CreateNewUnverifiedModule();
auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@ -158,7 +158,7 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
ShapeUtil::MakeShape(F32, {2, 2, 3, 3}), input, {1}));
// Create HLO module, compile, and execute.
auto hlo_module = CreateNewUnverifiedModule();
auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@ -183,7 +183,7 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
ShapeUtil::MakeShape(F32, {3, 3, 3, r1_size}), input, {3}));
// Create HLO module, compile, and execute.
auto hlo_module = CreateNewUnverifiedModule();
auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@ -214,7 +214,7 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
ShapeUtil::MakeShape(F32, {32, 64, 7, 7}), input, {1}));
// Create HLO module, compile, and execute.
auto hlo_module = CreateNewUnverifiedModule();
auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@ -230,7 +230,7 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {}));
// Create HLO module, compile, and execute.
auto hlo_module = CreateNewUnverifiedModule();
auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build());
LOG(INFO) << hlo_module->ToString();
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@ -253,7 +253,7 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
ShapeUtil::MakeShape(F32, {3, 3, 2, 2}), input, {2, 3}));
// Create HLO module, compile, and execute.
auto hlo_module = CreateNewUnverifiedModule();
auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@ -287,7 +287,7 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), input, {0, 1, 2}));
// Create HLO module, compile, and execute.
auto hlo_module = CreateNewUnverifiedModule();
auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});

View File

@ -85,25 +85,6 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) {
} // namespace
Status VerifiedHloModule::Verify() {
if (computation_count() == 0) {
// The computation was never built. Nothing to verify.
return Status::OK();
}
return verifier_.Run(this).status();
}
void VerifiedHloModule::VerifyOrAddFailure(const string& message) {
Status status = Verify();
if (!status.ok()) {
ADD_FAILURE() << "HloVerifier failed on module " << name()
<< (message.empty() ? "" : absl::StrCat(" (", message, ")"))
<< ": " << status;
LOG(ERROR) << "Contents of bad module:";
XLA_LOG_LINES(tensorflow::ERROR, ToString());
}
}
HloTestBase::HloTestBase(bool verifier_layout_sensitive,
bool allow_mixed_precision_in_hlo_verifier,
std::function<bool(const HloInstruction*)>

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@ -39,33 +40,6 @@ limitations under the License.
namespace xla {
// An HLO module derived class which verifies itself on destruction. This class
// is intended to be used in unit tests. Any verification errors are raised via
// ADD_FAILURE.
class VerifiedHloModule : public HloModule {
public:
VerifiedHloModule(const string& name, const HloModuleConfig& config,
bool verifier_layout_sensitive,
bool allow_mixed_precision_in_hlo_verifier,
std::function<int64(const Shape&)> shape_size_function)
: HloModule(name, config),
verifier_(
verifier_layout_sensitive, allow_mixed_precision_in_hlo_verifier,
/*instruction_can_change_layout_func=*/{}, shape_size_function) {}
~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); }
// Verifies the module using HloVerifier and returns the status.
Status Verify();
// Verifies the module and flags any error with ADD_FAILURE. 'message' is
// included in the failure message.
void VerifyOrAddFailure(const string& message);
private:
HloVerifier verifier_;
};
// A base class for tests which build and/or run HLO code. The class includes
// support for running an HLO module on two platforms and compare the results.
// This is a lower level of abstraction than using the client interface and

View File

@ -16,16 +16,22 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/local_client_test_base.h"
#include <memory>
#include <vector>
#include "absl/memory/memory.h"
#include "absl/strings/string_view.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/env.h"
@ -205,4 +211,21 @@ StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
return std::move(ret);
}
StatusOr<std::unique_ptr<VerifiedHloModule>>
LocalClientTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text) {
return ParseAndReturnVerifiedModule(hlo_text, HloModuleConfig());
}
StatusOr<std::unique_ptr<VerifiedHloModule>>
LocalClientTestBase::ParseAndReturnVerifiedModule(
absl::string_view hlo_text, const HloModuleConfig& config) {
auto module = absl::make_unique<VerifiedHloModule>(
TestName(), config, /*verifier_layout_sensitive=*/false,
/*allow_mixed_precision_in_hlo_verifier=*/true,
local_client_->backend().compiler()->ShapeSizeBytesFunction());
TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get()));
TF_RETURN_IF_ERROR(module->Verify());
return std::move(module);
}
} // namespace xla

View File

@ -20,16 +20,19 @@ limitations under the License.
#include <memory>
#include <vector>
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/local_service.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@ -109,6 +112,12 @@ class LocalClientTestBase : public ::testing::Test {
const ExecutableBuildOptions& build_options,
const ExecutableRunOptions& run_options);
// Parses the given string and returns module as a VerifiedHloModule.
StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
absl::string_view hlo_text);
StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
absl::string_view hlo_text, const HloModuleConfig& config);
// Returns a default set of execute options.
ExecutableBuildOptions DefaultExecutableBuildOptions() const;

View File

@ -75,7 +75,7 @@ XLA_TEST_F(TestUtilsTest, Token) {
}
XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) {
auto module = ParseAndReturnUnverifiedModule(
auto module = ParseAndReturnVerifiedModule(
R"(HloModule index_space_module
ENTRY IndexSpace {
@ -103,7 +103,7 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) {
}
XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) {
auto module = ParseAndReturnUnverifiedModule(
auto module = ParseAndReturnVerifiedModule(
R"(HloModule index_space_module
ENTRY IndexSpace {
@ -135,7 +135,7 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) {
XLA_TEST_F(TestUtilsTest, NoDuplicatesFloats) {
// Inputs which are sort keys in key/value sorts should have no duplicates.
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule sort.148.1589
compare {
@ -166,7 +166,7 @@ ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> (
XLA_TEST_F(TestUtilsTest, NoDuplicatesInt32) {
// Inputs which are sort keys in key/value sorts should have no duplicates.
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule sort.148.1589
compare {
@ -197,7 +197,7 @@ ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> (
XLA_TEST_F(TestUtilsTest, NoDuplicatesBfloat16) {
// Inputs which are sort keys in key/value sorts should have no duplicates.
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule sort, is_scheduled=true
compare {
@ -227,7 +227,7 @@ ENTRY %sort. (parameter.0: bf16[2,1452], parameter.1: s32[2,1452]) -> (bf16[2,14
}
XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsR0InputToDynamicSlice) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule Test
ENTRY %module (parameter.0: s32[], parameter.1: f32[20,20]) -> f32[] {
@ -255,7 +255,7 @@ ENTRY %module (parameter.0: s32[], parameter.1: f32[20,20]) -> f32[] {
}
XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsForGather) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule Test
ENTRY %module(parameter.0: f32[200,100,300], parameter.1: s32[10,2]) ->
@ -289,7 +289,7 @@ ENTRY %module(parameter.0: f32[200,100,300], parameter.1: s32[10,2]) ->
}
XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsForScatter) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule Test
scatter_update (lhs: f32[], rhs: f32[]) -> f32[] {

View File

@ -0,0 +1,46 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
#include <string>
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
Status VerifiedHloModule::Verify() {
if (computation_count() == 0) {
// The computation was never built. Nothing to verify.
return Status::OK();
}
return verifier_.Run(this).status();
}
void VerifiedHloModule::VerifyOrAddFailure(const string& message) {
Status status = Verify();
if (!status.ok()) {
ADD_FAILURE() << "HloVerifier failed on module " << name()
<< (message.empty() ? "" : absl::StrCat(" (", message, ")"))
<< ": " << status;
LOG(ERROR) << "Contents of bad module:";
XLA_LOG_LINES(tensorflow::ERROR, ToString());
}
}
} // namespace xla

View File

@ -0,0 +1,59 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_TESTS_VERIFIED_HLO_MODULE_H_
#define TENSORFLOW_COMPILER_XLA_TESTS_VERIFIED_HLO_MODULE_H_
#include <functional>
#include <string>
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status.h"
namespace xla {
// An HLO module derived class which verifies itself on destruction. This class
// is intended to be used in unit tests. Any verification errors are raised via
// ADD_FAILURE.
class VerifiedHloModule : public HloModule {
public:
VerifiedHloModule(const string& name, const HloModuleConfig& config,
bool verifier_layout_sensitive,
bool allow_mixed_precision_in_hlo_verifier,
std::function<int64(const Shape&)> shape_size_function)
: HloModule(name, config),
verifier_(
verifier_layout_sensitive, allow_mixed_precision_in_hlo_verifier,
/*instruction_can_change_layout_func=*/{}, shape_size_function) {}
~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); }
// Verifies the module using HloVerifier and returns the status.
Status Verify();
// Verifies the module and flags any error with ADD_FAILURE. 'message' is
// included in the failure message.
void VerifyOrAddFailure(const string& message);
private:
HloVerifier verifier_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TESTS_VERIFIED_HLO_MODULE_H_