[XLA] Make hlo_test_base support running and comparing an hlo module.

Also,
- Add examples that shows how to use the hlo_test_base as to create text/file-based testcases.
- Change the behavior of GetDefaultPlatform: when only one platform is present, returns that one; when two platforms are present and one of them is the interpreter, returns the other one. This is because some tests included both hlo_test_base and client_library_test_base, but now the hlo_test_base is linked with interpreter by default, which makes client_library_test_base fail getting the default platform.

PiperOrigin-RevId: 178309022
This commit is contained in:
A. Unique TensorFlower 2017-12-07 15:53:53 -08:00 committed by TensorFlower Gardener
parent 7b0458a789
commit 029109b4e1
12 changed files with 697 additions and 44 deletions

View File

@ -39,6 +39,14 @@ namespace se = ::perftools::gputools;
namespace xla {
/*static*/ StatusOr<std::unique_ptr<HloModule>>
HloRunner::CreateModuleFromString(const tensorflow::StringPiece hlo_string,
const DebugOptions& debug_options) {
HloModuleConfig config;
config.set_debug_options(debug_options);
return tools::Parse(hlo_string, config);
}
/*static*/ StatusOr<std::unique_ptr<HloModule>>
HloRunner::ReadModuleFromHloProtoFile(const std::string& filename,
const DebugOptions& debug_options) {

View File

@ -35,7 +35,8 @@ namespace xla {
// A base class for running an HloModule. This executes the given HloModule on a
// certain backend directly without using the client interface. HloModule can be
// explicitly built, or loaded from a serialization file (e.g., hlo proto file).
// explicitly built, or loaded from a serialization file (e.g., hlo proto
// file), or parsed from a hlo textual IR string.
class HloRunner {
public:
HloRunner();
@ -44,6 +45,12 @@ class HloRunner {
~HloRunner();
// Converts an HloModule from the given hlo textual IR string (in
// HloModule::ToString format).
static StatusOr<std::unique_ptr<HloModule>> CreateModuleFromString(
const tensorflow::StringPiece hlo_string,
const DebugOptions& debug_options);
// Reads the proto file in xla.HloProto format, creates and returns the
// HloModule. Will try to parse the filename as binary proto, then try as
// text proto if that fails.
@ -65,7 +72,8 @@ class HloRunner {
// Executes the given module with given literals as input and returns the
// result as a Literal. The LiteralPtr type accepts Literal* or
// std::unique_ptr<Literal>.
// If run_hlo_passes is true, the module will be executed without Hlo
//
// If run_hlo_passes is false, the module will be executed without Hlo
// optimization.
template <typename LiteralPtr>
StatusOr<std::unique_ptr<Literal>> Execute(

View File

@ -33,10 +33,32 @@ namespace se = ::perftools::gputools;
namespace xla {
using tensorflow::str_util::Lowercase;
// Minimum supported CUDA compute capability is 3.5.
constexpr int kMinCudaComputeCapabilityMajor = 3;
constexpr int kMinCudaComputeCapabilityMinor = 5;
// The name of the interpreter platform.
constexpr char kInterpreter[] = "interpreter";
namespace {
string CanonicalPlatformName(const string& name) {
string platform_str = Lowercase(name);
// "cpu" and "host" mean the same thing.
if (platform_str == "cpu") {
platform_str = "host";
}
// "gpu" and "cuda" mean the same thing.
if (platform_str == "gpu") {
platform_str = "cuda";
}
return platform_str;
}
} // namespace
/* static */ StatusOr<std::vector<se::Platform*>>
PlatformUtil::GetSupportedPlatforms() {
se::MultiPlatformManager::PlatformMap platform_map;
@ -78,7 +100,7 @@ PlatformUtil::GetSupportedPlatforms() {
return platforms;
}
/* static */ StatusOr<se::Platform*> PlatformUtil::GetDefaultPlatform() {
/* static */ StatusOr<se::Platform*> PlatformUtil::GetSolePlatform() {
TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms());
if (platforms.empty()) {
return NotFound("no platforms found");
@ -87,26 +109,42 @@ PlatformUtil::GetSupportedPlatforms() {
}
// Multiple platforms present and we can't pick a reasonable default.
auto l = [](string* out, const se::Platform* p) { out->append(p->Name()); };
string platforms_string = tensorflow::str_util::Join(platforms, ", ", l);
string platforms_string = tensorflow::str_util::Join(
platforms, ", ",
[](string* out, const se::Platform* p) { out->append(p->Name()); });
return InvalidArgument(
"must specify platform because more than one platform found: %s",
platforms_string.c_str());
}
/*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatform(
const string& platform_name) {
using tensorflow::str_util::Lowercase;
string platform_str = Lowercase(platform_name);
// "cpu" and "host" mean the same thing.
if (platform_str == "cpu") {
platform_str = "host";
}
// "gpu" and "cuda" mean the same thing.
if (platform_str == "gpu") {
platform_str = "cuda";
/* static */ StatusOr<se::Platform*> PlatformUtil::GetDefaultPlatform() {
TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms());
if (platforms.empty()) {
return NotFound("no platforms found");
} else if (platforms.size() == 1) {
return platforms[0];
} else if (platforms.size() == 2) {
for (int i = 0; i < 2; i++) {
if (Lowercase(platforms[i]->Name()) == kInterpreter &&
Lowercase(platforms[1 - i]->Name()) != kInterpreter) {
return platforms[1 - i];
}
}
}
// Multiple platforms present and we can't pick a reasonable default.
string platforms_string = tensorflow::str_util::Join(
platforms, ", ",
[](string* out, const se::Platform* p) { out->append(p->Name()); });
return InvalidArgument(
"must specify platform because more than one platform (except for the "
"interpreter platform) found: %s",
platforms_string.c_str());
}
/*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatform(
const string& platform_name) {
string platform_str = CanonicalPlatformName(platform_name);
TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms());
for (se::Platform* platform : platforms) {
if (Lowercase(platform->Name()) == platform_str) {
@ -116,6 +154,32 @@ PlatformUtil::GetSupportedPlatforms() {
return InvalidArgument("platform %s not found", platform_name.c_str());
}
/*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatformExceptFor(
const string& platform_name) {
string platform_str = CanonicalPlatformName(platform_name);
TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms());
std::vector<se::Platform*> matched;
for (se::Platform* platform : platforms) {
if (Lowercase(platform->Name()) != platform_name) {
matched.push_back(platform);
}
}
if (matched.empty()) {
return InvalidArgument("unable to find platform that is not %s",
platform_name.c_str());
}
if (matched.size() == 1) {
return matched[0];
}
string matched_string = tensorflow::str_util::Join(
matched, ", ",
[](string* out, const se::Platform* p) { out->append(p->Name()); });
return InvalidArgument(
"found multiple platforms %s, but expected one platform except for %s",
matched_string.c_str(), platform_name.c_str());
}
// Returns whether the device underlying the given StreamExecutor is supported
// by XLA.
static bool IsDeviceSupported(se::StreamExecutor* executor) {

View File

@ -37,16 +37,28 @@ class PlatformUtil {
static StatusOr<std::vector<perftools::gputools::Platform*>>
GetSupportedPlatforms();
// Convenience function which returns the default supported platform. If
// Convenience function which returns the default supported platform for
// tests. If exactly one supported platform is present, then this platform is
// the default platform. If exactly two platforms are present and one of them
// is the interpreter platform, then the other platform is the default
// platform. Otherwise returns an error.
static StatusOr<perftools::gputools::Platform*> GetDefaultPlatform();
// Convenience function which returns the sole supported platform. If
// exactly one supported platform is present, then this platform is the
// default platform. Otherwise returns an error.
static StatusOr<perftools::gputools::Platform*> GetDefaultPlatform();
static StatusOr<perftools::gputools::Platform*> GetSolePlatform();
// Returns the platform according to the given name. Returns error if there is
// no such platform.
static StatusOr<perftools::gputools::Platform*> GetPlatform(
const string& platform_name);
// Returns exactly one platform that does not have given name. Returns error
// if there is no such platform, or there are multiple such platforms.
static StatusOr<perftools::gputools::Platform*> GetPlatformExceptFor(
const string& platform_name);
// Returns a vector of StreamExecutors for the given platform. The vector is
// indexed by device ordinal (device numbering used by StreamExecutor). If an
// element is nullptr, then the device is present by not supported by XLA.

View File

@ -105,7 +105,9 @@ cc_library(
hdrs = ["hlo_test_base.h"],
deps = [
":literal_test_util",
":test_utils",
"//tensorflow/compiler/xla:shape_layout",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
@ -115,6 +117,9 @@ cc_library(
"//tensorflow/compiler/xla/service:computation_layout",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_runner",
"//tensorflow/compiler/xla/service:interpreter_plugin", # reference backend
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:test",
@ -1678,6 +1683,45 @@ xla_test(
],
)
# A demo of textual IR based test.
xla_test(
name = "sample_text_test",
srcs = ["sample_text_test.cc"],
# You can leave this empty if you want to test all supported backends.
backends = [
"cpu",
"gpu",
],
deps = [
":hlo_test_base",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
],
)
# A demo of test that loads an hlo module from a file and compares results on gpu and cpu.
tf_cc_test(
name = "sample_file_test",
srcs = ["sample_file_test.cc"],
data = ["isolated_convolution.hlo"],
tags = ["requires-gpu-sm35"],
deps = [
":hlo_test_base",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:cpu_plugin", # reference backend
"//tensorflow/compiler/xla/service:gpu_plugin", # test backend
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
)
# -----------------------------------------------------------------------------
filegroup(

View File

@ -15,13 +15,22 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include <memory>
#include <set>
#include <string>
#include <utility>
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@ -30,18 +39,72 @@ namespace se = ::perftools::gputools;
namespace xla {
namespace {
using tensorflow::StringPiece;
using tensorflow::gtl::ArraySlice;
using tensorflow::gtl::optional;
constexpr char kInterpreter[] = "interpreter";
// Helper functions to get test and reference platforms.
se::Platform* GetReferencePlatform() {
auto result = PlatformUtil::GetPlatform(kInterpreter);
TF_CHECK_OK(result.status()) << "could not get interpreter platform";
return result.ValueOrDie();
}
se::Platform* GetTestPlatform() {
auto result = PlatformUtil::GetDefaultPlatform();
TF_CHECK_OK(result.status()) << "could not get test platform";
return result.ValueOrDie();
}
bool ProgramShapesEqual(const ProgramShape& lhs, const ProgramShape& rhs) {
if (lhs.parameters_size() != rhs.parameters_size()) {
return false;
}
for (int i = 0; i < lhs.parameters_size(); i++) {
if (!ShapeUtil::Equal(lhs.parameters(i), rhs.parameters(i))) {
return false;
}
}
return ShapeUtil::Equal(lhs.result(), rhs.result());
}
ProgramShape GetProgramShapeWithLayout(const HloModule& module) {
ProgramShape program_shape;
const auto* entry = module.entry_computation();
for (const auto* param : entry->parameter_instructions()) {
*program_shape.add_parameters() = param->shape();
*program_shape.add_parameter_names() = param->name();
}
*program_shape.mutable_result() = entry->root_instruction()->shape();
return program_shape;
}
} // namespace
HloTestBase::HloTestBase()
: HloTestBase(GetTestPlatform(), GetReferencePlatform()) {}
HloTestBase::HloTestBase(se::Platform* test_platform,
se::Platform* reference_platform)
: test_runner_(test_platform), reference_runner_(reference_platform) {}
/* static */
std::unique_ptr<HloModule> HloTestBase::CreateNewModule() {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
return MakeUnique<HloModule>(TestName(), VersionedComputationHandle(),
config);
}
/*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() {
auto debug_options = legacy_flags::GetDebugOptionsFromFlags();
// TODO(b/38354253): Change tests to use Parameters instead of Constants.
debug_options.add_xla_disable_hlo_passes("constant_folding");
config.set_debug_options(debug_options);
return MakeUnique<HloModule>(TestName(), VersionedComputationHandle(),
config);
return debug_options;
}
StatusOr<perftools::gputools::DeviceMemoryBase> HloTestBase::Execute(
@ -49,25 +112,168 @@ StatusOr<perftools::gputools::DeviceMemoryBase> HloTestBase::Execute(
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
arguments,
Shape* result_shape) {
return runner_.Execute(std::move(module), arguments, result_shape);
return test_runner_.Execute(std::move(module), arguments, result_shape);
}
se::DeviceMemoryBase HloTestBase::TransferToDevice(const Literal& literal) {
return runner_.TransferToDevice(literal).ValueOrDie();
return test_runner_.TransferToDevice(literal).ValueOrDie();
}
std::unique_ptr<Literal> HloTestBase::TransferFromDevice(
const Shape& shape, se::DeviceMemoryBase device_base) {
return runner_.TransferFromDevice(shape, device_base).ValueOrDie();
return test_runner_.TransferFromDevice(shape, device_base).ValueOrDie();
}
std::unique_ptr<Literal> HloTestBase::ExecuteAndTransfer(
std::unique_ptr<HloModule> module,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
return runner_.ExecuteAndTransfer(std::move(module), arguments).ValueOrDie();
return test_runner_.ExecuteAndTransfer(std::move(module), arguments)
.ValueOrDie();
}
Backend& HloTestBase::backend() { return runner_.backend(); }
StatusOr<std::unique_ptr<HloModule>> HloTestBase::MakeReferenceModule(
const HloModule& test_module,
const std::function<void(HloModule*)>& reference_preprocessor) {
std::unique_ptr<HloModule> reference_module = test_module.Clone();
const auto& program_shape = GetProgramShapeWithLayout(test_module);
if (reference_preprocessor != nullptr) {
reference_preprocessor(reference_module.get());
if (!ProgramShapesEqual(program_shape,
GetProgramShapeWithLayout(*reference_module))) {
return InvalidArgument(
"reference preprocessor must not modify the program shape");
}
}
TF_RETURN_IF_ERROR(VerifyHloModule(*reference_runner_.backend().platform(),
reference_module.get()));
return std::move(reference_module);
}
template <typename LiteralPtr>
StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
std::unique_ptr<HloModule> module, const ArraySlice<LiteralPtr> arguments,
const optional<ErrorSpec>& error, bool run_hlo_passes,
const std::function<void(HloModule*)>& reference_preprocessor) {
static_assert(
std::is_same<Literal*, LiteralPtr>::value ||
std::is_same<std::unique_ptr<Literal>, LiteralPtr>::value,
"The LiteralPtr type only accepts Literal* or std::unique_ptr<Literal>.");
TF_RETURN_IF_ERROR(
VerifyHloModule(*test_runner_.backend().platform(), module.get()));
TF_ASSIGN_OR_RETURN(auto reference_module,
MakeReferenceModule(*module, reference_preprocessor));
// Execute on two backends.
TF_ASSIGN_OR_RETURN(
auto test,
test_runner_.Execute(std::move(module), arguments, run_hlo_passes));
TF_ASSIGN_OR_RETURN(auto reference,
reference_runner_.Execute(std::move(reference_module),
arguments, run_hlo_passes));
return LiteralTestUtil::NearOrEqual(/*expected=*/*reference, /*actual=*/*test,
error);
}
template <typename LiteralPtr>
::testing::AssertionResult HloTestBase::RunAndCompare(
std::unique_ptr<HloModule> module, const ArraySlice<LiteralPtr> arguments,
const optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor) {
auto result =
RunAndCompareInternal(std::move(module), arguments, error,
/*run_hlo_passes=*/true, reference_preprocessor);
if (!result.ok()) {
return ::testing::AssertionFailure() << result.status();
}
return result.ValueOrDie();
}
template <typename LiteralPtr>
::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
std::unique_ptr<HloModule> module, const ArraySlice<LiteralPtr> arguments,
const optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor) {
auto result =
RunAndCompareInternal(std::move(module), arguments, error,
/*run_hlo_passes=*/false, reference_preprocessor);
if (!result.ok()) {
return ::testing::AssertionFailure() << result.status();
}
return result.ValueOrDie();
}
::testing::AssertionResult HloTestBase::RunAndCompare(
std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor) {
const auto& fake_arguments =
MakeFakeArguments(module.get()).ConsumeValueOrDie();
return RunAndCompare<std::unique_ptr<Literal>>(
std::move(module), fake_arguments, error, reference_preprocessor);
}
::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor) {
const auto& fake_arguments =
MakeFakeArguments(module.get()).ConsumeValueOrDie();
return RunAndCompareNoHloPasses<std::unique_ptr<Literal>>(
std::move(module), fake_arguments, error, reference_preprocessor);
}
::testing::AssertionResult HloTestBase::RunAndCompare(
const StringPiece hlo_string,
const tensorflow::gtl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor) {
auto module_or_status =
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
if (!module_or_status.ok()) {
return ::testing::AssertionFailure() << "failed parsing hlo textual IR";
}
return RunAndCompare(module_or_status.ConsumeValueOrDie(), error,
reference_preprocessor);
}
::testing::AssertionResult HloTestBase::RunAndCompareFromFile(
const string& filename, const tensorflow::gtl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor) {
auto module_or_status =
HloRunner::ReadModule(filename, GetDebugOptionsForTest());
if (!module_or_status.ok()) {
return ::testing::AssertionFailure()
<< "failed reading hlo module from file";
}
return RunAndCompare(module_or_status.ConsumeValueOrDie(), error,
reference_preprocessor);
}
::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
const StringPiece hlo_string,
const tensorflow::gtl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor) {
auto module_or_status =
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
if (!module_or_status.ok()) {
return ::testing::AssertionFailure() << "failed parsing hlo textual IR";
}
return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error,
reference_preprocessor);
}
::testing::AssertionResult HloTestBase::RunAndCompareNoHloPassesFromFile(
const string& filename, const tensorflow::gtl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor) {
auto module_or_status =
HloRunner::ReadModule(filename, GetDebugOptionsForTest());
if (!module_or_status.ok()) {
return ::testing::AssertionFailure()
<< "failed reading hlo module from file";
}
return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error,
reference_preprocessor);
}
Backend& HloTestBase::backend() { return test_runner_.backend(); }
/* static */
string HloTestBase::TestName() {

View File

@ -24,31 +24,74 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_runner.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
// A base class for tests which build and run HLO code. This is a lower level of
// abstraction than using the client interface and enables, for one, explicitly
// building a graph of HLO instructions to run.
// A base class for tests which build and/or run HLO code. The class includes
// support for running an HLO module on two platforms and compare the results.
// This is a lower level of abstraction than using the client interface and
// enables, for one, explicitly building a graph of HLO instructions to run.
//
// This can also be used to write text/file-based test cases. Note that the test
// target is responsible for linking the needed backends. A covenient way to do
// this is to make it an xla_test: it will generate test targets linking with
// the respective backends, which will be used as the test backend; the
// interpreter backend is already linked with hlo_test_base so it will be the
// default reference backend. For example, if you want to compare both cpu vs.
// interpreter, and gpu vs. interpreter, you can:
//
// xla_test (
// name = "sample_text_test",
// srcs = ["sample_text_test.cc"],
// backends = [
// "cpu",
// "gpu",
// ],
// deps = [
// "//third_party/tensorflow/compiler/xla/tests:hlo_test_base",
// ...
// ],
// )
//
// For a more detailed example, see "../tests/sample_text_test.cc".
class HloTestBase : public ::testing::Test {
protected:
HloTestBase() {}
// This uses the interpreter backend as the reference backend and
// automatically finds another supported backend as the test backend. If the
// interpreter is the only supported backend, it will be both the test backend
// and the reference backend.
HloTestBase();
// If your test doesn't use interpreter as the reference backend, you can use
// this constructor. Note that your test target is responsible for linking in
// both needed backends.
HloTestBase(::perftools::gputools::Platform* test_platform,
::perftools::gputools::Platform* reference_platform);
~HloTestBase() override {}
// Creates a new HLO module for a test. The module created will have
// TestName() for its name; it will also automatically populate its debug
// options from command-line flags. It's recommended to use this method to
// create all HloModules for tests.
// options from command-line flags. If you want a fresh HloModule object and
// then add HloComputations to it, it's recommended to use this method in your
// tests.
static std::unique_ptr<HloModule> CreateNewModule();
// Populates debug options from command-line flags and adjusts the options for
// testing. It is recommended to use this when you need to pass in
// DebugOptions, e.g. when creating a module from a string or a file.
static DebugOptions GetDebugOptionsForTest();
// Executes the given module and returns a global data handle.
StatusOr<perftools::gputools::DeviceMemoryBase> Execute(
std::unique_ptr<HloModule> module,
@ -71,6 +114,73 @@ class HloTestBase : public ::testing::Test {
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
arguments);
// Executes the given hlo module on two backends and compares results.
//
// 'arguments': the input of the hlo module. The LiteralPtr type accepts
// Literal* or std::unique_ptr<Literal>.
//
// 'error': if has value, expects the results to be near (within the error
// bound). Otherwise, expects the results to be equal.
//
// 'reference_preprocessor': the module should be ready to run on the test
// backend, but it might need to be tailored so that it is able to run on the
// reference backend. Note that the program shape of the module must not be
// modified.
template <typename LiteralPtr>
::testing::AssertionResult RunAndCompare(
std::unique_ptr<HloModule> module,
const tensorflow::gtl::ArraySlice<LiteralPtr> arguments,
const tensorflow::gtl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
TF_MUST_USE_RESULT;
// Same as above, except that the module will be executed without Hlo
// optimization.
template <typename LiteralPtr>
::testing::AssertionResult RunAndCompareNoHloPasses(
std::unique_ptr<HloModule> module,
const tensorflow::gtl::ArraySlice<LiteralPtr> arguments,
const tensorflow::gtl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
TF_MUST_USE_RESULT;
// Executes an hlo module with fake inputs and compares the results.
::testing::AssertionResult RunAndCompare(
std::unique_ptr<HloModule> module,
const tensorflow::gtl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
TF_MUST_USE_RESULT;
// Same as above, except that the module will be executed without Hlo
// optimization.
::testing::AssertionResult RunAndCompareNoHloPasses(
std::unique_ptr<HloModule> module,
const tensorflow::gtl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
TF_MUST_USE_RESULT;
// Convenient wrappers for executing and comparing an hlo module with fake
// input. Module can be passed in directly, or parsed from an hlo_string,
// or loaded from a file.
::testing::AssertionResult RunAndCompare(
const tensorflow::StringPiece hlo_string,
const tensorflow::gtl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
TF_MUST_USE_RESULT;
::testing::AssertionResult RunAndCompareFromFile(
const string& filename, const tensorflow::gtl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
TF_MUST_USE_RESULT;
::testing::AssertionResult RunAndCompareNoHloPasses(
const tensorflow::StringPiece hlo_string,
const tensorflow::gtl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
TF_MUST_USE_RESULT;
::testing::AssertionResult RunAndCompareNoHloPassesFromFile(
const string& filename, const tensorflow::gtl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
TF_MUST_USE_RESULT;
// Convenience method to force the layout of a given parameter in a module.
// The layout of parameter number 'param_no' in the 'module' is set to
// 'layout'.
@ -101,12 +211,31 @@ class HloTestBase : public ::testing::Test {
static string TestName();
// Returns the backend owned by the HloRunner.
// Returns the backend owned by the test runner.
Backend& backend();
HloRunner runner_;
HloRunner test_runner_;
HloRunner reference_runner_;
ErrorSpec error_spec_{0.0001};
private:
// Given the test module, makes a reference module that is ready to run on the
// reference platform. This assumes that the given module is ready to run on
// the test platform.
StatusOr<std::unique_ptr<HloModule>> MakeReferenceModule(
const HloModule& test_module,
const std::function<void(HloModule*)>& reference_preprocessor);
// Runs the module on two platforms with or without running hlo passes and
// compares the results. Returns whether the results are near or equal. If any
// error happens before the results are computed, returns the error status.
template <typename LiteralPtr>
StatusOr<::testing::AssertionResult> RunAndCompareInternal(
std::unique_ptr<HloModule> module,
const tensorflow::gtl::ArraySlice<LiteralPtr> arguments,
const tensorflow::gtl::optional<ErrorSpec>& error, bool run_hlo_passes,
const std::function<void(HloModule*)>& reference_preprocessor);
};
} // namespace xla

View File

@ -0,0 +1,8 @@
HloModule convolution.167:
ENTRY %convolution.167 (parameter.0: f32[16,28,28,128], parameter.1: f32[3,3,128,128]) -> f32[16,28,28,128] {
%parameter.0 = f32[16,28,28,128]{3,0,2,1} parameter(0)
%parameter.1 = f32[3,3,128,128]{3,2,1,0} parameter(1)
ROOT %convolution.167 = f32[16,28,28,128]{3,0,2,1} convolution(f32[16,28,28,128]{3,0,2,1} %parameter.0, f32[3,3,128,128]{3,2,1,0} %parameter.1), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01oi->b01f
}

View File

@ -333,23 +333,37 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual,
return result;
}
/* static */ void LiteralTestUtil::ExpectEqualTuple(const Literal& expected,
const Literal& actual) {
/* static */ ::testing::AssertionResult LiteralTestUtil::EqualTuple(
const Literal& expected, const Literal& actual) {
VLOG(1) << "expected: " << expected.ToString();
VLOG(1) << "actual: " << actual.ToString();
ASSERT_TRUE(ShapeUtil::IsTuple(expected.shape()));
ASSERT_TRUE(ShapeUtil::IsTuple(actual.shape()));
if (!ShapeUtil::IsTuple(expected.shape()) ||
!ShapeUtil::IsTuple(actual.shape())) {
return ::testing::AssertionFailure()
<< "tuples expected shape = " << expected.shape().ShortDebugString()
<< " actual shape = " << actual.shape().ShortDebugString();
}
AssertEqualShapes(expected.shape(), actual.shape());
for (uint64 i = 0; i < expected.tuple_literals_size(); ++i) {
const auto& expected_element = expected.tuple_literals(i);
const auto& actual_element = actual.tuple_literals(i);
if (ShapeUtil::IsTuple(expected_element.shape())) {
ExpectEqualTuple(expected_element, actual_element);
auto ret = EqualTuple(expected_element, actual_element);
if (!ret) {
return ret;
}
} else {
ExpectEqual(expected_element, actual_element);
return Equal(expected_element, actual_element);
}
}
return ::testing::AssertionSuccess();
}
/* static */ void LiteralTestUtil::ExpectEqualTuple(const Literal& expected,
const Literal& actual) {
EXPECT_TRUE(EqualTuple(expected, actual));
}
namespace {
@ -615,8 +629,7 @@ bool NearComparator::ExpectValuesNear<bfloat16>(bfloat16 expected,
if (!ShapeUtil::IsTuple(expected.shape()) ||
!ShapeUtil::IsTuple(actual.shape())) {
return ::testing::AssertionFailure()
<< "tuples expected expected shape = "
<< expected.shape().ShortDebugString()
<< "tuples expected shape = " << expected.shape().ShortDebugString()
<< " actual shape = " << actual.shape().ShortDebugString();
}
AssertEqualShapes(expected.shape(), actual.shape());
@ -650,6 +663,32 @@ bool NearComparator::ExpectValuesNear<bfloat16>(bfloat16 expected,
EXPECT_TRUE(NearTuple(expected, actual, error));
}
/*static*/ ::testing::AssertionResult LiteralTestUtil::NearOrEqual(
const Literal& expected, const Literal& actual,
const tensorflow::gtl::optional<ErrorSpec>& error) {
bool is_tuple = ShapeUtil::IsTuple(expected.shape());
if (error.has_value()) {
if (is_tuple) {
VLOG(1) << "Expects near tuple";
return NearTuple(expected, actual, *error);
}
VLOG(1) << "Expects near";
return Near(expected, actual, *error);
}
if (is_tuple) {
VLOG(1) << "Expects equal tuple";
return EqualTuple(expected, actual);
}
VLOG(1) << "Expects equal";
return Equal(expected, actual);
}
/*static*/ void LiteralTestUtil::ExpectNearOrEqual(
const Literal& expected, const Literal& actual,
const tensorflow::gtl::optional<ErrorSpec>& error) {
EXPECT_TRUE(NearOrEqual(expected, actual, error));
}
/* static */ string LiteralTestUtil::MultiIndexAsString(
tensorflow::gtl::ArraySlice<int64> multi_index) {
return tensorflow::strings::StrCat(

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@ -110,6 +111,10 @@ class LiteralTestUtil {
static void ExpectR4EqualArray4D(const Array4D<NativeT>& expected,
const Literal& actual);
// Returns whether the two tuples are equal.
static ::testing::AssertionResult EqualTuple(
const Literal& expected, const Literal& actual) TF_MUST_USE_RESULT;
// Expects that the values of the elements in the expected and actual tuples
// are equal. Tuples are matched recursively.
static void ExpectEqualTuple(const Literal& expected, const Literal& actual);
@ -177,6 +182,19 @@ class LiteralTestUtil {
static void ExpectNearTuple(const Literal& expected, const Literal& actual,
const ErrorSpec& error);
// If the error spec is given, returns whether the expected and the actual are
// within the error bound; otherwise, returns whether they are equal. Tuples
// will be compared recursively.
static ::testing::AssertionResult NearOrEqual(
const Literal& expected, const Literal& actual,
const tensorflow::gtl::optional<ErrorSpec>& error) TF_MUST_USE_RESULT;
// If the error spec is given, expects the expected and the actual to be near;
// otherwise, expects them to be equal. Tuples will be compared recursively.
static void ExpectNearOrEqual(
const Literal& expected, const Literal& actual,
const tensorflow::gtl::optional<ErrorSpec>& error);
// Returns a multi-dimensional index as a string. For example: '{7, 8}' will
// be returned for a 2-dimensional index with dimension 0 index equal to 7,
// dimension 1 equal to 8.

View File

@ -0,0 +1,51 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This demonstrates how to use hlo_test_base to create a file based testcase
// and compare results on gpu and cpu.
#include <string>
#include <vector>
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace {
class SampleFileTest : public HloTestBase {
protected:
SampleFileTest()
: HloTestBase(
/*test_platform=*/PlatformUtil::GetPlatform("gpu").ValueOrDie(),
/*reference_platform=*/PlatformUtil::GetPlatform("cpu")
.ValueOrDie()) {}
};
TEST_F(SampleFileTest, Convolution) {
const string& filename = "compiler/xla/tests/isolated_convolution.hlo";
string test_srcdir = tensorflow::testing::TensorFlowSrcRoot();
EXPECT_TRUE(RunAndCompareFromFile(
tensorflow::io::JoinPath(test_srcdir, filename), ErrorSpec{0.01}));
}
} // namespace
} // namespace xla

View File

@ -0,0 +1,66 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This demonstrates how to use hlo_test_base to create textual IR based
// testcases.
#include <string>
#include <vector>
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace {
using tensorflow::gtl::nullopt;
class SampleTextTest : public HloTestBase {};
TEST_F(SampleTextTest, Axpy) {
const string& hlo_string = R"(
HloModule axpy_module:
ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
%alpha = f32[] parameter(0)
%broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
%x = f32[2,4]{1,0} parameter(1)
%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
%y = f32[2,4]{1,0} parameter(2)
ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
}
)";
EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_string, ErrorSpec{0.0001}));
}
TEST_F(SampleTextTest, Tuple) {
const string& hlo_string = R"(
HloModule TupleCreate_module:
ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) {
%v1 = f32[] parameter(0)
%v2 = f32[3]{0} parameter(1)
%v3 = f32[2,3]{1,0} parameter(2)
ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3)
}
)";
EXPECT_TRUE(RunAndCompare(hlo_string, nullopt));
}
} // namespace
} // namespace xla