Open source run_hlo_module.

PiperOrigin-RevId: 284166663
Change-Id: I395f6a0a8efeb60784bdcca4e5227f0ef470f6f7
This commit is contained in:
Adrian Kuegel 2019-12-06 05:30:25 -08:00 committed by TensorFlower Gardener
parent ba509fda80
commit 3d21ad0e16
6 changed files with 464 additions and 3 deletions

View File

@ -292,6 +292,53 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/core/platform:errors",
"//tensorflow/stream_executor:platform",
"//tensorflow/stream_executor/lib",
],
)
cc_library(
name = "run_hlo_module_lib",
testonly = True,
srcs = ["run_hlo_module.cc"],
hdrs = ["run_hlo_module.h"],
deps = [
":hlo_module_loader",
":prepare_reference_module",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:error_spec",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/client/lib:testing",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_runner",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:status",
"//tensorflow/core/platform:test",
"//tensorflow/stream_executor:platform",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
tf_cc_binary(
name = "run_hlo_module",
testonly = True,
srcs = ["run_hlo_module_main.cc"],
deps = [
":run_hlo_module_lib",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/compiler/xla/service:gpu_plugin",
"//tensorflow/compiler/xla/service:interpreter_plugin",
"//tensorflow/core:framework_internal",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:platform_port",
"//tensorflow/core/platform:status",
"//tensorflow/core/platform:test",
"@com_google_absl//absl/strings",
],
)

View File

@ -26,13 +26,17 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/platform.h"
namespace xla {
StatusOr<std::unique_ptr<HloModule>> PrepareReferenceModule(
const HloModule& test_module,
const ::stream_executor::Platform::Id& test_platform_id,
const std::function<void(HloModuleConfig*)>& config_modifier_hook,
const std::function<Status(HloModule*)>& module_modifier_hook) {
const std::function<Status(const HloModule&,
const ::stream_executor::Platform::Id&,
HloModule*)>& module_modifier_hook) {
DebugOptions debug_options = GetDebugOptionsFromFlags();
// The combination of fast math and optimizations leads to unsound code
// transformations (see third_party/tensorflow/compiler/xla/xla.proto for
@ -47,7 +51,8 @@ StatusOr<std::unique_ptr<HloModule>> PrepareReferenceModule(
std::unique_ptr<HloModule> reference_module =
test_module.Clone(reference_config, "reference");
if (module_modifier_hook) {
TF_RETURN_IF_ERROR(module_modifier_hook(reference_module.get()));
TF_RETURN_IF_ERROR(module_modifier_hook(test_module, test_platform_id,
reference_module.get()));
} else {
TF_RETURN_IF_ERROR(Despecializer().Run(reference_module.get()).status());
}

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/platform.h"
namespace xla {
@ -33,8 +34,11 @@ namespace xla {
// platforms.
StatusOr<std::unique_ptr<HloModule>> PrepareReferenceModule(
const HloModule& test_module,
const ::stream_executor::Platform::Id& test_platform_id,
const std::function<void(HloModuleConfig*)>& config_modifier_hook = {},
const std::function<Status(HloModule*)>& module_modifier_hook = {});
const std::function<Status(const HloModule&,
const ::stream_executor::Platform::Id&,
HloModule*)>& module_modifier_hook = {});
} // namespace xla

View File

@ -0,0 +1,145 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tools/run_hlo_module.h"
#include <functional>
#include <iostream>
#include <memory>
#include <random>
#include <string>
#include <utility>
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/lib/testing.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/error_spec.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_runner.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/tools/hlo_module_loader.h"
#include "tensorflow/compiler/xla/tools/prepare_reference_module.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h"
namespace se = ::stream_executor;
namespace xla {
namespace {
Literal ExecuteOnPlatform(std::unique_ptr<HloModule> module,
absl::Span<const Literal> args,
se::Platform* platform, bool run_hlo_passes) {
HloRunner runner(platform);
TF_QCHECK_OK(VerifyHloModule(module.get(), /*layout_sensitive=*/false,
/*allow_mixed_precision=*/true))
<< " (on " << platform->Name() << ")";
std::cerr << "Running HLO module on platform " << platform->Name() << "...\n";
XLA_VLOG_LINES(1, module->ToString());
const auto start = std::chrono::high_resolution_clock::now();
auto result_status = runner.Execute(std::move(module), args, run_hlo_passes);
const auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
std::cerr << "... compiled and ran in " << diff.count() << "s.\n";
TF_QCHECK_OK(result_status.status())
<< "Failed to execute on " << platform->Name() << "\n";
return result_status.ConsumeValueOrDie();
}
} // namespace
::testing::AssertionResult RunAndCompare(
const std::string& hlo_filename, const std::string& test_platform_name,
const std::string& reference_platform_name, std::minstd_rand0* engine,
const RunHloModuleOptions& options,
std::function<Status(const HloModule&,
const ::stream_executor::Platform::Id&, HloModule*)>
reference_module_modifier_hook) {
se::Platform* test_platform =
xla::PlatformUtil::GetPlatform(test_platform_name).ValueOrDie();
se::Platform* reference_platform =
reference_platform_name.empty()
? nullptr
: xla::PlatformUtil::GetPlatform(reference_platform_name)
.ValueOrDie();
auto config_modifier = [](HloModuleConfig* config) { config->set_seed(42); };
std::unique_ptr<HloModule> test_module =
LoadModuleFromFile(hlo_filename, hlo_module_loader_details::Config(),
options.input_format, config_modifier)
.ValueOrDie();
const HloModuleProto test_module_proto = test_module->ToProto();
std::vector<Literal> args = MakeFakeArguments(test_module.get(), engine,
options.use_large_float_range)
.ConsumeValueOrDie();
if (options.print_literals) {
for (int i = 0; i < args.size(); ++i) {
std::cout << "\n** Argument " << i << " **\n"
<< args[i].ToString() << "\n";
}
}
std::unique_ptr<HloModule> reference_module;
if (reference_platform != nullptr) {
// PrepareReferenceModule needs to know the *test* platform, in order to
// properly match the test platform's numerics.
reference_module =
PrepareReferenceModule(*test_module, test_platform->id(),
config_modifier, reference_module_modifier_hook)
.ConsumeValueOrDie();
}
Literal test_result = ExecuteOnPlatform(
std::move(test_module), args, test_platform, options.run_test_hlo_passes);
if (options.print_literals) {
std::cout << "\n** Result on test platform " << test_platform->Name()
<< " **\n"
<< test_result.ToString() << "\n";
}
if (reference_module == nullptr) {
std::cerr << "Skipping reference platform\n";
return ::testing::AssertionSuccess();
}
Literal reference_result =
ExecuteOnPlatform(std::move(reference_module), args, reference_platform,
options.run_reference_hlo_passes);
if (options.print_literals) {
std::cout << "\n** Result on reference platform "
<< reference_platform->Name() << " **\n"
<< reference_result.ToString() << "\n";
}
ErrorSpec error_spec(static_cast<float>(options.abs_error_bound),
static_cast<float>(options.rel_error_bound));
return LiteralTestUtil::Near(/*expected=*/reference_result,
/*actual=*/test_result,
/*error_spec=*/error_spec,
/*detailed_message=*/true);
}
} // namespace xla

View File

@ -0,0 +1,76 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_RUN_HLO_MODULE_H_
#define TENSORFLOW_COMPILER_XLA_TOOLS_RUN_HLO_MODULE_H_
#include <functional>
#include <random>
#include <string>
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/stream_executor/platform.h"
namespace xla {
// Command-line options to this tool. See main() in run_hlo_module_main.cc for
// descriptions of these fields.
struct RunHloModuleOptions {
RunHloModuleOptions()
: platform(""),
reference_platform("default"),
print_literals(false),
run_test_hlo_passes(true),
run_reference_hlo_passes(true),
use_large_float_range(true),
// TODO(b/68721786): These tolerances are set to match the values in the
// isolation test. The goal is to lower these to 0.001.
abs_error_bound(0.1),
rel_error_bound(0.1),
input_format("hlo"),
input_module(""),
iterations(1) {}
std::string platform;
std::string reference_platform;
bool print_literals;
bool run_test_hlo_passes;
bool run_reference_hlo_passes;
bool use_large_float_range;
float abs_error_bound;
float rel_error_bound;
std::string input_format;
std::string input_module;
int iterations;
};
// Reads a HloModule from 'hlo_filename', runs it on the platform with the name
// 'test_platform_name', and if 'reference_platform_name' is non-empty, it also
// runs it on the platform with the name 'reference_platform_name' and compares
// the results. 'reference_module_modifier_hook' can be used to transform the
// HloModule before it is run on the reference platform. This may be necessary
// to match the numerics of the test platform.
::testing::AssertionResult RunAndCompare(
const std::string& hlo_filename, const std::string& test_platform_name,
const std::string& reference_platform_name, std::minstd_rand0* engine,
const RunHloModuleOptions& options,
std::function<Status(const HloModule&,
const ::stream_executor::Platform::Id&, HloModule*)>
reference_module_modifier_hook = {});
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TOOLS_RUN_HLO_MODULE_H_

View File

@ -0,0 +1,184 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// A tool for reading a HloModule from a HloProto file and execute the module on
// given platform(s). See kUsage for details.
#include <iostream>
#include <random>
#include <string>
#include <vector>
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/tools/run_hlo_module.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace {
const char* const kUsage = R"(
This tool lets you read a HloModule from a file and execute the module on given
platform.
The file can be one of the followings:
1) a binary or text proto file, the proto should be in xla.HloProto type.
2) a hlo text dump, the string should be in HloModule::ToString() format.
By default, the module is run on a reference platform such as the interpreter
and the reference result is compared against the test result.
You can also pass in debug option flags for the HloModule.
Usage:
bazel run run_hlo_module -- \
--input_format=[hlo|pb|pbtxt] \
--platform=[CPU|CUDA|Interpreter] \
path/to/hlo_module
)";
const char kInterpreterPlatformName[] = "Interpreter";
// Returns the name of the test platform.
std::string GetTestPlatformName(std::string name) {
QCHECK(!name.empty()) << "Must pass --platform flag.";
return name;
}
// Returns the name of the reference platform
std::string GetReferencePlatformName(std::string reference_platform) {
if (reference_platform == "default") {
return kInterpreterPlatformName;
}
return reference_platform;
}
} // namespace
int main(int argc, char** argv) {
xla::RunHloModuleOptions opts;
std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag(
"platform", &opts.platform,
"The test platform that the HLO module will be executed on "
"(gpu, cpu, etc)."),
tensorflow::Flag(
"reference_platform", &opts.reference_platform,
"The reference platform that HLO module will be "
"executed on. The result produced on the reference platform will "
"be compared against the result produced on the test platform. A "
"value of 'default' will use the TPU_Interpreter as a reference if "
"the test platform is a TPU, and 'interpreter' otherwise. If the "
"flag value is the empty string, then the module will not be run "
"on a reference platform at all."),
tensorflow::Flag("print_literals", &opts.print_literals,
"Print the input and result literals to stdout."),
tensorflow::Flag(
"run_test_hlo_passes", &opts.run_test_hlo_passes,
"Run HLO pass pipeline for the test platform on the HLO module "
"before running the module on the test platform. This should be "
"set to true if the HLO module is unoptimized and set to false if "
"the HLO module already has been optimized."),
tensorflow::Flag(
"run_reference_hlo_passes", &opts.run_reference_hlo_passes,
"Run HLO pass pipeline for the reference platform on the HLO module "
"before running the module on the reference platform. "
"In general, if the given HLO module was optimized for a platform "
"other "
"than the reference this is necessary because some HLO passes are "
"legalization passes which must be run prior to code generation."),
tensorflow::Flag(
"use_large_float_range", &opts.use_large_float_range,
"Generate floating point values using a large uniform-log "
"distribtion as opposed to a small uniform distribution."),
tensorflow::Flag(
"abs_error_bound", &opts.abs_error_bound,
"The absolute error bound used when comparing the test and "
"reference results."),
tensorflow::Flag(
"rel_error_bound", &opts.rel_error_bound,
"The relative error bound used when comparing the test and "
"reference results."),
tensorflow::Flag("input_format", &opts.input_format,
"The format of the input file. Valid values:\n"
" hlo : HLO textual format\n"
" pb : xla::HloProto in binary proto format\n"
" pbtxt : xla::HloProto in text proto format"),
tensorflow::Flag(
"input_module", &opts.input_module,
"A path to a file containing the HLO module. Can also pass "
"a this as argv[1], but this flag is more explicit."),
tensorflow::Flag(
"iterations", &opts.iterations,
"The number of times to run the module. Each iteration will be run "
"with different input data.")};
xla::AppendDebugOptionsFlags(&flag_list);
// The usage string includes the message at the top of the file, the
// DebugOptions flags and the flags defined above.
const std::string kUsageString = absl::StrCat(
kUsage, "\n\n", tensorflow::Flags::Usage(argv[0], flag_list));
bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
tensorflow::port::InitMain(kUsageString.c_str(), &argc, &argv);
if (!parse_ok) {
LOG(QFATAL) << kUsageString;
}
const std::string test_platform_name = GetTestPlatformName(opts.platform);
const std::string reference_platform_name =
GetReferencePlatformName(opts.reference_platform);
std::string hlo_filename;
if (!opts.input_module.empty()) {
hlo_filename = opts.input_module;
} else {
QCHECK(argc == 2) << "Must specify a single input file";
hlo_filename = argv[1];
}
std::minstd_rand0 engine;
int failure_count = 0;
const int iteration_count = opts.iterations;
for (int i = 1; i <= iteration_count; ++i) {
if (iteration_count != 1) {
std::cerr << "\n=== Iteration " << i << "\n";
}
::testing::AssertionResult matched =
xla::RunAndCompare(hlo_filename, test_platform_name,
reference_platform_name, &engine, opts);
// The AssertionResult is only meaningful when the reference is
// used. Without a reference, the test just verifies that nothing blew up
// when running the module.
if (!reference_platform_name.empty()) {
if (matched) {
// Success.
std::cerr << "\n** Results on " << test_platform_name << " and "
<< reference_platform_name << " are close enough. **\n";
} else {
failure_count++;
std::cerr << matched.message() << "\n";
}
}
}
if (!reference_platform_name.empty()) {
std::cerr << failure_count << "/" << iteration_count
<< " runs miscompared.\n";
}
return failure_count == 0 ? 0 : -1;
}