Extract VerifiedHloModule to its own file.

Use it in local_client_test_base to also allow to create verified HloModules.
Change a few more tests to use verified HloModules.

PiperOrigin-RevId: 275026854
Change-Id: I82cdfeba08b1037d22171204b0281d29f9bc11b1
This commit is contained in:
Adrian Kuegel 2019-10-16 07:23:46 -07:00 committed by TensorFlower Gardener
parent 77b175ecad
commit 27715f81b7
10 changed files with 194 additions and 78 deletions

View File

@ -107,6 +107,25 @@ cc_library(
], ],
) )
cc_library(
name = "verified_hlo_module",
testonly = True,
srcs = ["verified_hlo_module.cc"],
hdrs = ["verified_hlo_module.h"],
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:test",
"@com_google_absl//absl/strings",
],
)
cc_library( cc_library(
name = "hlo_test_base", name = "hlo_test_base",
testonly = True, testonly = True,
@ -115,6 +134,7 @@ cc_library(
deps = [ deps = [
":literal_test_util", ":literal_test_util",
":test_utils", ":test_utils",
":verified_hlo_module",
"//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_layout",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
@ -249,6 +269,8 @@ cc_library(
srcs = ["local_client_test_base.cc"], srcs = ["local_client_test_base.cc"],
hdrs = ["local_client_test_base.h"], hdrs = ["local_client_test_base.h"],
deps = [ deps = [
":client_library_test_base",
":verified_hlo_module",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
@ -259,17 +281,19 @@ cc_library(
"//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:local_service",
"//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/stream_executor:device_memory_allocator", "//tensorflow/stream_executor:device_memory_allocator",
"//third_party/eigen3", "//third_party/eigen3",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
], ],
) )
@ -1851,7 +1875,11 @@ xla_test(
"interpreter", "interpreter",
], ],
deps = [ deps = [
":client_library_test_base",
":hlo_test_base",
":literal_test_util",
":test_macros_header", ":test_macros_header",
":xla_internal_test_main",
"//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test",
@ -1862,10 +1890,6 @@ xla_test(
"//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:test", "//tensorflow/core:test",
@ -2080,15 +2104,15 @@ xla_test(
name = "broadcast_test", name = "broadcast_test",
srcs = ["broadcast_test.cc"], srcs = ["broadcast_test.cc"],
deps = [ deps = [
":hlo_test_base",
":literal_test_util",
":test_macros_header", ":test_macros_header",
":xla_internal_test_main",
"//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test", "//tensorflow/core:test",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
], ],
@ -2433,10 +2457,10 @@ xla_test(
":local_client_test_base", ":local_client_test_base",
":test_macros_header", ":test_macros_header",
":test_utils", ":test_utils",
":xla_internal_test_main",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:test", "//tensorflow/core:test",
"@com_google_absl//absl/base", "@com_google_absl//absl/base",

View File

@ -44,7 +44,7 @@ XLA_TEST_F(TrivialAllReduceTest, OneOperand) {
ROOT crs = f32[3] all-reduce(p), to_apply=add ROOT crs = f32[3] all-reduce(p), to_apply=add
})"; })";
auto module = auto module =
ParseAndReturnUnverifiedModule(module_str, GetModuleConfigForTest()) ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())
.ValueOrDie(); .ValueOrDie();
auto literal = LiteralUtil::CreateR1<float>({1, 2, 3}); auto literal = LiteralUtil::CreateR1<float>({1, 2, 3});
EXPECT_EQ(literal, ExecuteAndTransfer(std::move(module), {&literal})); EXPECT_EQ(literal, ExecuteAndTransfer(std::move(module), {&literal}));
@ -66,7 +66,7 @@ XLA_TEST_F(TrivialAllReduceTest, MultipleOperands) {
ROOT crs = (f32[3], f32[2]) all-reduce(p0, p1), to_apply=add ROOT crs = (f32[3], f32[2]) all-reduce(p0, p1), to_apply=add
})"; })";
auto module = auto module =
ParseAndReturnUnverifiedModule(module_str, GetModuleConfigForTest()) ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())
.ValueOrDie(); .ValueOrDie();
auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3}); auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
auto literal1 = LiteralUtil::CreateR1<float>({10, 20}); auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
@ -93,7 +93,7 @@ XLA_TEST_F(TrivialAllReduceTest, ConstantOperand) {
ROOT crs = (f32[3], f32[2]) all-reduce(p0, p1), to_apply=add ROOT crs = (f32[3], f32[2]) all-reduce(p0, p1), to_apply=add
})"; })";
auto module = auto module =
ParseAndReturnUnverifiedModule(module_str, GetModuleConfigForTest()) ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())
.ValueOrDie(); .ValueOrDie();
auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3}); auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
auto literal1 = LiteralUtil::CreateR1<float>({10, 20}); auto literal1 = LiteralUtil::CreateR1<float>({10, 20});

