[XLA] Refactor run_hlo_module to avoid depending on TensorFlow test utilities.
PiperOrigin-RevId: 323371427 Change-Id: If0f7693752975722a511fa97c43cd3d60e22cdff
This commit is contained in:
parent
e29e1f4e57
commit
7eb5039543
@ -308,17 +308,18 @@ cc_library(
|
||||
":prepare_reference_module",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla:error_spec",
|
||||
"//tensorflow/compiler/xla:literal_comparison",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/client/lib:testing",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_runner",
|
||||
"//tensorflow/compiler/xla/service:hlo_verifier",
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/core/platform:path",
|
||||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/core/platform:test",
|
||||
"//tensorflow/stream_executor:platform",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
@ -339,6 +340,7 @@ tf_cc_binary(
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/core/platform:platform_port",
|
||||
"//tensorflow/core/platform:path",
|
||||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/core/platform:test",
|
||||
] + if_cuda_or_rocm([
|
||||
|
@ -27,24 +27,66 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/lib/testing.h"
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/error_spec.h"
|
||||
#include "tensorflow/compiler/xla/literal_comparison.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_runner.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
#include "tensorflow/compiler/xla/tools/hlo_module_loader.h"
|
||||
#include "tensorflow/compiler/xla/tools/prepare_reference_module.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace se = ::stream_executor;
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
// Writes the given literal to a file in the test temporary directory.
|
||||
void WriteLiteralToTempFile(const LiteralSlice& literal, const string& name) {
|
||||
// Bazel likes for tests to write "debugging outputs" like these to
|
||||
// TEST_UNDECLARED_OUTPUTS_DIR. This plays well with tools that inspect test
|
||||
// results, especially when they're run on remote machines.
|
||||
auto* env = tensorflow::Env::Default();
|
||||
string binary_filename;
|
||||
string text_filename;
|
||||
string outdir;
|
||||
if (tensorflow::io::GetTestUndeclaredOutputsDir(&outdir)) {
|
||||
string filename = tensorflow::io::JoinPath(
|
||||
outdir, absl::StrFormat("tempfile-%d-%s", env->NowMicros(), name));
|
||||
binary_filename = absl::StrCat(filename, ".pb");
|
||||
text_filename = absl::StrCat(filename, ".txt");
|
||||
} else {
|
||||
binary_filename =
|
||||
tensorflow::io::GetTempFilename(absl::StrCat(name, ".pb"));
|
||||
text_filename = tensorflow::io::GetTempFilename(absl::StrCat(name, ".txt"));
|
||||
}
|
||||
|
||||
TF_CHECK_OK(
|
||||
tensorflow::WriteBinaryProto(env, binary_filename, literal.ToProto()));
|
||||
TF_CHECK_OK(
|
||||
tensorflow::WriteStringToFile(env, text_filename, literal.ToString()));
|
||||
LOG(ERROR) << "wrote Literal to " << name << " binary: " << binary_filename
|
||||
<< " text: " << text_filename;
|
||||
}
|
||||
|
||||
// Callback helper that dumps literals to temporary files in the event of a
|
||||
// miscomparison.
|
||||
void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual,
|
||||
const LiteralSlice& mismatches,
|
||||
const ShapeIndex& /*shape_index*/) {
|
||||
LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape()) << " "
|
||||
<< literal_comparison::ToStringTruncated(expected);
|
||||
LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual.shape()) << " "
|
||||
<< literal_comparison::ToStringTruncated(actual);
|
||||
LOG(INFO) << "Dumping literals to temp files...";
|
||||
WriteLiteralToTempFile(expected, "expected");
|
||||
WriteLiteralToTempFile(actual, "actual");
|
||||
WriteLiteralToTempFile(mismatches, "mismatches");
|
||||
}
|
||||
|
||||
Literal ExecuteOnPlatform(std::unique_ptr<HloModule> module,
|
||||
absl::Span<const Literal> args,
|
||||
se::Platform* platform, bool run_hlo_passes) {
|
||||
@ -69,7 +111,7 @@ Literal ExecuteOnPlatform(std::unique_ptr<HloModule> module,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
::testing::AssertionResult RunAndCompare(
|
||||
Status RunAndCompare(
|
||||
const std::string& hlo_filename, const std::string& test_platform_name,
|
||||
const std::string& reference_platform_name, std::minstd_rand0* engine,
|
||||
const RunHloModuleOptions& options,
|
||||
@ -122,7 +164,7 @@ Literal ExecuteOnPlatform(std::unique_ptr<HloModule> module,
|
||||
|
||||
if (reference_module == nullptr) {
|
||||
std::cerr << "Skipping reference platform\n";
|
||||
return ::testing::AssertionSuccess();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Literal reference_result =
|
||||
@ -136,10 +178,10 @@ Literal ExecuteOnPlatform(std::unique_ptr<HloModule> module,
|
||||
}
|
||||
ErrorSpec error_spec(static_cast<float>(options.abs_error_bound),
|
||||
static_cast<float>(options.rel_error_bound));
|
||||
return LiteralTestUtil::Near(/*expected=*/reference_result,
|
||||
/*actual=*/test_result,
|
||||
/*error_spec=*/error_spec,
|
||||
/*detailed_message=*/true);
|
||||
return literal_comparison::Near(/*expected=*/reference_result,
|
||||
/*actual=*/test_result,
|
||||
/*error=*/error_spec,
|
||||
/*detailed_message=*/true, &OnMiscompare);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -22,7 +22,6 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
|
||||
namespace xla {
|
||||
@ -63,7 +62,7 @@ struct RunHloModuleOptions {
|
||||
// the results. 'reference_module_modifier_hook' can be used to transform the
|
||||
// HloModule before it is run on the reference platform. This may be necessary
|
||||
// to match the numerics of the test platform.
|
||||
::testing::AssertionResult RunAndCompare(
|
||||
Status RunAndCompare(
|
||||
const std::string& hlo_filename, const std::string& test_platform_name,
|
||||
const std::string& reference_platform_name, std::minstd_rand0* engine,
|
||||
const RunHloModuleOptions& options,
|
||||
|
@ -156,7 +156,7 @@ int main(int argc, char** argv) {
|
||||
if (iteration_count != 1) {
|
||||
std::cerr << "\n=== Iteration " << i << "\n";
|
||||
}
|
||||
::testing::AssertionResult matched =
|
||||
xla::Status matched =
|
||||
xla::RunAndCompare(hlo_filename, test_platform_name,
|
||||
reference_platform_name, &engine, opts);
|
||||
|
||||
@ -164,13 +164,13 @@ int main(int argc, char** argv) {
|
||||
// used. Without a reference, the test just verifies that nothing blew up
|
||||
// when running the module.
|
||||
if (!reference_platform_name.empty()) {
|
||||
if (matched) {
|
||||
if (matched.ok()) {
|
||||
// Success.
|
||||
std::cerr << "\n** Results on " << test_platform_name << " and "
|
||||
<< reference_platform_name << " are close enough. **\n";
|
||||
} else {
|
||||
failure_count++;
|
||||
std::cerr << matched.message() << "\n";
|
||||
std::cerr << matched << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user