[XLA] Refactor run_hlo_module to avoid depending on TensorFlow test utilities.

PiperOrigin-RevId: 323371427
Change-Id: If0f7693752975722a511fa97c43cd3d60e22cdff
This commit is contained in:
Peter Hawkins 2020-07-27 09:04:58 -07:00 committed by TensorFlower Gardener
parent e29e1f4e57
commit 7eb5039543
4 changed files with 60 additions and 17 deletions

View File

@ -308,17 +308,18 @@ cc_library(
":prepare_reference_module",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:error_spec",
"//tensorflow/compiler/xla:literal_comparison",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/client/lib:testing",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_runner",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:path",
"//tensorflow/core/platform:status",
"//tensorflow/core/platform:test",
"//tensorflow/stream_executor:platform",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
@ -339,6 +340,7 @@ tf_cc_binary(
"//tensorflow/core:framework_internal",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:platform_port",
"//tensorflow/core/platform:path",
"//tensorflow/core/platform:status",
"//tensorflow/core/platform:test",
] + if_cuda_or_rocm([

View File

@ -27,24 +27,66 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/testing.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/error_spec.h"
#include "tensorflow/compiler/xla/literal_comparison.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_runner.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/tools/hlo_module_loader.h"
#include "tensorflow/compiler/xla/tools/prepare_reference_module.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h"
namespace se = ::stream_executor;
namespace xla {
namespace {
// Writes the given literal to a file in the test temporary directory.
void WriteLiteralToTempFile(const LiteralSlice& literal, const string& name) {
// Bazel likes for tests to write "debugging outputs" like these to
// TEST_UNDECLARED_OUTPUTS_DIR. This plays well with tools that inspect test
// results, especially when they're run on remote machines.
auto* env = tensorflow::Env::Default();
string binary_filename;
string text_filename;
string outdir;
if (tensorflow::io::GetTestUndeclaredOutputsDir(&outdir)) {
string filename = tensorflow::io::JoinPath(
outdir, absl::StrFormat("tempfile-%d-%s", env->NowMicros(), name));
binary_filename = absl::StrCat(filename, ".pb");
text_filename = absl::StrCat(filename, ".txt");
} else {
binary_filename =
tensorflow::io::GetTempFilename(absl::StrCat(name, ".pb"));
text_filename = tensorflow::io::GetTempFilename(absl::StrCat(name, ".txt"));
}
TF_CHECK_OK(
tensorflow::WriteBinaryProto(env, binary_filename, literal.ToProto()));
TF_CHECK_OK(
tensorflow::WriteStringToFile(env, text_filename, literal.ToString()));
LOG(ERROR) << "wrote Literal to " << name << " binary: " << binary_filename
<< " text: " << text_filename;
}
// Callback helper that dumps literals to temporary files in the event of a
// miscomparison.
void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual,
const LiteralSlice& mismatches,
const ShapeIndex& /*shape_index*/) {
LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape()) << " "
<< literal_comparison::ToStringTruncated(expected);
LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual.shape()) << " "
<< literal_comparison::ToStringTruncated(actual);
LOG(INFO) << "Dumping literals to temp files...";
WriteLiteralToTempFile(expected, "expected");
WriteLiteralToTempFile(actual, "actual");
WriteLiteralToTempFile(mismatches, "mismatches");
}
Literal ExecuteOnPlatform(std::unique_ptr<HloModule> module,
absl::Span<const Literal> args,
se::Platform* platform, bool run_hlo_passes) {
@ -69,7 +111,7 @@ Literal ExecuteOnPlatform(std::unique_ptr<HloModule> module,
}
} // namespace
::testing::AssertionResult RunAndCompare(
Status RunAndCompare(
const std::string& hlo_filename, const std::string& test_platform_name,
const std::string& reference_platform_name, std::minstd_rand0* engine,
const RunHloModuleOptions& options,
@ -122,7 +164,7 @@ Literal ExecuteOnPlatform(std::unique_ptr<HloModule> module,
if (reference_module == nullptr) {
std::cerr << "Skipping reference platform\n";
return ::testing::AssertionSuccess();
return Status::OK();
}
Literal reference_result =
@ -136,10 +178,10 @@ Literal ExecuteOnPlatform(std::unique_ptr<HloModule> module,
}
ErrorSpec error_spec(static_cast<float>(options.abs_error_bound),
static_cast<float>(options.rel_error_bound));
return LiteralTestUtil::Near(/*expected=*/reference_result,
/*actual=*/test_result,
/*error_spec=*/error_spec,
/*detailed_message=*/true);
return literal_comparison::Near(/*expected=*/reference_result,
/*actual=*/test_result,
/*error=*/error_spec,
/*detailed_message=*/true, &OnMiscompare);
}
} // namespace xla

View File

@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/stream_executor/platform.h"
namespace xla {
@ -63,7 +62,7 @@ struct RunHloModuleOptions {
// the results. 'reference_module_modifier_hook' can be used to transform the
// HloModule before it is run on the reference platform. This may be necessary
// to match the numerics of the test platform.
::testing::AssertionResult RunAndCompare(
Status RunAndCompare(
const std::string& hlo_filename, const std::string& test_platform_name,
const std::string& reference_platform_name, std::minstd_rand0* engine,
const RunHloModuleOptions& options,

View File

@ -156,7 +156,7 @@ int main(int argc, char** argv) {
if (iteration_count != 1) {
std::cerr << "\n=== Iteration " << i << "\n";
}
::testing::AssertionResult matched =
xla::Status matched =
xla::RunAndCompare(hlo_filename, test_platform_name,
reference_platform_name, &engine, opts);
@ -164,13 +164,13 @@ int main(int argc, char** argv) {
// used. Without a reference, the test just verifies that nothing blew up
// when running the module.
if (!reference_platform_name.empty()) {
if (matched) {
if (matched.ok()) {
// Success.
std::cerr << "\n** Results on " << test_platform_name << " and "
<< reference_platform_name << " are close enough. **\n";
} else {
failure_count++;
std::cerr << matched.message() << "\n";
std::cerr << matched << "\n";
}
}
}