View File

@ -42,7 +42,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) {
ShapeUtil::MakeShape(F32, {}), input, {})); ShapeUtil::MakeShape(F32, {}), input, {}));
// Create HLO module, compile, and execute. // Create HLO module, compile, and execute.
auto hlo_module = CreateNewUnverifiedModule(); auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build()); hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@ -58,7 +58,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
ShapeUtil::MakeShape(F32, {2, 2}), input, {})); ShapeUtil::MakeShape(F32, {2, 2}), input, {}));
// Create HLO module, compile, and execute. // Create HLO module, compile, and execute.
auto hlo_module = CreateNewUnverifiedModule(); auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build()); hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@ -81,7 +81,7 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
builder.AddInstruction(HloInstruction::CreateTuple({element1, element2})); builder.AddInstruction(HloInstruction::CreateTuple({element1, element2}));
// Create HLO module, compile, and execute. // Create HLO module, compile, and execute.
auto hlo_module = CreateNewUnverifiedModule(); auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build()); hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@ -102,7 +102,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1})); ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1}));
// Create HLO module, compile, and execute. // Create HLO module, compile, and execute.
auto hlo_module = CreateNewUnverifiedModule(); auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build()); hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@ -121,7 +121,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0})); ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0}));
// Create HLO module, compile, and execute. // Create HLO module, compile, and execute.
auto hlo_module = CreateNewUnverifiedModule(); auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build()); hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@ -138,7 +138,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2})); ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2}));
// Create HLO module, compile, and execute. // Create HLO module, compile, and execute.
auto hlo_module = CreateNewUnverifiedModule(); auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build()); hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@ -158,7 +158,7 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
ShapeUtil::MakeShape(F32, {2, 2, 3, 3}), input, {1})); ShapeUtil::MakeShape(F32, {2, 2, 3, 3}), input, {1}));
// Create HLO module, compile, and execute. // Create HLO module, compile, and execute.
auto hlo_module = CreateNewUnverifiedModule(); auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build()); hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@ -183,7 +183,7 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
ShapeUtil::MakeShape(F32, {3, 3, 3, r1_size}), input, {3})); ShapeUtil::MakeShape(F32, {3, 3, 3, r1_size}), input, {3}));
// Create HLO module, compile, and execute. // Create HLO module, compile, and execute.
auto hlo_module = CreateNewUnverifiedModule(); auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build()); hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@ -214,7 +214,7 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
ShapeUtil::MakeShape(F32, {32, 64, 7, 7}), input, {1})); ShapeUtil::MakeShape(F32, {32, 64, 7, 7}), input, {1}));
// Create HLO module, compile, and execute. // Create HLO module, compile, and execute.
auto hlo_module = CreateNewUnverifiedModule(); auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build()); hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@ -230,7 +230,7 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {})); ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {}));
// Create HLO module, compile, and execute. // Create HLO module, compile, and execute.
auto hlo_module = CreateNewUnverifiedModule(); auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build()); hlo_module->AddEntryComputation(builder.Build());
LOG(INFO) << hlo_module->ToString(); LOG(INFO) << hlo_module->ToString();
auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@ -253,7 +253,7 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
ShapeUtil::MakeShape(F32, {3, 3, 2, 2}), input, {2, 3})); ShapeUtil::MakeShape(F32, {3, 3, 2, 2}), input, {2, 3}));
// Create HLO module, compile, and execute. // Create HLO module, compile, and execute.
auto hlo_module = CreateNewUnverifiedModule(); auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build()); hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto result = ExecuteAndTransfer(std::move(hlo_module), {});
@ -287,7 +287,7 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), input, {0, 1, 2})); ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), input, {0, 1, 2}));
// Create HLO module, compile, and execute. // Create HLO module, compile, and execute.
auto hlo_module = CreateNewUnverifiedModule(); auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build()); hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto result = ExecuteAndTransfer(std::move(hlo_module), {});

View File

