From 7eb5039543dab1d1186e5cc2ec875326d624f0d5 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 27 Jul 2020 09:04:58 -0700 Subject: [PATCH] [XLA] Refactor run_hlo_module to avoid depending on TensorFlow test utilities. PiperOrigin-RevId: 323371427 Change-Id: If0f7693752975722a511fa97c43cd3d60e22cdff --- tensorflow/compiler/xla/tools/BUILD | 6 +- .../compiler/xla/tools/run_hlo_module.cc | 62 ++++++++++++++++--- .../compiler/xla/tools/run_hlo_module.h | 3 +- .../compiler/xla/tools/run_hlo_module_main.cc | 6 +- 4 files changed, 60 insertions(+), 17 deletions(-) diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index b113b498e22..fc1ca7d3105 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -308,17 +308,18 @@ cc_library( ":prepare_reference_module", "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:error_spec", + "//tensorflow/compiler/xla:literal_comparison", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client/lib:testing", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:platform_util", - "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/core:lib", "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:path", "//tensorflow/core/platform:status", - "//tensorflow/core/platform:test", "//tensorflow/stream_executor:platform", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -339,6 +340,7 @@ tf_cc_binary( "//tensorflow/core:framework_internal", "//tensorflow/core/platform:logging", "//tensorflow/core/platform:platform_port", + "//tensorflow/core/platform:path", "//tensorflow/core/platform:status", "//tensorflow/core/platform:test", ] + if_cuda_or_rocm([ diff --git a/tensorflow/compiler/xla/tools/run_hlo_module.cc b/tensorflow/compiler/xla/tools/run_hlo_module.cc index 39b545af393..be9b23efb12 100644 --- a/tensorflow/compiler/xla/tools/run_hlo_module.cc +++ b/tensorflow/compiler/xla/tools/run_hlo_module.cc @@ -27,24 +27,66 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/testing.h" #include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/error_spec.h" +#include "tensorflow/compiler/xla/literal_comparison.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_runner.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/platform_util.h" -#include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/tools/hlo_module_loader.h" #include "tensorflow/compiler/xla/tools/prepare_reference_module.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/test.h" - -namespace se = ::stream_executor; namespace xla { namespace { +// Writes the given literal to a file in the test temporary directory. +void WriteLiteralToTempFile(const LiteralSlice& literal, const string& name) { + // Bazel likes for tests to write "debugging outputs" like these to + // TEST_UNDECLARED_OUTPUTS_DIR. This plays well with tools that inspect test + // results, especially when they're run on remote machines. + auto* env = tensorflow::Env::Default(); + string binary_filename; + string text_filename; + string outdir; + if (tensorflow::io::GetTestUndeclaredOutputsDir(&outdir)) { + string filename = tensorflow::io::JoinPath( + outdir, absl::StrFormat("tempfile-%d-%s", env->NowMicros(), name)); + binary_filename = absl::StrCat(filename, ".pb"); + text_filename = absl::StrCat(filename, ".txt"); + } else { + binary_filename = + tensorflow::io::GetTempFilename(absl::StrCat(name, ".pb")); + text_filename = tensorflow::io::GetTempFilename(absl::StrCat(name, ".txt")); + } + + TF_CHECK_OK( + tensorflow::WriteBinaryProto(env, binary_filename, literal.ToProto())); + TF_CHECK_OK( + tensorflow::WriteStringToFile(env, text_filename, literal.ToString())); + LOG(ERROR) << "wrote Literal to " << name << " binary: " << binary_filename + << " text: " << text_filename; +} + +// Callback helper that dumps literals to temporary files in the event of a +// miscomparison. +void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual, + const LiteralSlice& mismatches, + const ShapeIndex& /*shape_index*/) { + LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape()) << " " + << literal_comparison::ToStringTruncated(expected); + LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual.shape()) << " " + << literal_comparison::ToStringTruncated(actual); + LOG(INFO) << "Dumping literals to temp files..."; + WriteLiteralToTempFile(expected, "expected"); + WriteLiteralToTempFile(actual, "actual"); + WriteLiteralToTempFile(mismatches, "mismatches"); +} + Literal ExecuteOnPlatform(std::unique_ptr module, absl::Span args, se::Platform* platform, bool run_hlo_passes) { @@ -69,7 +111,7 @@ Literal ExecuteOnPlatform(std::unique_ptr module, } } // namespace -::testing::AssertionResult RunAndCompare( +Status RunAndCompare( const std::string& hlo_filename, const std::string& test_platform_name, const std::string& reference_platform_name, std::minstd_rand0* engine, const RunHloModuleOptions& options, @@ -122,7 +164,7 @@ Literal ExecuteOnPlatform(std::unique_ptr module, if (reference_module == nullptr) { std::cerr << "Skipping reference platform\n"; - return ::testing::AssertionSuccess(); + return Status::OK(); } Literal reference_result = @@ -136,10 +178,10 @@ Literal ExecuteOnPlatform(std::unique_ptr module, } ErrorSpec error_spec(static_cast(options.abs_error_bound), static_cast(options.rel_error_bound)); - return LiteralTestUtil::Near(/*expected=*/reference_result, - /*actual=*/test_result, - /*error_spec=*/error_spec, - /*detailed_message=*/true); + return literal_comparison::Near(/*expected=*/reference_result, + /*actual=*/test_result, + /*error=*/error_spec, + /*detailed_message=*/true, &OnMiscompare); } } // namespace xla diff --git a/tensorflow/compiler/xla/tools/run_hlo_module.h b/tensorflow/compiler/xla/tools/run_hlo_module.h index 932cc22f4dd..57f81cc7c94 100644 --- a/tensorflow/compiler/xla/tools/run_hlo_module.h +++ b/tensorflow/compiler/xla/tools/run_hlo_module.h @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/test.h" #include "tensorflow/stream_executor/platform.h" namespace xla { @@ -63,7 +62,7 @@ struct RunHloModuleOptions { // the results. 'reference_module_modifier_hook' can be used to transform the // HloModule before it is run on the reference platform. This may be necessary // to match the numerics of the test platform. -::testing::AssertionResult RunAndCompare( +Status RunAndCompare( const std::string& hlo_filename, const std::string& test_platform_name, const std::string& reference_platform_name, std::minstd_rand0* engine, const RunHloModuleOptions& options, diff --git a/tensorflow/compiler/xla/tools/run_hlo_module_main.cc b/tensorflow/compiler/xla/tools/run_hlo_module_main.cc index 39d7826e162..9d153491862 100644 --- a/tensorflow/compiler/xla/tools/run_hlo_module_main.cc +++ b/tensorflow/compiler/xla/tools/run_hlo_module_main.cc @@ -156,7 +156,7 @@ int main(int argc, char** argv) { if (iteration_count != 1) { std::cerr << "\n=== Iteration " << i << "\n"; } - ::testing::AssertionResult matched = + xla::Status matched = xla::RunAndCompare(hlo_filename, test_platform_name, reference_platform_name, &engine, opts); @@ -164,13 +164,13 @@ int main(int argc, char** argv) { // used. Without a reference, the test just verifies that nothing blew up // when running the module. if (!reference_platform_name.empty()) { - if (matched) { + if (matched.ok()) { // Success. std::cerr << "\n** Results on " << test_platform_name << " and " << reference_platform_name << " are close enough. **\n"; } else { failure_count++; - std::cerr << matched.message() << "\n"; + std::cerr << matched << "\n"; } } }