Extract VerifiedHloModule to its own file.
Use it in local_client_test_base to also allow to create verified HloModules. Change a few more tests to use verified HloModules. PiperOrigin-RevId: 275026854 Change-Id: I82cdfeba08b1037d22171204b0281d29f9bc11b1
This commit is contained in:
parent
77b175ecad
commit
27715f81b7
@ -107,6 +107,25 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "verified_hlo_module",
|
||||||
|
testonly = True,
|
||||||
|
srcs = ["verified_hlo_module.cc"],
|
||||||
|
hdrs = ["verified_hlo_module.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
|
"//tensorflow/compiler/xla:types",
|
||||||
|
"//tensorflow/compiler/xla:util",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo_verifier",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core/platform:logging",
|
||||||
|
"//tensorflow/core/platform:test",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "hlo_test_base",
|
name = "hlo_test_base",
|
||||||
testonly = True,
|
testonly = True,
|
||||||
@ -115,6 +134,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":literal_test_util",
|
":literal_test_util",
|
||||||
":test_utils",
|
":test_utils",
|
||||||
|
":verified_hlo_module",
|
||||||
"//tensorflow/compiler/xla:debug_options_flags",
|
"//tensorflow/compiler/xla:debug_options_flags",
|
||||||
"//tensorflow/compiler/xla:shape_layout",
|
"//tensorflow/compiler/xla:shape_layout",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
@ -249,6 +269,8 @@ cc_library(
|
|||||||
srcs = ["local_client_test_base.cc"],
|
srcs = ["local_client_test_base.cc"],
|
||||||
hdrs = ["local_client_test_base.h"],
|
hdrs = ["local_client_test_base.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":client_library_test_base",
|
||||||
|
":verified_hlo_module",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
@ -259,17 +281,19 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/client:local_client",
|
"//tensorflow/compiler/xla/client:local_client",
|
||||||
"//tensorflow/compiler/xla/client:xla_computation",
|
"//tensorflow/compiler/xla/client:xla_computation",
|
||||||
"//tensorflow/compiler/xla/service:computation_placer",
|
"//tensorflow/compiler/xla/service:computation_placer",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||||
"//tensorflow/compiler/xla/service:local_service",
|
"//tensorflow/compiler/xla/service:local_service",
|
||||||
"//tensorflow/compiler/xla/service:platform_util",
|
"//tensorflow/compiler/xla/service:platform_util",
|
||||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:stream_executor_no_cuda",
|
"//tensorflow/core:stream_executor_no_cuda",
|
||||||
"//tensorflow/stream_executor:device_memory_allocator",
|
"//tensorflow/stream_executor:device_memory_allocator",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -1851,7 +1875,11 @@ xla_test(
|
|||||||
"interpreter",
|
"interpreter",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":client_library_test_base",
|
||||||
|
":hlo_test_base",
|
||||||
|
":literal_test_util",
|
||||||
":test_macros_header",
|
":test_macros_header",
|
||||||
|
":xla_internal_test_main",
|
||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:test",
|
"//tensorflow/compiler/xla:test",
|
||||||
@ -1862,10 +1890,6 @@ xla_test(
|
|||||||
"//tensorflow/compiler/xla/client:xla_builder",
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
"//tensorflow/compiler/xla/client:xla_computation",
|
"//tensorflow/compiler/xla/client:xla_computation",
|
||||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
|
||||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:stream_executor_no_cuda",
|
"//tensorflow/core:stream_executor_no_cuda",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
@ -2080,15 +2104,15 @@ xla_test(
|
|||||||
name = "broadcast_test",
|
name = "broadcast_test",
|
||||||
srcs = ["broadcast_test.cc"],
|
srcs = ["broadcast_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":hlo_test_base",
|
||||||
|
":literal_test_util",
|
||||||
":test_macros_header",
|
":test_macros_header",
|
||||||
|
":xla_internal_test_main",
|
||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
|
||||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
],
|
],
|
||||||
@ -2433,10 +2457,10 @@ xla_test(
|
|||||||
":local_client_test_base",
|
":local_client_test_base",
|
||||||
":test_macros_header",
|
":test_macros_header",
|
||||||
":test_utils",
|
":test_utils",
|
||||||
|
":xla_internal_test_main",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla/client:xla_builder",
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"@com_google_absl//absl/base",
|
"@com_google_absl//absl/base",
|
||||||
|
@ -44,7 +44,7 @@ XLA_TEST_F(TrivialAllReduceTest, OneOperand) {
|
|||||||
ROOT crs = f32[3] all-reduce(p), to_apply=add
|
ROOT crs = f32[3] all-reduce(p), to_apply=add
|
||||||
})";
|
})";
|
||||||
auto module =
|
auto module =
|
||||||
ParseAndReturnUnverifiedModule(module_str, GetModuleConfigForTest())
|
ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())
|
||||||
.ValueOrDie();
|
.ValueOrDie();
|
||||||
auto literal = LiteralUtil::CreateR1<float>({1, 2, 3});
|
auto literal = LiteralUtil::CreateR1<float>({1, 2, 3});
|
||||||
EXPECT_EQ(literal, ExecuteAndTransfer(std::move(module), {&literal}));
|
EXPECT_EQ(literal, ExecuteAndTransfer(std::move(module), {&literal}));
|
||||||
@ -66,7 +66,7 @@ XLA_TEST_F(TrivialAllReduceTest, MultipleOperands) {
|
|||||||
ROOT crs = (f32[3], f32[2]) all-reduce(p0, p1), to_apply=add
|
ROOT crs = (f32[3], f32[2]) all-reduce(p0, p1), to_apply=add
|
||||||
})";
|
})";
|
||||||
auto module =
|
auto module =
|
||||||
ParseAndReturnUnverifiedModule(module_str, GetModuleConfigForTest())
|
ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())
|
||||||
.ValueOrDie();
|
.ValueOrDie();
|
||||||
auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
|
auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
|
||||||
auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
|
auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
|
||||||
@ -93,7 +93,7 @@ XLA_TEST_F(TrivialAllReduceTest, ConstantOperand) {
|
|||||||
ROOT crs = (f32[3], f32[2]) all-reduce(p0, p1), to_apply=add
|
ROOT crs = (f32[3], f32[2]) all-reduce(p0, p1), to_apply=add
|
||||||
})";
|
})";
|
||||||
auto module =
|
auto module =
|
||||||
ParseAndReturnUnverifiedModule(module_str, GetModuleConfigForTest())
|
ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())
|
||||||
.ValueOrDie();
|
.ValueOrDie();
|
||||||
auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
|
auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
|
||||||
auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
|
auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
|
||||||
|
@ -42,7 +42,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) {
|
|||||||
ShapeUtil::MakeShape(F32, {}), input, {}));
|
ShapeUtil::MakeShape(F32, {}), input, {}));
|
||||||
|
|
||||||
// Create HLO module, compile, and execute.
|
// Create HLO module, compile, and execute.
|
||||||
auto hlo_module = CreateNewUnverifiedModule();
|
auto hlo_module = CreateNewVerifiedModule();
|
||||||
hlo_module->AddEntryComputation(builder.Build());
|
hlo_module->AddEntryComputation(builder.Build());
|
||||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||||
|
|
||||||
@ -58,7 +58,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
|
|||||||
ShapeUtil::MakeShape(F32, {2, 2}), input, {}));
|
ShapeUtil::MakeShape(F32, {2, 2}), input, {}));
|
||||||
|
|
||||||
// Create HLO module, compile, and execute.
|
// Create HLO module, compile, and execute.
|
||||||
auto hlo_module = CreateNewUnverifiedModule();
|
auto hlo_module = CreateNewVerifiedModule();
|
||||||
hlo_module->AddEntryComputation(builder.Build());
|
hlo_module->AddEntryComputation(builder.Build());
|
||||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||||
|
|
||||||
@ -81,7 +81,7 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
|
|||||||
builder.AddInstruction(HloInstruction::CreateTuple({element1, element2}));
|
builder.AddInstruction(HloInstruction::CreateTuple({element1, element2}));
|
||||||
|
|
||||||
// Create HLO module, compile, and execute.
|
// Create HLO module, compile, and execute.
|
||||||
auto hlo_module = CreateNewUnverifiedModule();
|
auto hlo_module = CreateNewVerifiedModule();
|
||||||
hlo_module->AddEntryComputation(builder.Build());
|
hlo_module->AddEntryComputation(builder.Build());
|
||||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||||
|
|
||||||
@ -102,7 +102,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
|
|||||||
ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1}));
|
ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1}));
|
||||||
|
|
||||||
// Create HLO module, compile, and execute.
|
// Create HLO module, compile, and execute.
|
||||||
auto hlo_module = CreateNewUnverifiedModule();
|
auto hlo_module = CreateNewVerifiedModule();
|
||||||
hlo_module->AddEntryComputation(builder.Build());
|
hlo_module->AddEntryComputation(builder.Build());
|
||||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||||
|
|
||||||
@ -121,7 +121,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
|
|||||||
ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0}));
|
ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0}));
|
||||||
|
|
||||||
// Create HLO module, compile, and execute.
|
// Create HLO module, compile, and execute.
|
||||||
auto hlo_module = CreateNewUnverifiedModule();
|
auto hlo_module = CreateNewVerifiedModule();
|
||||||
hlo_module->AddEntryComputation(builder.Build());
|
hlo_module->AddEntryComputation(builder.Build());
|
||||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||||
|
|
||||||
@ -138,7 +138,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
|
|||||||
ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2}));
|
ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2}));
|
||||||
|
|
||||||
// Create HLO module, compile, and execute.
|
// Create HLO module, compile, and execute.
|
||||||
auto hlo_module = CreateNewUnverifiedModule();
|
auto hlo_module = CreateNewVerifiedModule();
|
||||||
hlo_module->AddEntryComputation(builder.Build());
|
hlo_module->AddEntryComputation(builder.Build());
|
||||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||||
|
|
||||||
@ -158,7 +158,7 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
|
|||||||
ShapeUtil::MakeShape(F32, {2, 2, 3, 3}), input, {1}));
|
ShapeUtil::MakeShape(F32, {2, 2, 3, 3}), input, {1}));
|
||||||
|
|
||||||
// Create HLO module, compile, and execute.
|
// Create HLO module, compile, and execute.
|
||||||
auto hlo_module = CreateNewUnverifiedModule();
|
auto hlo_module = CreateNewVerifiedModule();
|
||||||
hlo_module->AddEntryComputation(builder.Build());
|
hlo_module->AddEntryComputation(builder.Build());
|
||||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||||
|
|
||||||
@ -183,7 +183,7 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
|
|||||||
ShapeUtil::MakeShape(F32, {3, 3, 3, r1_size}), input, {3}));
|
ShapeUtil::MakeShape(F32, {3, 3, 3, r1_size}), input, {3}));
|
||||||
|
|
||||||
// Create HLO module, compile, and execute.
|
// Create HLO module, compile, and execute.
|
||||||
auto hlo_module = CreateNewUnverifiedModule();
|
auto hlo_module = CreateNewVerifiedModule();
|
||||||
hlo_module->AddEntryComputation(builder.Build());
|
hlo_module->AddEntryComputation(builder.Build());
|
||||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||||
|
|
||||||
@ -214,7 +214,7 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
|
|||||||
ShapeUtil::MakeShape(F32, {32, 64, 7, 7}), input, {1}));
|
ShapeUtil::MakeShape(F32, {32, 64, 7, 7}), input, {1}));
|
||||||
|
|
||||||
// Create HLO module, compile, and execute.
|
// Create HLO module, compile, and execute.
|
||||||
auto hlo_module = CreateNewUnverifiedModule();
|
auto hlo_module = CreateNewVerifiedModule();
|
||||||
hlo_module->AddEntryComputation(builder.Build());
|
hlo_module->AddEntryComputation(builder.Build());
|
||||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||||
|
|
||||||
@ -230,7 +230,7 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
|
|||||||
ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {}));
|
ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {}));
|
||||||
|
|
||||||
// Create HLO module, compile, and execute.
|
// Create HLO module, compile, and execute.
|
||||||
auto hlo_module = CreateNewUnverifiedModule();
|
auto hlo_module = CreateNewVerifiedModule();
|
||||||
hlo_module->AddEntryComputation(builder.Build());
|
hlo_module->AddEntryComputation(builder.Build());
|
||||||
LOG(INFO) << hlo_module->ToString();
|
LOG(INFO) << hlo_module->ToString();
|
||||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||||
@ -253,7 +253,7 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
|
|||||||
ShapeUtil::MakeShape(F32, {3, 3, 2, 2}), input, {2, 3}));
|
ShapeUtil::MakeShape(F32, {3, 3, 2, 2}), input, {2, 3}));
|
||||||
|
|
||||||
// Create HLO module, compile, and execute.
|
// Create HLO module, compile, and execute.
|
||||||
auto hlo_module = CreateNewUnverifiedModule();
|
auto hlo_module = CreateNewVerifiedModule();
|
||||||
hlo_module->AddEntryComputation(builder.Build());
|
hlo_module->AddEntryComputation(builder.Build());
|
||||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||||
|
|
||||||
@ -287,7 +287,7 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
|
|||||||
ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), input, {0, 1, 2}));
|
ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), input, {0, 1, 2}));
|
||||||
|
|
||||||
// Create HLO module, compile, and execute.
|
// Create HLO module, compile, and execute.
|
||||||
auto hlo_module = CreateNewUnverifiedModule();
|
auto hlo_module = CreateNewVerifiedModule();
|
||||||
hlo_module->AddEntryComputation(builder.Build());
|
hlo_module->AddEntryComputation(builder.Build());
|
||||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||||
|
|
||||||
|
@ -85,25 +85,6 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status VerifiedHloModule::Verify() {
|
|
||||||
if (computation_count() == 0) {
|
|
||||||
// The computation was never built. Nothing to verify.
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
return verifier_.Run(this).status();
|
|
||||||
}
|
|
||||||
|
|
||||||
void VerifiedHloModule::VerifyOrAddFailure(const string& message) {
|
|
||||||
Status status = Verify();
|
|
||||||
if (!status.ok()) {
|
|
||||||
ADD_FAILURE() << "HloVerifier failed on module " << name()
|
|
||||||
<< (message.empty() ? "" : absl::StrCat(" (", message, ")"))
|
|
||||||
<< ": " << status;
|
|
||||||
LOG(ERROR) << "Contents of bad module:";
|
|
||||||
XLA_LOG_LINES(tensorflow::ERROR, ToString());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
HloTestBase::HloTestBase(bool verifier_layout_sensitive,
|
HloTestBase::HloTestBase(bool verifier_layout_sensitive,
|
||||||
bool allow_mixed_precision_in_hlo_verifier,
|
bool allow_mixed_precision_in_hlo_verifier,
|
||||||
std::function<bool(const HloInstruction*)>
|
std::function<bool(const HloInstruction*)>
|
||||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/shape_layout.h"
|
#include "tensorflow/compiler/xla/shape_layout.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||||
@ -39,33 +40,6 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
// An HLO module derived class which verifies itself on destruction. This class
|
|
||||||
// is intended to be used in unit tests. Any verification errors are raised via
|
|
||||||
// ADD_FAILURE.
|
|
||||||
class VerifiedHloModule : public HloModule {
|
|
||||||
public:
|
|
||||||
VerifiedHloModule(const string& name, const HloModuleConfig& config,
|
|
||||||
bool verifier_layout_sensitive,
|
|
||||||
bool allow_mixed_precision_in_hlo_verifier,
|
|
||||||
std::function<int64(const Shape&)> shape_size_function)
|
|
||||||
: HloModule(name, config),
|
|
||||||
verifier_(
|
|
||||||
verifier_layout_sensitive, allow_mixed_precision_in_hlo_verifier,
|
|
||||||
/*instruction_can_change_layout_func=*/{}, shape_size_function) {}
|
|
||||||
|
|
||||||
~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); }
|
|
||||||
|
|
||||||
// Verifies the module using HloVerifier and returns the status.
|
|
||||||
Status Verify();
|
|
||||||
|
|
||||||
// Verifies the module and flags any error with ADD_FAILURE. 'message' is
|
|
||||||
// included in the failure message.
|
|
||||||
void VerifyOrAddFailure(const string& message);
|
|
||||||
|
|
||||||
private:
|
|
||||||
HloVerifier verifier_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// A base class for tests which build and/or run HLO code. The class includes
|
// A base class for tests which build and/or run HLO code. The class includes
|
||||||
// support for running an HLO module on two platforms and compare the results.
|
// support for running an HLO module on two platforms and compare the results.
|
||||||
// This is a lower level of abstraction than using the client interface and
|
// This is a lower level of abstraction than using the client interface and
|
||||||
|
@ -16,16 +16,22 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/tests/local_client_test_base.h"
|
#include "tensorflow/compiler/xla/tests/local_client_test_base.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||||
#include "tensorflow/compiler/xla/map_util.h"
|
#include "tensorflow/compiler/xla/map_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/threadpool.h"
|
#include "tensorflow/core/lib/core/threadpool.h"
|
||||||
#include "tensorflow/core/platform/byte_order.h"
|
#include "tensorflow/core/platform/byte_order.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
@ -205,4 +211,21 @@ StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
|
|||||||
return std::move(ret);
|
return std::move(ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<std::unique_ptr<VerifiedHloModule>>
|
||||||
|
LocalClientTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text) {
|
||||||
|
return ParseAndReturnVerifiedModule(hlo_text, HloModuleConfig());
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<std::unique_ptr<VerifiedHloModule>>
|
||||||
|
LocalClientTestBase::ParseAndReturnVerifiedModule(
|
||||||
|
absl::string_view hlo_text, const HloModuleConfig& config) {
|
||||||
|
auto module = absl::make_unique<VerifiedHloModule>(
|
||||||
|
TestName(), config, /*verifier_layout_sensitive=*/false,
|
||||||
|
/*allow_mixed_precision_in_hlo_verifier=*/true,
|
||||||
|
local_client_->backend().compiler()->ShapeSizeBytesFunction());
|
||||||
|
TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get()));
|
||||||
|
TF_RETURN_IF_ERROR(module->Verify());
|
||||||
|
return std::move(module);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -20,16 +20,19 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||||
#include "tensorflow/compiler/xla/service/local_service.h"
|
#include "tensorflow/compiler/xla/service/local_service.h"
|
||||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||||
@ -109,6 +112,12 @@ class LocalClientTestBase : public ::testing::Test {
|
|||||||
const ExecutableBuildOptions& build_options,
|
const ExecutableBuildOptions& build_options,
|
||||||
const ExecutableRunOptions& run_options);
|
const ExecutableRunOptions& run_options);
|
||||||
|
|
||||||
|
// Parses the given string and returns module as a VerifiedHloModule.
|
||||||
|
StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
|
||||||
|
absl::string_view hlo_text);
|
||||||
|
StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
|
||||||
|
absl::string_view hlo_text, const HloModuleConfig& config);
|
||||||
|
|
||||||
// Returns a default set of execute options.
|
// Returns a default set of execute options.
|
||||||
ExecutableBuildOptions DefaultExecutableBuildOptions() const;
|
ExecutableBuildOptions DefaultExecutableBuildOptions() const;
|
||||||
|
|
||||||
|
@ -75,7 +75,7 @@ XLA_TEST_F(TestUtilsTest, Token) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) {
|
XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) {
|
||||||
auto module = ParseAndReturnUnverifiedModule(
|
auto module = ParseAndReturnVerifiedModule(
|
||||||
R"(HloModule index_space_module
|
R"(HloModule index_space_module
|
||||||
|
|
||||||
ENTRY IndexSpace {
|
ENTRY IndexSpace {
|
||||||
@ -103,7 +103,7 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) {
|
XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) {
|
||||||
auto module = ParseAndReturnUnverifiedModule(
|
auto module = ParseAndReturnVerifiedModule(
|
||||||
R"(HloModule index_space_module
|
R"(HloModule index_space_module
|
||||||
|
|
||||||
ENTRY IndexSpace {
|
ENTRY IndexSpace {
|
||||||
@ -135,7 +135,7 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) {
|
|||||||
|
|
||||||
XLA_TEST_F(TestUtilsTest, NoDuplicatesFloats) {
|
XLA_TEST_F(TestUtilsTest, NoDuplicatesFloats) {
|
||||||
// Inputs which are sort keys in key/value sorts should have no duplicates.
|
// Inputs which are sort keys in key/value sorts should have no duplicates.
|
||||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
auto module = ParseAndReturnVerifiedModule(R"(
|
||||||
HloModule sort.148.1589
|
HloModule sort.148.1589
|
||||||
|
|
||||||
compare {
|
compare {
|
||||||
@ -166,7 +166,7 @@ ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> (
|
|||||||
|
|
||||||
XLA_TEST_F(TestUtilsTest, NoDuplicatesInt32) {
|
XLA_TEST_F(TestUtilsTest, NoDuplicatesInt32) {
|
||||||
// Inputs which are sort keys in key/value sorts should have no duplicates.
|
// Inputs which are sort keys in key/value sorts should have no duplicates.
|
||||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
auto module = ParseAndReturnVerifiedModule(R"(
|
||||||
HloModule sort.148.1589
|
HloModule sort.148.1589
|
||||||
|
|
||||||
compare {
|
compare {
|
||||||
@ -197,7 +197,7 @@ ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> (
|
|||||||
|
|
||||||
XLA_TEST_F(TestUtilsTest, NoDuplicatesBfloat16) {
|
XLA_TEST_F(TestUtilsTest, NoDuplicatesBfloat16) {
|
||||||
// Inputs which are sort keys in key/value sorts should have no duplicates.
|
// Inputs which are sort keys in key/value sorts should have no duplicates.
|
||||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
auto module = ParseAndReturnVerifiedModule(R"(
|
||||||
HloModule sort, is_scheduled=true
|
HloModule sort, is_scheduled=true
|
||||||
|
|
||||||
compare {
|
compare {
|
||||||
@ -227,7 +227,7 @@ ENTRY %sort. (parameter.0: bf16[2,1452], parameter.1: s32[2,1452]) -> (bf16[2,14
|
|||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsR0InputToDynamicSlice) {
|
XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsR0InputToDynamicSlice) {
|
||||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
auto module = ParseAndReturnVerifiedModule(R"(
|
||||||
HloModule Test
|
HloModule Test
|
||||||
|
|
||||||
ENTRY %module (parameter.0: s32[], parameter.1: f32[20,20]) -> f32[] {
|
ENTRY %module (parameter.0: s32[], parameter.1: f32[20,20]) -> f32[] {
|
||||||
@ -255,7 +255,7 @@ ENTRY %module (parameter.0: s32[], parameter.1: f32[20,20]) -> f32[] {
|
|||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsForGather) {
|
XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsForGather) {
|
||||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
auto module = ParseAndReturnVerifiedModule(R"(
|
||||||
HloModule Test
|
HloModule Test
|
||||||
|
|
||||||
ENTRY %module(parameter.0: f32[200,100,300], parameter.1: s32[10,2]) ->
|
ENTRY %module(parameter.0: f32[200,100,300], parameter.1: s32[10,2]) ->
|
||||||
@ -289,7 +289,7 @@ ENTRY %module(parameter.0: f32[200,100,300], parameter.1: s32[10,2]) ->
|
|||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsForScatter) {
|
XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsForScatter) {
|
||||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
auto module = ParseAndReturnVerifiedModule(R"(
|
||||||
HloModule Test
|
HloModule Test
|
||||||
|
|
||||||
scatter_update (lhs: f32[], rhs: f32[]) -> f32[] {
|
scatter_update (lhs: f32[], rhs: f32[]) -> f32[] {
|
||||||
|
46
tensorflow/compiler/xla/tests/verified_hlo_module.cc
Normal file
46
tensorflow/compiler/xla/tests/verified_hlo_module.cc
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
Status VerifiedHloModule::Verify() {
|
||||||
|
if (computation_count() == 0) {
|
||||||
|
// The computation was never built. Nothing to verify.
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
return verifier_.Run(this).status();
|
||||||
|
}
|
||||||
|
|
||||||
|
void VerifiedHloModule::VerifyOrAddFailure(const string& message) {
|
||||||
|
Status status = Verify();
|
||||||
|
if (!status.ok()) {
|
||||||
|
ADD_FAILURE() << "HloVerifier failed on module " << name()
|
||||||
|
<< (message.empty() ? "" : absl::StrCat(" (", message, ")"))
|
||||||
|
<< ": " << status;
|
||||||
|
LOG(ERROR) << "Contents of bad module:";
|
||||||
|
XLA_LOG_LINES(tensorflow::ERROR, ToString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla
|
59
tensorflow/compiler/xla/tests/verified_hlo_module.h
Normal file
59
tensorflow/compiler/xla/tests/verified_hlo_module.h
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_COMPILER_XLA_TESTS_VERIFIED_HLO_MODULE_H_
|
||||||
|
#define TENSORFLOW_COMPILER_XLA_TESTS_VERIFIED_HLO_MODULE_H_
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
||||||
|
#include "tensorflow/compiler/xla/shape.h"
|
||||||
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
// An HLO module derived class which verifies itself on destruction. This class
|
||||||
|
// is intended to be used in unit tests. Any verification errors are raised via
|
||||||
|
// ADD_FAILURE.
|
||||||
|
class VerifiedHloModule : public HloModule {
|
||||||
|
public:
|
||||||
|
VerifiedHloModule(const string& name, const HloModuleConfig& config,
|
||||||
|
bool verifier_layout_sensitive,
|
||||||
|
bool allow_mixed_precision_in_hlo_verifier,
|
||||||
|
std::function<int64(const Shape&)> shape_size_function)
|
||||||
|
: HloModule(name, config),
|
||||||
|
verifier_(
|
||||||
|
verifier_layout_sensitive, allow_mixed_precision_in_hlo_verifier,
|
||||||
|
/*instruction_can_change_layout_func=*/{}, shape_size_function) {}
|
||||||
|
|
||||||
|
~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); }
|
||||||
|
|
||||||
|
// Verifies the module using HloVerifier and returns the status.
|
||||||
|
Status Verify();
|
||||||
|
|
||||||
|
// Verifies the module and flags any error with ADD_FAILURE. 'message' is
|
||||||
|
// included in the failure message.
|
||||||
|
void VerifyOrAddFailure(const string& message);
|
||||||
|
|
||||||
|
private:
|
||||||
|
HloVerifier verifier_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_XLA_TESTS_VERIFIED_HLO_MODULE_H_
|
Loading…
Reference in New Issue
Block a user