@ -85,25 +85,6 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) {
} // namespace } // namespace
Status VerifiedHloModule::Verify() {
if (computation_count() == 0) {
// The computation was never built. Nothing to verify.
return Status::OK();
}
return verifier_.Run(this).status();
}
void VerifiedHloModule::VerifyOrAddFailure(const string& message) {
Status status = Verify();
if (!status.ok()) {
ADD_FAILURE() << "HloVerifier failed on module " << name()
<< (message.empty() ? "" : absl::StrCat(" (", message, ")"))
<< ": " << status;
LOG(ERROR) << "Contents of bad module:";
XLA_LOG_LINES(tensorflow::ERROR, ToString());
}
}
HloTestBase::HloTestBase(bool verifier_layout_sensitive, HloTestBase::HloTestBase(bool verifier_layout_sensitive,
bool allow_mixed_precision_in_hlo_verifier, bool allow_mixed_precision_in_hlo_verifier,
std::function<bool(const HloInstruction*)> std::function<bool(const HloInstruction*)>

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h"
@ -39,33 +40,6 @@ limitations under the License.
namespace xla { namespace xla {
// An HLO module derived class which verifies itself on destruction. This class
// is intended to be used in unit tests. Any verification errors are raised via
// ADD_FAILURE.
class VerifiedHloModule : public HloModule {
public:
VerifiedHloModule(const string& name, const HloModuleConfig& config,
bool verifier_layout_sensitive,
bool allow_mixed_precision_in_hlo_verifier,
std::function<int64(const Shape&)> shape_size_function)
: HloModule(name, config),
verifier_(
verifier_layout_sensitive, allow_mixed_precision_in_hlo_verifier,
/*instruction_can_change_layout_func=*/{}, shape_size_function) {}
~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); }
// Verifies the module using HloVerifier and returns the status.
Status Verify();
// Verifies the module and flags any error with ADD_FAILURE. 'message' is
// included in the failure message.
void VerifyOrAddFailure(const string& message);
private:
HloVerifier verifier_;
};
// A base class for tests which build and/or run HLO code. The class includes // A base class for tests which build and/or run HLO code. The class includes
// support for running an HLO module on two platforms and compare the results. // support for running an HLO module on two platforms and compare the results.
// This is a lower level of abstraction than using the client interface and // This is a lower level of abstraction than using the client interface and

View File

@ -16,16 +16,22 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/local_client_test_base.h" #include "tensorflow/compiler/xla/tests/local_client_test_base.h"
#include <memory>
#include <vector> #include <vector>
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/string_view.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/byte_order.h" #include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
@ -205,4 +211,21 @@ StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
return std::move(ret); return std::move(ret);
} }
StatusOr<std::unique_ptr<VerifiedHloModule>>
LocalClientTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text) {
return ParseAndReturnVerifiedModule(hlo_text, HloModuleConfig());
}
StatusOr<std::unique_ptr<VerifiedHloModule>>
LocalClientTestBase::ParseAndReturnVerifiedModule(
absl::string_view hlo_text, const HloModuleConfig& config) {
auto module = absl::make_unique<VerifiedHloModule>(
TestName(), config, /*verifier_layout_sensitive=*/false,
/*allow_mixed_precision_in_hlo_verifier=*/true,
local_client_->backend().compiler()->ShapeSizeBytesFunction());
TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get()));
TF_RETURN_IF_ERROR(module->Verify());
return std::move(module);
}
} // namespace xla } // namespace xla

View File

@ -20,16 +20,19 @@ limitations under the License.
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "absl/strings/string_view.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/local_service.h"
#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h"
@ -109,6 +112,12 @@ class LocalClientTestBase : public ::testing::Test {
const ExecutableBuildOptions& build_options, const ExecutableBuildOptions& build_options,
const ExecutableRunOptions& run_options); const ExecutableRunOptions& run_options);
// Parses the given string and returns module as a VerifiedHloModule.
StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
absl::string_view hlo_text);
StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
absl::string_view hlo_text, const HloModuleConfig& config);
// Returns a default set of execute options. // Returns a default set of execute options.
ExecutableBuildOptions DefaultExecutableBuildOptions() const; ExecutableBuildOptions DefaultExecutableBuildOptions() const;

View File

