[XLA] Make hlo_test_base support running and comparing an hlo module.
Also, - Add examples that shows how to use the hlo_test_base as to create text/file-based testcases. - Change the behavior of GetDefaultPlatform: when only one platform is present, returns that one; when two platforms are present and one of them is the interpreter, returns the other one. This is because some tests included both hlo_test_base and client_library_test_base, but now the hlo_test_base is linked with interpreter by default, which makes client_library_test_base fail getting the default platform. PiperOrigin-RevId: 178309022
This commit is contained in:
parent
7b0458a789
commit
029109b4e1
@ -39,6 +39,14 @@ namespace se = ::perftools::gputools;
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
|
/*static*/ StatusOr<std::unique_ptr<HloModule>>
|
||||||
|
HloRunner::CreateModuleFromString(const tensorflow::StringPiece hlo_string,
|
||||||
|
const DebugOptions& debug_options) {
|
||||||
|
HloModuleConfig config;
|
||||||
|
config.set_debug_options(debug_options);
|
||||||
|
return tools::Parse(hlo_string, config);
|
||||||
|
}
|
||||||
|
|
||||||
/*static*/ StatusOr<std::unique_ptr<HloModule>>
|
/*static*/ StatusOr<std::unique_ptr<HloModule>>
|
||||||
HloRunner::ReadModuleFromHloProtoFile(const std::string& filename,
|
HloRunner::ReadModuleFromHloProtoFile(const std::string& filename,
|
||||||
const DebugOptions& debug_options) {
|
const DebugOptions& debug_options) {
|
||||||
|
@ -35,7 +35,8 @@ namespace xla {
|
|||||||
|
|
||||||
// A base class for running an HloModule. This executes the given HloModule on a
|
// A base class for running an HloModule. This executes the given HloModule on a
|
||||||
// certain backend directly without using the client interface. HloModule can be
|
// certain backend directly without using the client interface. HloModule can be
|
||||||
// explicitly built, or loaded from a serialization file (e.g., hlo proto file).
|
// explicitly built, or loaded from a serialization file (e.g., hlo proto
|
||||||
|
// file), or parsed from a hlo textual IR string.
|
||||||
class HloRunner {
|
class HloRunner {
|
||||||
public:
|
public:
|
||||||
HloRunner();
|
HloRunner();
|
||||||
@ -44,6 +45,12 @@ class HloRunner {
|
|||||||
|
|
||||||
~HloRunner();
|
~HloRunner();
|
||||||
|
|
||||||
|
// Converts an HloModule from the given hlo textual IR string (in
|
||||||
|
// HloModule::ToString format).
|
||||||
|
static StatusOr<std::unique_ptr<HloModule>> CreateModuleFromString(
|
||||||
|
const tensorflow::StringPiece hlo_string,
|
||||||
|
const DebugOptions& debug_options);
|
||||||
|
|
||||||
// Reads the proto file in xla.HloProto format, creates and returns the
|
// Reads the proto file in xla.HloProto format, creates and returns the
|
||||||
// HloModule. Will try to parse the filename as binary proto, then try as
|
// HloModule. Will try to parse the filename as binary proto, then try as
|
||||||
// text proto if that fails.
|
// text proto if that fails.
|
||||||
@ -65,7 +72,8 @@ class HloRunner {
|
|||||||
// Executes the given module with given literals as input and returns the
|
// Executes the given module with given literals as input and returns the
|
||||||
// result as a Literal. The LiteralPtr type accepts Literal* or
|
// result as a Literal. The LiteralPtr type accepts Literal* or
|
||||||
// std::unique_ptr<Literal>.
|
// std::unique_ptr<Literal>.
|
||||||
// If run_hlo_passes is true, the module will be executed without Hlo
|
//
|
||||||
|
// If run_hlo_passes is false, the module will be executed without Hlo
|
||||||
// optimization.
|
// optimization.
|
||||||
template <typename LiteralPtr>
|
template <typename LiteralPtr>
|
||||||
StatusOr<std::unique_ptr<Literal>> Execute(
|
StatusOr<std::unique_ptr<Literal>> Execute(
|
||||||
|
@ -33,10 +33,32 @@ namespace se = ::perftools::gputools;
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
|
using tensorflow::str_util::Lowercase;
|
||||||
|
|
||||||
// Minimum supported CUDA compute capability is 3.5.
|
// Minimum supported CUDA compute capability is 3.5.
|
||||||
constexpr int kMinCudaComputeCapabilityMajor = 3;
|
constexpr int kMinCudaComputeCapabilityMajor = 3;
|
||||||
constexpr int kMinCudaComputeCapabilityMinor = 5;
|
constexpr int kMinCudaComputeCapabilityMinor = 5;
|
||||||
|
|
||||||
|
// The name of the interpreter platform.
|
||||||
|
constexpr char kInterpreter[] = "interpreter";
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
string CanonicalPlatformName(const string& name) {
|
||||||
|
string platform_str = Lowercase(name);
|
||||||
|
// "cpu" and "host" mean the same thing.
|
||||||
|
if (platform_str == "cpu") {
|
||||||
|
platform_str = "host";
|
||||||
|
}
|
||||||
|
// "gpu" and "cuda" mean the same thing.
|
||||||
|
if (platform_str == "gpu") {
|
||||||
|
platform_str = "cuda";
|
||||||
|
}
|
||||||
|
return platform_str;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
/* static */ StatusOr<std::vector<se::Platform*>>
|
/* static */ StatusOr<std::vector<se::Platform*>>
|
||||||
PlatformUtil::GetSupportedPlatforms() {
|
PlatformUtil::GetSupportedPlatforms() {
|
||||||
se::MultiPlatformManager::PlatformMap platform_map;
|
se::MultiPlatformManager::PlatformMap platform_map;
|
||||||
@ -78,7 +100,7 @@ PlatformUtil::GetSupportedPlatforms() {
|
|||||||
return platforms;
|
return platforms;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ StatusOr<se::Platform*> PlatformUtil::GetDefaultPlatform() {
|
/* static */ StatusOr<se::Platform*> PlatformUtil::GetSolePlatform() {
|
||||||
TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms());
|
TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms());
|
||||||
if (platforms.empty()) {
|
if (platforms.empty()) {
|
||||||
return NotFound("no platforms found");
|
return NotFound("no platforms found");
|
||||||
@ -87,26 +109,42 @@ PlatformUtil::GetSupportedPlatforms() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Multiple platforms present and we can't pick a reasonable default.
|
// Multiple platforms present and we can't pick a reasonable default.
|
||||||
auto l = [](string* out, const se::Platform* p) { out->append(p->Name()); };
|
string platforms_string = tensorflow::str_util::Join(
|
||||||
string platforms_string = tensorflow::str_util::Join(platforms, ", ", l);
|
platforms, ", ",
|
||||||
|
[](string* out, const se::Platform* p) { out->append(p->Name()); });
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"must specify platform because more than one platform found: %s",
|
"must specify platform because more than one platform found: %s",
|
||||||
platforms_string.c_str());
|
platforms_string.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
/*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatform(
|
/* static */ StatusOr<se::Platform*> PlatformUtil::GetDefaultPlatform() {
|
||||||
const string& platform_name) {
|
TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms());
|
||||||
using tensorflow::str_util::Lowercase;
|
if (platforms.empty()) {
|
||||||
string platform_str = Lowercase(platform_name);
|
return NotFound("no platforms found");
|
||||||
// "cpu" and "host" mean the same thing.
|
} else if (platforms.size() == 1) {
|
||||||
if (platform_str == "cpu") {
|
return platforms[0];
|
||||||
platform_str = "host";
|
} else if (platforms.size() == 2) {
|
||||||
|
for (int i = 0; i < 2; i++) {
|
||||||
|
if (Lowercase(platforms[i]->Name()) == kInterpreter &&
|
||||||
|
Lowercase(platforms[1 - i]->Name()) != kInterpreter) {
|
||||||
|
return platforms[1 - i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// "gpu" and "cuda" mean the same thing.
|
|
||||||
if (platform_str == "gpu") {
|
|
||||||
platform_str = "cuda";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Multiple platforms present and we can't pick a reasonable default.
|
||||||
|
string platforms_string = tensorflow::str_util::Join(
|
||||||
|
platforms, ", ",
|
||||||
|
[](string* out, const se::Platform* p) { out->append(p->Name()); });
|
||||||
|
return InvalidArgument(
|
||||||
|
"must specify platform because more than one platform (except for the "
|
||||||
|
"interpreter platform) found: %s",
|
||||||
|
platforms_string.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
/*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatform(
|
||||||
|
const string& platform_name) {
|
||||||
|
string platform_str = CanonicalPlatformName(platform_name);
|
||||||
TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms());
|
TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms());
|
||||||
for (se::Platform* platform : platforms) {
|
for (se::Platform* platform : platforms) {
|
||||||
if (Lowercase(platform->Name()) == platform_str) {
|
if (Lowercase(platform->Name()) == platform_str) {
|
||||||
@ -116,6 +154,32 @@ PlatformUtil::GetSupportedPlatforms() {
|
|||||||
return InvalidArgument("platform %s not found", platform_name.c_str());
|
return InvalidArgument("platform %s not found", platform_name.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatformExceptFor(
|
||||||
|
const string& platform_name) {
|
||||||
|
string platform_str = CanonicalPlatformName(platform_name);
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms());
|
||||||
|
std::vector<se::Platform*> matched;
|
||||||
|
for (se::Platform* platform : platforms) {
|
||||||
|
if (Lowercase(platform->Name()) != platform_name) {
|
||||||
|
matched.push_back(platform);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (matched.empty()) {
|
||||||
|
return InvalidArgument("unable to find platform that is not %s",
|
||||||
|
platform_name.c_str());
|
||||||
|
}
|
||||||
|
if (matched.size() == 1) {
|
||||||
|
return matched[0];
|
||||||
|
}
|
||||||
|
string matched_string = tensorflow::str_util::Join(
|
||||||
|
matched, ", ",
|
||||||
|
[](string* out, const se::Platform* p) { out->append(p->Name()); });
|
||||||
|
return InvalidArgument(
|
||||||
|
"found multiple platforms %s, but expected one platform except for %s",
|
||||||
|
matched_string.c_str(), platform_name.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
// Returns whether the device underlying the given StreamExecutor is supported
|
// Returns whether the device underlying the given StreamExecutor is supported
|
||||||
// by XLA.
|
// by XLA.
|
||||||
static bool IsDeviceSupported(se::StreamExecutor* executor) {
|
static bool IsDeviceSupported(se::StreamExecutor* executor) {
|
||||||
|
@ -37,16 +37,28 @@ class PlatformUtil {
|
|||||||
static StatusOr<std::vector<perftools::gputools::Platform*>>
|
static StatusOr<std::vector<perftools::gputools::Platform*>>
|
||||||
GetSupportedPlatforms();
|
GetSupportedPlatforms();
|
||||||
|
|
||||||
// Convenience function which returns the default supported platform. If
|
// Convenience function which returns the default supported platform for
|
||||||
|
// tests. If exactly one supported platform is present, then this platform is
|
||||||
|
// the default platform. If exactly two platforms are present and one of them
|
||||||
|
// is the interpreter platform, then the other platform is the default
|
||||||
|
// platform. Otherwise returns an error.
|
||||||
|
static StatusOr<perftools::gputools::Platform*> GetDefaultPlatform();
|
||||||
|
|
||||||
|
// Convenience function which returns the sole supported platform. If
|
||||||
// exactly one supported platform is present, then this platform is the
|
// exactly one supported platform is present, then this platform is the
|
||||||
// default platform. Otherwise returns an error.
|
// default platform. Otherwise returns an error.
|
||||||
static StatusOr<perftools::gputools::Platform*> GetDefaultPlatform();
|
static StatusOr<perftools::gputools::Platform*> GetSolePlatform();
|
||||||
|
|
||||||
// Returns the platform according to the given name. Returns error if there is
|
// Returns the platform according to the given name. Returns error if there is
|
||||||
// no such platform.
|
// no such platform.
|
||||||
static StatusOr<perftools::gputools::Platform*> GetPlatform(
|
static StatusOr<perftools::gputools::Platform*> GetPlatform(
|
||||||
const string& platform_name);
|
const string& platform_name);
|
||||||
|
|
||||||
|
// Returns exactly one platform that does not have given name. Returns error
|
||||||
|
// if there is no such platform, or there are multiple such platforms.
|
||||||
|
static StatusOr<perftools::gputools::Platform*> GetPlatformExceptFor(
|
||||||
|
const string& platform_name);
|
||||||
|
|
||||||
// Returns a vector of StreamExecutors for the given platform. The vector is
|
// Returns a vector of StreamExecutors for the given platform. The vector is
|
||||||
// indexed by device ordinal (device numbering used by StreamExecutor). If an
|
// indexed by device ordinal (device numbering used by StreamExecutor). If an
|
||||||
// element is nullptr, then the device is present by not supported by XLA.
|
// element is nullptr, then the device is present by not supported by XLA.
|
||||||
|
@ -105,7 +105,9 @@ cc_library(
|
|||||||
hdrs = ["hlo_test_base.h"],
|
hdrs = ["hlo_test_base.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":literal_test_util",
|
":literal_test_util",
|
||||||
|
":test_utils",
|
||||||
"//tensorflow/compiler/xla:shape_layout",
|
"//tensorflow/compiler/xla:shape_layout",
|
||||||
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
@ -115,6 +117,9 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:computation_layout",
|
"//tensorflow/compiler/xla/service:computation_layout",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/compiler/xla/service:hlo_runner",
|
"//tensorflow/compiler/xla/service:hlo_runner",
|
||||||
|
"//tensorflow/compiler/xla/service:interpreter_plugin", # reference backend
|
||||||
|
"//tensorflow/compiler/xla/service:platform_util",
|
||||||
|
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:stream_executor_no_cuda",
|
"//tensorflow/core:stream_executor_no_cuda",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
@ -1678,6 +1683,45 @@ xla_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# A demo of textual IR based test.
|
||||||
|
xla_test(
|
||||||
|
name = "sample_text_test",
|
||||||
|
srcs = ["sample_text_test.cc"],
|
||||||
|
# You can leave this empty if you want to test all supported backends.
|
||||||
|
backends = [
|
||||||
|
"cpu",
|
||||||
|
"gpu",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":hlo_test_base",
|
||||||
|
"//tensorflow/compiler/xla:test",
|
||||||
|
"//tensorflow/compiler/xla:types",
|
||||||
|
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||||
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# A demo of test that loads an hlo module from a file and compares results on gpu and cpu.
|
||||||
|
tf_cc_test(
|
||||||
|
name = "sample_file_test",
|
||||||
|
srcs = ["sample_file_test.cc"],
|
||||||
|
data = ["isolated_convolution.hlo"],
|
||||||
|
tags = ["requires-gpu-sm35"],
|
||||||
|
deps = [
|
||||||
|
":hlo_test_base",
|
||||||
|
"//tensorflow/compiler/xla:test",
|
||||||
|
"//tensorflow/compiler/xla:types",
|
||||||
|
"//tensorflow/compiler/xla/service:cpu_plugin", # reference backend
|
||||||
|
"//tensorflow/compiler/xla/service:gpu_plugin", # test backend
|
||||||
|
"//tensorflow/compiler/xla/service:platform_util",
|
||||||
|
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||||
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
|
@ -15,13 +15,22 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/layout_util.h"
|
||||||
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
|
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
|
||||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||||
|
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
@ -30,18 +39,72 @@ namespace se = ::perftools::gputools;
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using tensorflow::StringPiece;
|
||||||
|
using tensorflow::gtl::ArraySlice;
|
||||||
|
using tensorflow::gtl::optional;
|
||||||
|
|
||||||
|
constexpr char kInterpreter[] = "interpreter";
|
||||||
|
|
||||||
|
// Helper functions to get test and reference platforms.
|
||||||
|
se::Platform* GetReferencePlatform() {
|
||||||
|
auto result = PlatformUtil::GetPlatform(kInterpreter);
|
||||||
|
TF_CHECK_OK(result.status()) << "could not get interpreter platform";
|
||||||
|
return result.ValueOrDie();
|
||||||
|
}
|
||||||
|
|
||||||
|
se::Platform* GetTestPlatform() {
|
||||||
|
auto result = PlatformUtil::GetDefaultPlatform();
|
||||||
|
TF_CHECK_OK(result.status()) << "could not get test platform";
|
||||||
|
return result.ValueOrDie();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ProgramShapesEqual(const ProgramShape& lhs, const ProgramShape& rhs) {
|
||||||
|
if (lhs.parameters_size() != rhs.parameters_size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < lhs.parameters_size(); i++) {
|
||||||
|
if (!ShapeUtil::Equal(lhs.parameters(i), rhs.parameters(i))) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ShapeUtil::Equal(lhs.result(), rhs.result());
|
||||||
|
}
|
||||||
|
|
||||||
|
ProgramShape GetProgramShapeWithLayout(const HloModule& module) {
|
||||||
|
ProgramShape program_shape;
|
||||||
|
const auto* entry = module.entry_computation();
|
||||||
|
for (const auto* param : entry->parameter_instructions()) {
|
||||||
|
*program_shape.add_parameters() = param->shape();
|
||||||
|
*program_shape.add_parameter_names() = param->name();
|
||||||
|
}
|
||||||
|
*program_shape.mutable_result() = entry->root_instruction()->shape();
|
||||||
|
return program_shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
HloTestBase::HloTestBase()
|
||||||
|
: HloTestBase(GetTestPlatform(), GetReferencePlatform()) {}
|
||||||
|
|
||||||
|
HloTestBase::HloTestBase(se::Platform* test_platform,
|
||||||
|
se::Platform* reference_platform)
|
||||||
|
: test_runner_(test_platform), reference_runner_(reference_platform) {}
|
||||||
|
|
||||||
/* static */
|
/* static */
|
||||||
std::unique_ptr<HloModule> HloTestBase::CreateNewModule() {
|
std::unique_ptr<HloModule> HloTestBase::CreateNewModule() {
|
||||||
HloModuleConfig config;
|
HloModuleConfig config;
|
||||||
|
config.set_debug_options(GetDebugOptionsForTest());
|
||||||
|
return MakeUnique<HloModule>(TestName(), VersionedComputationHandle(),
|
||||||
|
config);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() {
|
||||||
auto debug_options = legacy_flags::GetDebugOptionsFromFlags();
|
auto debug_options = legacy_flags::GetDebugOptionsFromFlags();
|
||||||
// TODO(b/38354253): Change tests to use Parameters instead of Constants.
|
// TODO(b/38354253): Change tests to use Parameters instead of Constants.
|
||||||
debug_options.add_xla_disable_hlo_passes("constant_folding");
|
debug_options.add_xla_disable_hlo_passes("constant_folding");
|
||||||
|
return debug_options;
|
||||||
config.set_debug_options(debug_options);
|
|
||||||
|
|
||||||
return MakeUnique<HloModule>(TestName(), VersionedComputationHandle(),
|
|
||||||
config);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<perftools::gputools::DeviceMemoryBase> HloTestBase::Execute(
|
StatusOr<perftools::gputools::DeviceMemoryBase> HloTestBase::Execute(
|
||||||
@ -49,25 +112,168 @@ StatusOr<perftools::gputools::DeviceMemoryBase> HloTestBase::Execute(
|
|||||||
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
|
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
|
||||||
arguments,
|
arguments,
|
||||||
Shape* result_shape) {
|
Shape* result_shape) {
|
||||||
return runner_.Execute(std::move(module), arguments, result_shape);
|
return test_runner_.Execute(std::move(module), arguments, result_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
se::DeviceMemoryBase HloTestBase::TransferToDevice(const Literal& literal) {
|
se::DeviceMemoryBase HloTestBase::TransferToDevice(const Literal& literal) {
|
||||||
return runner_.TransferToDevice(literal).ValueOrDie();
|
return test_runner_.TransferToDevice(literal).ValueOrDie();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<Literal> HloTestBase::TransferFromDevice(
|
std::unique_ptr<Literal> HloTestBase::TransferFromDevice(
|
||||||
const Shape& shape, se::DeviceMemoryBase device_base) {
|
const Shape& shape, se::DeviceMemoryBase device_base) {
|
||||||
return runner_.TransferFromDevice(shape, device_base).ValueOrDie();
|
return test_runner_.TransferFromDevice(shape, device_base).ValueOrDie();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<Literal> HloTestBase::ExecuteAndTransfer(
|
std::unique_ptr<Literal> HloTestBase::ExecuteAndTransfer(
|
||||||
std::unique_ptr<HloModule> module,
|
std::unique_ptr<HloModule> module,
|
||||||
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
|
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
|
||||||
return runner_.ExecuteAndTransfer(std::move(module), arguments).ValueOrDie();
|
return test_runner_.ExecuteAndTransfer(std::move(module), arguments)
|
||||||
|
.ValueOrDie();
|
||||||
}
|
}
|
||||||
|
|
||||||
Backend& HloTestBase::backend() { return runner_.backend(); }
|
StatusOr<std::unique_ptr<HloModule>> HloTestBase::MakeReferenceModule(
|
||||||
|
const HloModule& test_module,
|
||||||
|
const std::function<void(HloModule*)>& reference_preprocessor) {
|
||||||
|
std::unique_ptr<HloModule> reference_module = test_module.Clone();
|
||||||
|
const auto& program_shape = GetProgramShapeWithLayout(test_module);
|
||||||
|
|
||||||
|
if (reference_preprocessor != nullptr) {
|
||||||
|
reference_preprocessor(reference_module.get());
|
||||||
|
if (!ProgramShapesEqual(program_shape,
|
||||||
|
GetProgramShapeWithLayout(*reference_module))) {
|
||||||
|
return InvalidArgument(
|
||||||
|
"reference preprocessor must not modify the program shape");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(VerifyHloModule(*reference_runner_.backend().platform(),
|
||||||
|
reference_module.get()));
|
||||||
|
return std::move(reference_module);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename LiteralPtr>
|
||||||
|
StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
|
||||||
|
std::unique_ptr<HloModule> module, const ArraySlice<LiteralPtr> arguments,
|
||||||
|
const optional<ErrorSpec>& error, bool run_hlo_passes,
|
||||||
|
const std::function<void(HloModule*)>& reference_preprocessor) {
|
||||||
|
static_assert(
|
||||||
|
std::is_same<Literal*, LiteralPtr>::value ||
|
||||||
|
std::is_same<std::unique_ptr<Literal>, LiteralPtr>::value,
|
||||||
|
"The LiteralPtr type only accepts Literal* or std::unique_ptr<Literal>.");
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
VerifyHloModule(*test_runner_.backend().platform(), module.get()));
|
||||||
|
TF_ASSIGN_OR_RETURN(auto reference_module,
|
||||||
|
MakeReferenceModule(*module, reference_preprocessor));
|
||||||
|
|
||||||
|
// Execute on two backends.
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto test,
|
||||||
|
test_runner_.Execute(std::move(module), arguments, run_hlo_passes));
|
||||||
|
TF_ASSIGN_OR_RETURN(auto reference,
|
||||||
|
reference_runner_.Execute(std::move(reference_module),
|
||||||
|
arguments, run_hlo_passes));
|
||||||
|
return LiteralTestUtil::NearOrEqual(/*expected=*/*reference, /*actual=*/*test,
|
||||||
|
error);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename LiteralPtr>
|
||||||
|
::testing::AssertionResult HloTestBase::RunAndCompare(
|
||||||
|
std::unique_ptr<HloModule> module, const ArraySlice<LiteralPtr> arguments,
|
||||||
|
const optional<ErrorSpec>& error,
|
||||||
|
const std::function<void(HloModule*)>& reference_preprocessor) {
|
||||||
|
auto result =
|
||||||
|
RunAndCompareInternal(std::move(module), arguments, error,
|
||||||
|
/*run_hlo_passes=*/true, reference_preprocessor);
|
||||||
|
if (!result.ok()) {
|
||||||
|
return ::testing::AssertionFailure() << result.status();
|
||||||
|
}
|
||||||
|
return result.ValueOrDie();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename LiteralPtr>
|
||||||
|
::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
|
||||||
|
std::unique_ptr<HloModule> module, const ArraySlice<LiteralPtr> arguments,
|
||||||
|
const optional<ErrorSpec>& error,
|
||||||
|
const std::function<void(HloModule*)>& reference_preprocessor) {
|
||||||
|
auto result =
|
||||||
|
RunAndCompareInternal(std::move(module), arguments, error,
|
||||||
|
/*run_hlo_passes=*/false, reference_preprocessor);
|
||||||
|
if (!result.ok()) {
|
||||||
|
return ::testing::AssertionFailure() << result.status();
|
||||||
|
}
|
||||||
|
return result.ValueOrDie();
|
||||||
|
}
|
||||||
|
|
||||||
|
::testing::AssertionResult HloTestBase::RunAndCompare(
|
||||||
|
std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
|
||||||
|
const std::function<void(HloModule*)>& reference_preprocessor) {
|
||||||
|
const auto& fake_arguments =
|
||||||
|
MakeFakeArguments(module.get()).ConsumeValueOrDie();
|
||||||
|
return RunAndCompare<std::unique_ptr<Literal>>(
|
||||||
|
std::move(module), fake_arguments, error, reference_preprocessor);
|
||||||
|
}
|
||||||
|
|
||||||
|
::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
|
||||||
|
std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
|
||||||
|
const std::function<void(HloModule*)>& reference_preprocessor) {
|
||||||
|
const auto& fake_arguments =
|
||||||
|
MakeFakeArguments(module.get()).ConsumeValueOrDie();
|
||||||
|
return RunAndCompareNoHloPasses<std::unique_ptr<Literal>>(
|
||||||
|
std::move(module), fake_arguments, error, reference_preprocessor);
|
||||||
|
}
|
||||||
|
|
||||||
|
::testing::AssertionResult HloTestBase::RunAndCompare(
|
||||||
|
const StringPiece hlo_string,
|
||||||
|
const tensorflow::gtl::optional<ErrorSpec>& error,
|
||||||
|
const std::function<void(HloModule*)>& reference_preprocessor) {
|
||||||
|
auto module_or_status =
|
||||||
|
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
|
||||||
|
if (!module_or_status.ok()) {
|
||||||
|
return ::testing::AssertionFailure() << "failed parsing hlo textual IR";
|
||||||
|
}
|
||||||
|
return RunAndCompare(module_or_status.ConsumeValueOrDie(), error,
|
||||||
|
reference_preprocessor);
|
||||||
|
}
|
||||||
|
|
||||||
|
::testing::AssertionResult HloTestBase::RunAndCompareFromFile(
|
||||||
|
const string& filename, const tensorflow::gtl::optional<ErrorSpec>& error,
|
||||||
|
const std::function<void(HloModule*)>& reference_preprocessor) {
|
||||||
|
auto module_or_status =
|
||||||
|
HloRunner::ReadModule(filename, GetDebugOptionsForTest());
|
||||||
|
if (!module_or_status.ok()) {
|
||||||
|
return ::testing::AssertionFailure()
|
||||||
|
<< "failed reading hlo module from file";
|
||||||
|
}
|
||||||
|
return RunAndCompare(module_or_status.ConsumeValueOrDie(), error,
|
||||||
|
reference_preprocessor);
|
||||||
|
}
|
||||||
|
|
||||||
|
::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
|
||||||
|
const StringPiece hlo_string,
|
||||||
|
const tensorflow::gtl::optional<ErrorSpec>& error,
|
||||||
|
const std::function<void(HloModule*)>& reference_preprocessor) {
|
||||||
|
auto module_or_status =
|
||||||
|
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
|
||||||
|
if (!module_or_status.ok()) {
|
||||||
|
return ::testing::AssertionFailure() << "failed parsing hlo textual IR";
|
||||||
|
}
|
||||||
|
return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error,
|
||||||
|
reference_preprocessor);
|
||||||
|
}
|
||||||
|
|
||||||
|
::testing::AssertionResult HloTestBase::RunAndCompareNoHloPassesFromFile(
|
||||||
|
const string& filename, const tensorflow::gtl::optional<ErrorSpec>& error,
|
||||||
|
const std::function<void(HloModule*)>& reference_preprocessor) {
|
||||||
|
auto module_or_status =
|
||||||
|
HloRunner::ReadModule(filename, GetDebugOptionsForTest());
|
||||||
|
if (!module_or_status.ok()) {
|
||||||
|
return ::testing::AssertionFailure()
|
||||||
|
<< "failed reading hlo module from file";
|
||||||
|
}
|
||||||
|
return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error,
|
||||||
|
reference_preprocessor);
|
||||||
|
}
|
||||||
|
|
||||||
|
Backend& HloTestBase::backend() { return test_runner_.backend(); }
|
||||||
|
|
||||||
/* static */
|
/* static */
|
||||||
string HloTestBase::TestName() {
|
string HloTestBase::TestName() {
|
||||||
|
@ -24,31 +24,74 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/computation_layout.h"
|
#include "tensorflow/compiler/xla/service/computation_layout.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_runner.h"
|
#include "tensorflow/compiler/xla/service/hlo_runner.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||||
#include "tensorflow/compiler/xla/shape_layout.h"
|
#include "tensorflow/compiler/xla/shape_layout.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/optional.h"
|
||||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
// A base class for tests which build and run HLO code. This is a lower level of
|
// A base class for tests which build and/or run HLO code. The class includes
|
||||||
// abstraction than using the client interface and enables, for one, explicitly
|
// support for running an HLO module on two platforms and compare the results.
|
||||||
// building a graph of HLO instructions to run.
|
// This is a lower level of abstraction than using the client interface and
|
||||||
|
// enables, for one, explicitly building a graph of HLO instructions to run.
|
||||||
|
//
|
||||||
|
// This can also be used to write text/file-based test cases. Note that the test
|
||||||
|
// target is responsible for linking the needed backends. A covenient way to do
|
||||||
|
// this is to make it an xla_test: it will generate test targets linking with
|
||||||
|
// the respective backends, which will be used as the test backend; the
|
||||||
|
// interpreter backend is already linked with hlo_test_base so it will be the
|
||||||
|
// default reference backend. For example, if you want to compare both cpu vs.
|
||||||
|
// interpreter, and gpu vs. interpreter, you can:
|
||||||
|
//
|
||||||
|
// xla_test (
|
||||||
|
// name = "sample_text_test",
|
||||||
|
// srcs = ["sample_text_test.cc"],
|
||||||
|
// backends = [
|
||||||
|
// "cpu",
|
||||||
|
// "gpu",
|
||||||
|
// ],
|
||||||
|
// deps = [
|
||||||
|
// "//third_party/tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
|
// ...
|
||||||
|
// ],
|
||||||
|
// )
|
||||||
|
//
|
||||||
|
// For a more detailed example, see "../tests/sample_text_test.cc".
|
||||||
class HloTestBase : public ::testing::Test {
|
class HloTestBase : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
HloTestBase() {}
|
// This uses the interpreter backend as the reference backend and
|
||||||
|
// automatically finds another supported backend as the test backend. If the
|
||||||
|
// interpreter is the only supported backend, it will be both the test backend
|
||||||
|
// and the reference backend.
|
||||||
|
HloTestBase();
|
||||||
|
|
||||||
|
// If your test doesn't use interpreter as the reference backend, you can use
|
||||||
|
// this constructor. Note that your test target is responsible for linking in
|
||||||
|
// both needed backends.
|
||||||
|
HloTestBase(::perftools::gputools::Platform* test_platform,
|
||||||
|
::perftools::gputools::Platform* reference_platform);
|
||||||
|
|
||||||
~HloTestBase() override {}
|
~HloTestBase() override {}
|
||||||
|
|
||||||
// Creates a new HLO module for a test. The module created will have
|
// Creates a new HLO module for a test. The module created will have
|
||||||
// TestName() for its name; it will also automatically populate its debug
|
// TestName() for its name; it will also automatically populate its debug
|
||||||
// options from command-line flags. It's recommended to use this method to
|
// options from command-line flags. If you want a fresh HloModule object and
|
||||||
// create all HloModules for tests.
|
// then add HloComputations to it, it's recommended to use this method in your
|
||||||
|
// tests.
|
||||||
static std::unique_ptr<HloModule> CreateNewModule();
|
static std::unique_ptr<HloModule> CreateNewModule();
|
||||||
|
|
||||||
|
// Populates debug options from command-line flags and adjusts the options for
|
||||||
|
// testing. It is recommended to use this when you need to pass in
|
||||||
|
// DebugOptions, e.g. when creating a module from a string or a file.
|
||||||
|
static DebugOptions GetDebugOptionsForTest();
|
||||||
|
|
||||||
// Executes the given module and returns a global data handle.
|
// Executes the given module and returns a global data handle.
|
||||||
StatusOr<perftools::gputools::DeviceMemoryBase> Execute(
|
StatusOr<perftools::gputools::DeviceMemoryBase> Execute(
|
||||||
std::unique_ptr<HloModule> module,
|
std::unique_ptr<HloModule> module,
|
||||||
@ -71,6 +114,73 @@ class HloTestBase : public ::testing::Test {
|
|||||||
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
|
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
|
||||||
arguments);
|
arguments);
|
||||||
|
|
||||||
|
// Executes the given hlo module on two backends and compares results.
|
||||||
|
//
|
||||||
|
// 'arguments': the input of the hlo module. The LiteralPtr type accepts
|
||||||
|
// Literal* or std::unique_ptr<Literal>.
|
||||||
|
//
|
||||||
|
// 'error': if has value, expects the results to be near (within the error
|
||||||
|
// bound). Otherwise, expects the results to be equal.
|
||||||
|
//
|
||||||
|
// 'reference_preprocessor': the module should be ready to run on the test
|
||||||
|
// backend, but it might need to be tailored so that it is able to run on the
|
||||||
|
// reference backend. Note that the program shape of the module must not be
|
||||||
|
// modified.
|
||||||
|
template <typename LiteralPtr>
|
||||||
|
::testing::AssertionResult RunAndCompare(
|
||||||
|
std::unique_ptr<HloModule> module,
|
||||||
|
const tensorflow::gtl::ArraySlice<LiteralPtr> arguments,
|
||||||
|
const tensorflow::gtl::optional<ErrorSpec>& error,
|
||||||
|
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
|
||||||
|
TF_MUST_USE_RESULT;
|
||||||
|
|
||||||
|
// Same as above, except that the module will be executed without Hlo
|
||||||
|
// optimization.
|
||||||
|
template <typename LiteralPtr>
|
||||||
|
::testing::AssertionResult RunAndCompareNoHloPasses(
|
||||||
|
std::unique_ptr<HloModule> module,
|
||||||
|
const tensorflow::gtl::ArraySlice<LiteralPtr> arguments,
|
||||||
|
const tensorflow::gtl::optional<ErrorSpec>& error,
|
||||||
|
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
|
||||||
|
TF_MUST_USE_RESULT;
|
||||||
|
|
||||||
|
// Executes an hlo module with fake inputs and compares the results.
|
||||||
|
::testing::AssertionResult RunAndCompare(
|
||||||
|
std::unique_ptr<HloModule> module,
|
||||||
|
const tensorflow::gtl::optional<ErrorSpec>& error,
|
||||||
|
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
|
||||||
|
TF_MUST_USE_RESULT;
|
||||||
|
|
||||||
|
// Same as above, except that the module will be executed without Hlo
|
||||||
|
// optimization.
|
||||||
|
::testing::AssertionResult RunAndCompareNoHloPasses(
|
||||||
|
std::unique_ptr<HloModule> module,
|
||||||
|
const tensorflow::gtl::optional<ErrorSpec>& error,
|
||||||
|
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
|
||||||
|
TF_MUST_USE_RESULT;
|
||||||
|
|
||||||
|
// Convenient wrappers for executing and comparing an hlo module with fake
|
||||||
|
// input. Module can be passed in directly, or parsed from an hlo_string,
|
||||||
|
// or loaded from a file.
|
||||||
|
::testing::AssertionResult RunAndCompare(
|
||||||
|
const tensorflow::StringPiece hlo_string,
|
||||||
|
const tensorflow::gtl::optional<ErrorSpec>& error,
|
||||||
|
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
|
||||||
|
TF_MUST_USE_RESULT;
|
||||||
|
::testing::AssertionResult RunAndCompareFromFile(
|
||||||
|
const string& filename, const tensorflow::gtl::optional<ErrorSpec>& error,
|
||||||
|
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
|
||||||
|
TF_MUST_USE_RESULT;
|
||||||
|
::testing::AssertionResult RunAndCompareNoHloPasses(
|
||||||
|
const tensorflow::StringPiece hlo_string,
|
||||||
|
const tensorflow::gtl::optional<ErrorSpec>& error,
|
||||||
|
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
|
||||||
|
TF_MUST_USE_RESULT;
|
||||||
|
::testing::AssertionResult RunAndCompareNoHloPassesFromFile(
|
||||||
|
const string& filename, const tensorflow::gtl::optional<ErrorSpec>& error,
|
||||||
|
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
|
||||||
|
TF_MUST_USE_RESULT;
|
||||||
|
|
||||||
// Convenience method to force the layout of a given parameter in a module.
|
// Convenience method to force the layout of a given parameter in a module.
|
||||||
// The layout of parameter number 'param_no' in the 'module' is set to
|
// The layout of parameter number 'param_no' in the 'module' is set to
|
||||||
// 'layout'.
|
// 'layout'.
|
||||||
@ -101,12 +211,31 @@ class HloTestBase : public ::testing::Test {
|
|||||||
|
|
||||||
static string TestName();
|
static string TestName();
|
||||||
|
|
||||||
// Returns the backend owned by the HloRunner.
|
// Returns the backend owned by the test runner.
|
||||||
Backend& backend();
|
Backend& backend();
|
||||||
|
|
||||||
HloRunner runner_;
|
HloRunner test_runner_;
|
||||||
|
HloRunner reference_runner_;
|
||||||
|
|
||||||
ErrorSpec error_spec_{0.0001};
|
ErrorSpec error_spec_{0.0001};
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Given the test module, makes a reference module that is ready to run on the
|
||||||
|
// reference platform. This assumes that the given module is ready to run on
|
||||||
|
// the test platform.
|
||||||
|
StatusOr<std::unique_ptr<HloModule>> MakeReferenceModule(
|
||||||
|
const HloModule& test_module,
|
||||||
|
const std::function<void(HloModule*)>& reference_preprocessor);
|
||||||
|
|
||||||
|
// Runs the module on two platforms with or without running hlo passes and
|
||||||
|
// compares the results. Returns whether the results are near or equal. If any
|
||||||
|
// error happens before the results are computed, returns the error status.
|
||||||
|
template <typename LiteralPtr>
|
||||||
|
StatusOr<::testing::AssertionResult> RunAndCompareInternal(
|
||||||
|
std::unique_ptr<HloModule> module,
|
||||||
|
const tensorflow::gtl::ArraySlice<LiteralPtr> arguments,
|
||||||
|
const tensorflow::gtl::optional<ErrorSpec>& error, bool run_hlo_passes,
|
||||||
|
const std::function<void(HloModule*)>& reference_preprocessor);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
8
tensorflow/compiler/xla/tests/isolated_convolution.hlo
Normal file
8
tensorflow/compiler/xla/tests/isolated_convolution.hlo
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
HloModule convolution.167:
|
||||||
|
|
||||||
|
ENTRY %convolution.167 (parameter.0: f32[16,28,28,128], parameter.1: f32[3,3,128,128]) -> f32[16,28,28,128] {
|
||||||
|
%parameter.0 = f32[16,28,28,128]{3,0,2,1} parameter(0)
|
||||||
|
%parameter.1 = f32[3,3,128,128]{3,2,1,0} parameter(1)
|
||||||
|
ROOT %convolution.167 = f32[16,28,28,128]{3,0,2,1} convolution(f32[16,28,28,128]{3,0,2,1} %parameter.0, f32[3,3,128,128]{3,2,1,0} %parameter.1), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01oi->b01f
|
||||||
|
}
|
||||||
|
|
@ -333,23 +333,37 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual,
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ void LiteralTestUtil::ExpectEqualTuple(const Literal& expected,
|
/* static */ ::testing::AssertionResult LiteralTestUtil::EqualTuple(
|
||||||
const Literal& actual) {
|
const Literal& expected, const Literal& actual) {
|
||||||
VLOG(1) << "expected: " << expected.ToString();
|
VLOG(1) << "expected: " << expected.ToString();
|
||||||
VLOG(1) << "actual: " << actual.ToString();
|
VLOG(1) << "actual: " << actual.ToString();
|
||||||
|
|
||||||
ASSERT_TRUE(ShapeUtil::IsTuple(expected.shape()));
|
if (!ShapeUtil::IsTuple(expected.shape()) ||
|
||||||
ASSERT_TRUE(ShapeUtil::IsTuple(actual.shape()));
|
!ShapeUtil::IsTuple(actual.shape())) {
|
||||||
|
return ::testing::AssertionFailure()
|
||||||
|
<< "tuples expected shape = " << expected.shape().ShortDebugString()
|
||||||
|
<< " actual shape = " << actual.shape().ShortDebugString();
|
||||||
|
}
|
||||||
AssertEqualShapes(expected.shape(), actual.shape());
|
AssertEqualShapes(expected.shape(), actual.shape());
|
||||||
for (uint64 i = 0; i < expected.tuple_literals_size(); ++i) {
|
for (uint64 i = 0; i < expected.tuple_literals_size(); ++i) {
|
||||||
const auto& expected_element = expected.tuple_literals(i);
|
const auto& expected_element = expected.tuple_literals(i);
|
||||||
const auto& actual_element = actual.tuple_literals(i);
|
const auto& actual_element = actual.tuple_literals(i);
|
||||||
if (ShapeUtil::IsTuple(expected_element.shape())) {
|
if (ShapeUtil::IsTuple(expected_element.shape())) {
|
||||||
ExpectEqualTuple(expected_element, actual_element);
|
auto ret = EqualTuple(expected_element, actual_element);
|
||||||
|
if (!ret) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
ExpectEqual(expected_element, actual_element);
|
return Equal(expected_element, actual_element);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return ::testing::AssertionSuccess();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* static */ void LiteralTestUtil::ExpectEqualTuple(const Literal& expected,
|
||||||
|
const Literal& actual) {
|
||||||
|
EXPECT_TRUE(EqualTuple(expected, actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -615,8 +629,7 @@ bool NearComparator::ExpectValuesNear<bfloat16>(bfloat16 expected,
|
|||||||
if (!ShapeUtil::IsTuple(expected.shape()) ||
|
if (!ShapeUtil::IsTuple(expected.shape()) ||
|
||||||
!ShapeUtil::IsTuple(actual.shape())) {
|
!ShapeUtil::IsTuple(actual.shape())) {
|
||||||
return ::testing::AssertionFailure()
|
return ::testing::AssertionFailure()
|
||||||
<< "tuples expected expected shape = "
|
<< "tuples expected shape = " << expected.shape().ShortDebugString()
|
||||||
<< expected.shape().ShortDebugString()
|
|
||||||
<< " actual shape = " << actual.shape().ShortDebugString();
|
<< " actual shape = " << actual.shape().ShortDebugString();
|
||||||
}
|
}
|
||||||
AssertEqualShapes(expected.shape(), actual.shape());
|
AssertEqualShapes(expected.shape(), actual.shape());
|
||||||
@ -650,6 +663,32 @@ bool NearComparator::ExpectValuesNear<bfloat16>(bfloat16 expected,
|
|||||||
EXPECT_TRUE(NearTuple(expected, actual, error));
|
EXPECT_TRUE(NearTuple(expected, actual, error));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*static*/ ::testing::AssertionResult LiteralTestUtil::NearOrEqual(
|
||||||
|
const Literal& expected, const Literal& actual,
|
||||||
|
const tensorflow::gtl::optional<ErrorSpec>& error) {
|
||||||
|
bool is_tuple = ShapeUtil::IsTuple(expected.shape());
|
||||||
|
if (error.has_value()) {
|
||||||
|
if (is_tuple) {
|
||||||
|
VLOG(1) << "Expects near tuple";
|
||||||
|
return NearTuple(expected, actual, *error);
|
||||||
|
}
|
||||||
|
VLOG(1) << "Expects near";
|
||||||
|
return Near(expected, actual, *error);
|
||||||
|
}
|
||||||
|
if (is_tuple) {
|
||||||
|
VLOG(1) << "Expects equal tuple";
|
||||||
|
return EqualTuple(expected, actual);
|
||||||
|
}
|
||||||
|
VLOG(1) << "Expects equal";
|
||||||
|
return Equal(expected, actual);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*static*/ void LiteralTestUtil::ExpectNearOrEqual(
|
||||||
|
const Literal& expected, const Literal& actual,
|
||||||
|
const tensorflow::gtl::optional<ErrorSpec>& error) {
|
||||||
|
EXPECT_TRUE(NearOrEqual(expected, actual, error));
|
||||||
|
}
|
||||||
|
|
||||||
/* static */ string LiteralTestUtil::MultiIndexAsString(
|
/* static */ string LiteralTestUtil::MultiIndexAsString(
|
||||||
tensorflow::gtl::ArraySlice<int64> multi_index) {
|
tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||||
return tensorflow::strings::StrCat(
|
return tensorflow::strings::StrCat(
|
||||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/optional.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
@ -110,6 +111,10 @@ class LiteralTestUtil {
|
|||||||
static void ExpectR4EqualArray4D(const Array4D<NativeT>& expected,
|
static void ExpectR4EqualArray4D(const Array4D<NativeT>& expected,
|
||||||
const Literal& actual);
|
const Literal& actual);
|
||||||
|
|
||||||
|
// Returns whether the two tuples are equal.
|
||||||
|
static ::testing::AssertionResult EqualTuple(
|
||||||
|
const Literal& expected, const Literal& actual) TF_MUST_USE_RESULT;
|
||||||
|
|
||||||
// Expects that the values of the elements in the expected and actual tuples
|
// Expects that the values of the elements in the expected and actual tuples
|
||||||
// are equal. Tuples are matched recursively.
|
// are equal. Tuples are matched recursively.
|
||||||
static void ExpectEqualTuple(const Literal& expected, const Literal& actual);
|
static void ExpectEqualTuple(const Literal& expected, const Literal& actual);
|
||||||
@ -177,6 +182,19 @@ class LiteralTestUtil {
|
|||||||
static void ExpectNearTuple(const Literal& expected, const Literal& actual,
|
static void ExpectNearTuple(const Literal& expected, const Literal& actual,
|
||||||
const ErrorSpec& error);
|
const ErrorSpec& error);
|
||||||
|
|
||||||
|
// If the error spec is given, returns whether the expected and the actual are
|
||||||
|
// within the error bound; otherwise, returns whether they are equal. Tuples
|
||||||
|
// will be compared recursively.
|
||||||
|
static ::testing::AssertionResult NearOrEqual(
|
||||||
|
const Literal& expected, const Literal& actual,
|
||||||
|
const tensorflow::gtl::optional<ErrorSpec>& error) TF_MUST_USE_RESULT;
|
||||||
|
|
||||||
|
// If the error spec is given, expects the expected and the actual to be near;
|
||||||
|
// otherwise, expects them to be equal. Tuples will be compared recursively.
|
||||||
|
static void ExpectNearOrEqual(
|
||||||
|
const Literal& expected, const Literal& actual,
|
||||||
|
const tensorflow::gtl::optional<ErrorSpec>& error);
|
||||||
|
|
||||||
// Returns a multi-dimensional index as a string. For example: '{7, 8}' will
|
// Returns a multi-dimensional index as a string. For example: '{7, 8}' will
|
||||||
// be returned for a 2-dimensional index with dimension 0 index equal to 7,
|
// be returned for a 2-dimensional index with dimension 0 index equal to 7,
|
||||||
// dimension 1 equal to 8.
|
// dimension 1 equal to 8.
|
||||||
|
51
tensorflow/compiler/xla/tests/sample_file_test.cc
Normal file
51
tensorflow/compiler/xla/tests/sample_file_test.cc
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This demonstrates how to use hlo_test_base to create a file based testcase
|
||||||
|
// and compare results on gpu and cpu.
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class SampleFileTest : public HloTestBase {
|
||||||
|
protected:
|
||||||
|
SampleFileTest()
|
||||||
|
: HloTestBase(
|
||||||
|
/*test_platform=*/PlatformUtil::GetPlatform("gpu").ValueOrDie(),
|
||||||
|
/*reference_platform=*/PlatformUtil::GetPlatform("cpu")
|
||||||
|
.ValueOrDie()) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(SampleFileTest, Convolution) {
|
||||||
|
const string& filename = "compiler/xla/tests/isolated_convolution.hlo";
|
||||||
|
string test_srcdir = tensorflow::testing::TensorFlowSrcRoot();
|
||||||
|
EXPECT_TRUE(RunAndCompareFromFile(
|
||||||
|
tensorflow::io::JoinPath(test_srcdir, filename), ErrorSpec{0.01}));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace xla
|
66
tensorflow/compiler/xla/tests/sample_text_test.cc
Normal file
66
tensorflow/compiler/xla/tests/sample_text_test.cc
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This demonstrates how to use hlo_test_base to create textual IR based
|
||||||
|
// testcases.
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||||
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/optional.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using tensorflow::gtl::nullopt;
|
||||||
|
|
||||||
|
class SampleTextTest : public HloTestBase {};
|
||||||
|
|
||||||
|
TEST_F(SampleTextTest, Axpy) {
|
||||||
|
const string& hlo_string = R"(
|
||||||
|
HloModule axpy_module:
|
||||||
|
ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
|
||||||
|
%alpha = f32[] parameter(0)
|
||||||
|
%broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
|
||||||
|
%x = f32[2,4]{1,0} parameter(1)
|
||||||
|
%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
|
||||||
|
%y = f32[2,4]{1,0} parameter(2)
|
||||||
|
ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_string, ErrorSpec{0.0001}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SampleTextTest, Tuple) {
|
||||||
|
const string& hlo_string = R"(
|
||||||
|
HloModule TupleCreate_module:
|
||||||
|
ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) {
|
||||||
|
%v1 = f32[] parameter(0)
|
||||||
|
%v2 = f32[3]{0} parameter(1)
|
||||||
|
%v3 = f32[2,3]{1,0} parameter(2)
|
||||||
|
ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
EXPECT_TRUE(RunAndCompare(hlo_string, nullopt));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace xla
|
Loading…
Reference in New Issue
Block a user