Extract VerifiedHloModule to its own file.
Use it in local_client_test_base to also allow to create verified HloModules. Change a few more tests to use verified HloModules. PiperOrigin-RevId: 275026854 Change-Id: I82cdfeba08b1037d22171204b0281d29f9bc11b1
This commit is contained in:
parent
77b175ecad
commit
27715f81b7
@ -107,6 +107,25 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "verified_hlo_module",
|
||||
testonly = True,
|
||||
srcs = ["verified_hlo_module.cc"],
|
||||
hdrs = ["verified_hlo_module.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||
"//tensorflow/compiler/xla/service:hlo_verifier",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/core/platform:test",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_test_base",
|
||||
testonly = True,
|
||||
@ -115,6 +134,7 @@ cc_library(
|
||||
deps = [
|
||||
":literal_test_util",
|
||||
":test_utils",
|
||||
":verified_hlo_module",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla:shape_layout",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -249,6 +269,8 @@ cc_library(
|
||||
srcs = ["local_client_test_base.cc"],
|
||||
hdrs = ["local_client_test_base.h"],
|
||||
deps = [
|
||||
":client_library_test_base",
|
||||
":verified_hlo_module",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -259,17 +281,19 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/service:local_service",
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"//third_party/eigen3",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
@ -1851,7 +1875,11 @@ xla_test(
|
||||
"interpreter",
|
||||
],
|
||||
deps = [
|
||||
":client_library_test_base",
|
||||
":hlo_test_base",
|
||||
":literal_test_util",
|
||||
":test_macros_header",
|
||||
":xla_internal_test_main",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
@ -1862,10 +1890,6 @@ xla_test(
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/core:test",
|
||||
@ -2080,15 +2104,15 @@ xla_test(
|
||||
name = "broadcast_test",
|
||||
srcs = ["broadcast_test.cc"],
|
||||
deps = [
|
||||
":hlo_test_base",
|
||||
":literal_test_util",
|
||||
":test_macros_header",
|
||||
":xla_internal_test_main",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
@ -2433,10 +2457,10 @@ xla_test(
|
||||
":local_client_test_base",
|
||||
":test_macros_header",
|
||||
":test_utils",
|
||||
":xla_internal_test_main",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"@com_google_absl//absl/base",
|
||||
|
@ -44,7 +44,7 @@ XLA_TEST_F(TrivialAllReduceTest, OneOperand) {
|
||||
ROOT crs = f32[3] all-reduce(p), to_apply=add
|
||||
})";
|
||||
auto module =
|
||||
ParseAndReturnUnverifiedModule(module_str, GetModuleConfigForTest())
|
||||
ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())
|
||||
.ValueOrDie();
|
||||
auto literal = LiteralUtil::CreateR1<float>({1, 2, 3});
|
||||
EXPECT_EQ(literal, ExecuteAndTransfer(std::move(module), {&literal}));
|
||||
@ -66,7 +66,7 @@ XLA_TEST_F(TrivialAllReduceTest, MultipleOperands) {
|
||||
ROOT crs = (f32[3], f32[2]) all-reduce(p0, p1), to_apply=add
|
||||
})";
|
||||
auto module =
|
||||
ParseAndReturnUnverifiedModule(module_str, GetModuleConfigForTest())
|
||||
ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())
|
||||
.ValueOrDie();
|
||||
auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
|
||||
auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
|
||||
@ -93,7 +93,7 @@ XLA_TEST_F(TrivialAllReduceTest, ConstantOperand) {
|
||||
ROOT crs = (f32[3], f32[2]) all-reduce(p0, p1), to_apply=add
|
||||
})";
|
||||
auto module =
|
||||
ParseAndReturnUnverifiedModule(module_str, GetModuleConfigForTest())
|
||||
ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())
|
||||
.ValueOrDie();
|
||||
auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
|
||||
auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
|
||||
|
@ -42,7 +42,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) {
|
||||
ShapeUtil::MakeShape(F32, {}), input, {}));
|
||||
|
||||
// Create HLO module, compile, and execute.
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
hlo_module->AddEntryComputation(builder.Build());
|
||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||
|
||||
@ -58,7 +58,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
|
||||
ShapeUtil::MakeShape(F32, {2, 2}), input, {}));
|
||||
|
||||
// Create HLO module, compile, and execute.
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
hlo_module->AddEntryComputation(builder.Build());
|
||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||
|
||||
@ -81,7 +81,7 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({element1, element2}));
|
||||
|
||||
// Create HLO module, compile, and execute.
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
hlo_module->AddEntryComputation(builder.Build());
|
||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||
|
||||
@ -102,7 +102,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
|
||||
ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1}));
|
||||
|
||||
// Create HLO module, compile, and execute.
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
hlo_module->AddEntryComputation(builder.Build());
|
||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||
|
||||
@ -121,7 +121,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
|
||||
ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0}));
|
||||
|
||||
// Create HLO module, compile, and execute.
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
hlo_module->AddEntryComputation(builder.Build());
|
||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||
|
||||
@ -138,7 +138,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
|
||||
ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2}));
|
||||
|
||||
// Create HLO module, compile, and execute.
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
hlo_module->AddEntryComputation(builder.Build());
|
||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||
|
||||
@ -158,7 +158,7 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
|
||||
ShapeUtil::MakeShape(F32, {2, 2, 3, 3}), input, {1}));
|
||||
|
||||
// Create HLO module, compile, and execute.
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
hlo_module->AddEntryComputation(builder.Build());
|
||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||
|
||||
@ -183,7 +183,7 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
|
||||
ShapeUtil::MakeShape(F32, {3, 3, 3, r1_size}), input, {3}));
|
||||
|
||||
// Create HLO module, compile, and execute.
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
hlo_module->AddEntryComputation(builder.Build());
|
||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||
|
||||
@ -214,7 +214,7 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
|
||||
ShapeUtil::MakeShape(F32, {32, 64, 7, 7}), input, {1}));
|
||||
|
||||
// Create HLO module, compile, and execute.
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
hlo_module->AddEntryComputation(builder.Build());
|
||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||
|
||||
@ -230,7 +230,7 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
|
||||
ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {}));
|
||||
|
||||
// Create HLO module, compile, and execute.
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
hlo_module->AddEntryComputation(builder.Build());
|
||||
LOG(INFO) << hlo_module->ToString();
|
||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||
@ -253,7 +253,7 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
|
||||
ShapeUtil::MakeShape(F32, {3, 3, 2, 2}), input, {2, 3}));
|
||||
|
||||
// Create HLO module, compile, and execute.
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
hlo_module->AddEntryComputation(builder.Build());
|
||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||
|
||||
@ -287,7 +287,7 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
|
||||
ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), input, {0, 1, 2}));
|
||||
|
||||
// Create HLO module, compile, and execute.
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
hlo_module->AddEntryComputation(builder.Build());
|
||||
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
|
||||
|
||||
|
@ -85,25 +85,6 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) {
|
||||
|
||||
} // namespace
|
||||
|
||||
Status VerifiedHloModule::Verify() {
|
||||
if (computation_count() == 0) {
|
||||
// The computation was never built. Nothing to verify.
|
||||
return Status::OK();
|
||||
}
|
||||
return verifier_.Run(this).status();
|
||||
}
|
||||
|
||||
void VerifiedHloModule::VerifyOrAddFailure(const string& message) {
|
||||
Status status = Verify();
|
||||
if (!status.ok()) {
|
||||
ADD_FAILURE() << "HloVerifier failed on module " << name()
|
||||
<< (message.empty() ? "" : absl::StrCat(" (", message, ")"))
|
||||
<< ": " << status;
|
||||
LOG(ERROR) << "Contents of bad module:";
|
||||
XLA_LOG_LINES(tensorflow::ERROR, ToString());
|
||||
}
|
||||
}
|
||||
|
||||
HloTestBase::HloTestBase(bool verifier_layout_sensitive,
|
||||
bool allow_mixed_precision_in_hlo_verifier,
|
||||
std::function<bool(const HloInstruction*)>
|
||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_layout.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
@ -39,33 +40,6 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// An HLO module derived class which verifies itself on destruction. This class
|
||||
// is intended to be used in unit tests. Any verification errors are raised via
|
||||
// ADD_FAILURE.
|
||||
class VerifiedHloModule : public HloModule {
|
||||
public:
|
||||
VerifiedHloModule(const string& name, const HloModuleConfig& config,
|
||||
bool verifier_layout_sensitive,
|
||||
bool allow_mixed_precision_in_hlo_verifier,
|
||||
std::function<int64(const Shape&)> shape_size_function)
|
||||
: HloModule(name, config),
|
||||
verifier_(
|
||||
verifier_layout_sensitive, allow_mixed_precision_in_hlo_verifier,
|
||||
/*instruction_can_change_layout_func=*/{}, shape_size_function) {}
|
||||
|
||||
~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); }
|
||||
|
||||
// Verifies the module using HloVerifier and returns the status.
|
||||
Status Verify();
|
||||
|
||||
// Verifies the module and flags any error with ADD_FAILURE. 'message' is
|
||||
// included in the failure message.
|
||||
void VerifyOrAddFailure(const string& message);
|
||||
|
||||
private:
|
||||
HloVerifier verifier_;
|
||||
};
|
||||
|
||||
// A base class for tests which build and/or run HLO code. The class includes
|
||||
// support for running an HLO module on two platforms and compare the results.
|
||||
// This is a lower level of abstraction than using the client interface and
|
||||
|
@ -16,16 +16,22 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/tests/local_client_test_base.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/platform/byte_order.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
@ -205,4 +211,21 @@ StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
|
||||
return std::move(ret);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<VerifiedHloModule>>
|
||||
LocalClientTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text) {
|
||||
return ParseAndReturnVerifiedModule(hlo_text, HloModuleConfig());
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<VerifiedHloModule>>
|
||||
LocalClientTestBase::ParseAndReturnVerifiedModule(
|
||||
absl::string_view hlo_text, const HloModuleConfig& config) {
|
||||
auto module = absl::make_unique<VerifiedHloModule>(
|
||||
TestName(), config, /*verifier_layout_sensitive=*/false,
|
||||
/*allow_mixed_precision_in_hlo_verifier=*/true,
|
||||
local_client_->backend().compiler()->ShapeSizeBytesFunction());
|
||||
TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get()));
|
||||
TF_RETURN_IF_ERROR(module->Verify());
|
||||
return std::move(module);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -20,16 +20,19 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
#include "tensorflow/compiler/xla/service/local_service.h"
|
||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
@ -109,6 +112,12 @@ class LocalClientTestBase : public ::testing::Test {
|
||||
const ExecutableBuildOptions& build_options,
|
||||
const ExecutableRunOptions& run_options);
|
||||
|
||||
// Parses the given string and returns module as a VerifiedHloModule.
|
||||
StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
|
||||
absl::string_view hlo_text);
|
||||
StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
|
||||
absl::string_view hlo_text, const HloModuleConfig& config);
|
||||
|
||||
// Returns a default set of execute options.
|
||||
ExecutableBuildOptions DefaultExecutableBuildOptions() const;
|
||||
|
||||
|
@ -75,7 +75,7 @@ XLA_TEST_F(TestUtilsTest, Token) {
|
||||
}
|
||||
|
||||
XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) {
|
||||
auto module = ParseAndReturnUnverifiedModule(
|
||||
auto module = ParseAndReturnVerifiedModule(
|
||||
R"(HloModule index_space_module
|
||||
|
||||
ENTRY IndexSpace {
|
||||
@ -103,7 +103,7 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) {
|
||||
}
|
||||
|
||||
XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) {
|
||||
auto module = ParseAndReturnUnverifiedModule(
|
||||
auto module = ParseAndReturnVerifiedModule(
|
||||
R"(HloModule index_space_module
|
||||
|
||||
ENTRY IndexSpace {
|
||||
@ -135,7 +135,7 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) {
|
||||
|
||||
XLA_TEST_F(TestUtilsTest, NoDuplicatesFloats) {
|
||||
// Inputs which are sort keys in key/value sorts should have no duplicates.
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule sort.148.1589
|
||||
|
||||
compare {
|
||||
@ -166,7 +166,7 @@ ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> (
|
||||
|
||||
XLA_TEST_F(TestUtilsTest, NoDuplicatesInt32) {
|
||||
// Inputs which are sort keys in key/value sorts should have no duplicates.
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule sort.148.1589
|
||||
|
||||
compare {
|
||||
@ -197,7 +197,7 @@ ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> (
|
||||
|
||||
XLA_TEST_F(TestUtilsTest, NoDuplicatesBfloat16) {
|
||||
// Inputs which are sort keys in key/value sorts should have no duplicates.
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule sort, is_scheduled=true
|
||||
|
||||
compare {
|
||||
@ -227,7 +227,7 @@ ENTRY %sort. (parameter.0: bf16[2,1452], parameter.1: s32[2,1452]) -> (bf16[2,14
|
||||
}
|
||||
|
||||
XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsR0InputToDynamicSlice) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule Test
|
||||
|
||||
ENTRY %module (parameter.0: s32[], parameter.1: f32[20,20]) -> f32[] {
|
||||
@ -255,7 +255,7 @@ ENTRY %module (parameter.0: s32[], parameter.1: f32[20,20]) -> f32[] {
|
||||
}
|
||||
|
||||
XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsForGather) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule Test
|
||||
|
||||
ENTRY %module(parameter.0: f32[200,100,300], parameter.1: s32[10,2]) ->
|
||||
@ -289,7 +289,7 @@ ENTRY %module(parameter.0: f32[200,100,300], parameter.1: s32[10,2]) ->
|
||||
}
|
||||
|
||||
XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsForScatter) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule Test
|
||||
|
||||
scatter_update (lhs: f32[], rhs: f32[]) -> f32[] {
|
||||
|
46
tensorflow/compiler/xla/tests/verified_hlo_module.cc
Normal file
46
tensorflow/compiler/xla/tests/verified_hlo_module.cc
Normal file
@ -0,0 +1,46 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
Status VerifiedHloModule::Verify() {
|
||||
if (computation_count() == 0) {
|
||||
// The computation was never built. Nothing to verify.
|
||||
return Status::OK();
|
||||
}
|
||||
return verifier_.Run(this).status();
|
||||
}
|
||||
|
||||
void VerifiedHloModule::VerifyOrAddFailure(const string& message) {
|
||||
Status status = Verify();
|
||||
if (!status.ok()) {
|
||||
ADD_FAILURE() << "HloVerifier failed on module " << name()
|
||||
<< (message.empty() ? "" : absl::StrCat(" (", message, ")"))
|
||||
<< ": " << status;
|
||||
LOG(ERROR) << "Contents of bad module:";
|
||||
XLA_LOG_LINES(tensorflow::ERROR, ToString());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xla
|
59
tensorflow/compiler/xla/tests/verified_hlo_module.h
Normal file
59
tensorflow/compiler/xla/tests/verified_hlo_module.h
Normal file
@ -0,0 +1,59 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_TESTS_VERIFIED_HLO_MODULE_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_TESTS_VERIFIED_HLO_MODULE_H_
|
||||
|
||||
#include <functional>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// An HLO module derived class which verifies itself on destruction. This class
|
||||
// is intended to be used in unit tests. Any verification errors are raised via
|
||||
// ADD_FAILURE.
|
||||
class VerifiedHloModule : public HloModule {
|
||||
public:
|
||||
VerifiedHloModule(const string& name, const HloModuleConfig& config,
|
||||
bool verifier_layout_sensitive,
|
||||
bool allow_mixed_precision_in_hlo_verifier,
|
||||
std::function<int64(const Shape&)> shape_size_function)
|
||||
: HloModule(name, config),
|
||||
verifier_(
|
||||
verifier_layout_sensitive, allow_mixed_precision_in_hlo_verifier,
|
||||
/*instruction_can_change_layout_func=*/{}, shape_size_function) {}
|
||||
|
||||
~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); }
|
||||
|
||||
// Verifies the module using HloVerifier and returns the status.
|
||||
Status Verify();
|
||||
|
||||
// Verifies the module and flags any error with ADD_FAILURE. 'message' is
|
||||
// included in the failure message.
|
||||
void VerifyOrAddFailure(const string& message);
|
||||
|
||||
private:
|
||||
HloVerifier verifier_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_TESTS_VERIFIED_HLO_MODULE_H_
|
Loading…
Reference in New Issue
Block a user