@ -75,7 +75,7 @@ XLA_TEST_F(TestUtilsTest, Token) {
} }
XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) { XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) {
auto module = ParseAndReturnUnverifiedModule( auto module = ParseAndReturnVerifiedModule(
R"(HloModule index_space_module R"(HloModule index_space_module
ENTRY IndexSpace { ENTRY IndexSpace {
@ -103,7 +103,7 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) {
} }
XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) { XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) {
auto module = ParseAndReturnUnverifiedModule( auto module = ParseAndReturnVerifiedModule(
R"(HloModule index_space_module R"(HloModule index_space_module
ENTRY IndexSpace { ENTRY IndexSpace {
@ -135,7 +135,7 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) {
XLA_TEST_F(TestUtilsTest, NoDuplicatesFloats) { XLA_TEST_F(TestUtilsTest, NoDuplicatesFloats) {
// Inputs which are sort keys in key/value sorts should have no duplicates. // Inputs which are sort keys in key/value sorts should have no duplicates.
auto module = ParseAndReturnUnverifiedModule(R"( auto module = ParseAndReturnVerifiedModule(R"(
HloModule sort.148.1589 HloModule sort.148.1589
compare { compare {
@ -166,7 +166,7 @@ ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> (
XLA_TEST_F(TestUtilsTest, NoDuplicatesInt32) { XLA_TEST_F(TestUtilsTest, NoDuplicatesInt32) {
// Inputs which are sort keys in key/value sorts should have no duplicates. // Inputs which are sort keys in key/value sorts should have no duplicates.
auto module = ParseAndReturnUnverifiedModule(R"( auto module = ParseAndReturnVerifiedModule(R"(
HloModule sort.148.1589 HloModule sort.148.1589
compare { compare {
@ -197,7 +197,7 @@ ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> (
XLA_TEST_F(TestUtilsTest, NoDuplicatesBfloat16) { XLA_TEST_F(TestUtilsTest, NoDuplicatesBfloat16) {
// Inputs which are sort keys in key/value sorts should have no duplicates. // Inputs which are sort keys in key/value sorts should have no duplicates.
auto module = ParseAndReturnUnverifiedModule(R"( auto module = ParseAndReturnVerifiedModule(R"(
HloModule sort, is_scheduled=true HloModule sort, is_scheduled=true
compare { compare {
@ -227,7 +227,7 @@ ENTRY %sort. (parameter.0: bf16[2,1452], parameter.1: s32[2,1452]) -> (bf16[2,14
} }
XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsR0InputToDynamicSlice) { XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsR0InputToDynamicSlice) {
auto module = ParseAndReturnUnverifiedModule(R"( auto module = ParseAndReturnVerifiedModule(R"(
HloModule Test HloModule Test
ENTRY %module (parameter.0: s32[], parameter.1: f32[20,20]) -> f32[] { ENTRY %module (parameter.0: s32[], parameter.1: f32[20,20]) -> f32[] {
@ -255,7 +255,7 @@ ENTRY %module (parameter.0: s32[], parameter.1: f32[20,20]) -> f32[] {
} }
XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsForGather) { XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsForGather) {
auto module = ParseAndReturnUnverifiedModule(R"( auto module = ParseAndReturnVerifiedModule(R"(
HloModule Test HloModule Test
ENTRY %module(parameter.0: f32[200,100,300], parameter.1: s32[10,2]) -> ENTRY %module(parameter.0: f32[200,100,300], parameter.1: s32[10,2]) ->
@ -289,7 +289,7 @@ ENTRY %module(parameter.0: f32[200,100,300], parameter.1: s32[10,2]) ->
} }
XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsForScatter) { XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsForScatter) {
auto module = ParseAndReturnUnverifiedModule(R"( auto module = ParseAndReturnVerifiedModule(R"(
HloModule Test HloModule Test
scatter_update (lhs: f32[], rhs: f32[]) -> f32[] { scatter_update (lhs: f32[], rhs: f32[]) -> f32[] {

View File

@ -0,0 +1,46 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
#include <string>
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
Status VerifiedHloModule::Verify() {
if (computation_count() == 0) {
// The computation was never built. Nothing to verify.
return Status::OK();
}
return verifier_.Run(this).status();
}
void VerifiedHloModule::VerifyOrAddFailure(const string& message) {
Status status = Verify();
if (!status.ok()) {
ADD_FAILURE() << "HloVerifier failed on module " << name()
<< (message.empty() ? "" : absl::StrCat(" (", message, ")"))
<< ": " << status;
LOG(ERROR) << "Contents of bad module:";
XLA_LOG_LINES(tensorflow::ERROR, ToString());
}
}
} // namespace xla

View File

@ -0,0 +1,59 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_TESTS_VERIFIED_HLO_MODULE_H_
#define TENSORFLOW_COMPILER_XLA_TESTS_VERIFIED_HLO_MODULE_H_
#include <functional>
#include <string>
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status.h"
namespace xla {
// An HLO module derived class which verifies itself on destruction. This class
// is intended to be used in unit tests. Any verification errors are raised via
// ADD_FAILURE.
class VerifiedHloModule : public HloModule {
public:
VerifiedHloModule(const string& name, const HloModuleConfig& config,
bool verifier_layout_sensitive,
bool allow_mixed_precision_in_hlo_verifier,
std::function<int64(const Shape&)> shape_size_function)
: HloModule(name, config),
verifier_(
verifier_layout_sensitive, allow_mixed_precision_in_hlo_verifier,
/*instruction_can_change_layout_func=*/{}, shape_size_function) {}
~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); }
// Verifies the module using HloVerifier and returns the status.
Status Verify();
// Verifies the module and flags any error with ADD_FAILURE. 'message' is
// included in the failure message.
void VerifyOrAddFailure(const string& message);
private:
HloVerifier verifier_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TESTS_VERIFIED_HLO_MODULE_H_