[XLA] Make hlo_test_base support running and comparing an hlo module.
Also, - Add examples that shows how to use the hlo_test_base as to create text/file-based testcases. - Change the behavior of GetDefaultPlatform: when only one platform is present, returns that one; when two platforms are present and one of them is the interpreter, returns the other one. This is because some tests included both hlo_test_base and client_library_test_base, but now the hlo_test_base is linked with interpreter by default, which makes client_library_test_base fail getting the default platform. PiperOrigin-RevId: 178309022
This commit is contained in:
parent
7b0458a789
commit
029109b4e1
@ -39,6 +39,14 @@ namespace se = ::perftools::gputools;
|
||||
|
||||
namespace xla {
|
||||
|
||||
/*static*/ StatusOr<std::unique_ptr<HloModule>>
|
||||
HloRunner::CreateModuleFromString(const tensorflow::StringPiece hlo_string,
|
||||
const DebugOptions& debug_options) {
|
||||
HloModuleConfig config;
|
||||
config.set_debug_options(debug_options);
|
||||
return tools::Parse(hlo_string, config);
|
||||
}
|
||||
|
||||
/*static*/ StatusOr<std::unique_ptr<HloModule>>
|
||||
HloRunner::ReadModuleFromHloProtoFile(const std::string& filename,
|
||||
const DebugOptions& debug_options) {
|
||||
|
@ -35,7 +35,8 @@ namespace xla {
|
||||
|
||||
// A base class for running an HloModule. This executes the given HloModule on a
|
||||
// certain backend directly without using the client interface. HloModule can be
|
||||
// explicitly built, or loaded from a serialization file (e.g., hlo proto file).
|
||||
// explicitly built, or loaded from a serialization file (e.g., hlo proto
|
||||
// file), or parsed from a hlo textual IR string.
|
||||
class HloRunner {
|
||||
public:
|
||||
HloRunner();
|
||||
@ -44,6 +45,12 @@ class HloRunner {
|
||||
|
||||
~HloRunner();
|
||||
|
||||
// Converts an HloModule from the given hlo textual IR string (in
|
||||
// HloModule::ToString format).
|
||||
static StatusOr<std::unique_ptr<HloModule>> CreateModuleFromString(
|
||||
const tensorflow::StringPiece hlo_string,
|
||||
const DebugOptions& debug_options);
|
||||
|
||||
// Reads the proto file in xla.HloProto format, creates and returns the
|
||||
// HloModule. Will try to parse the filename as binary proto, then try as
|
||||
// text proto if that fails.
|
||||
@ -65,7 +72,8 @@ class HloRunner {
|
||||
// Executes the given module with given literals as input and returns the
|
||||
// result as a Literal. The LiteralPtr type accepts Literal* or
|
||||
// std::unique_ptr<Literal>.
|
||||
// If run_hlo_passes is true, the module will be executed without Hlo
|
||||
//
|
||||
// If run_hlo_passes is false, the module will be executed without Hlo
|
||||
// optimization.
|
||||
template <typename LiteralPtr>
|
||||
StatusOr<std::unique_ptr<Literal>> Execute(
|
||||
|
@ -33,10 +33,32 @@ namespace se = ::perftools::gputools;
|
||||
|
||||
namespace xla {
|
||||
|
||||
using tensorflow::str_util::Lowercase;
|
||||
|
||||
// Minimum supported CUDA compute capability is 3.5.
|
||||
constexpr int kMinCudaComputeCapabilityMajor = 3;
|
||||
constexpr int kMinCudaComputeCapabilityMinor = 5;
|
||||
|
||||
// The name of the interpreter platform.
|
||||
constexpr char kInterpreter[] = "interpreter";
|
||||
|
||||
namespace {
|
||||
|
||||
string CanonicalPlatformName(const string& name) {
|
||||
string platform_str = Lowercase(name);
|
||||
// "cpu" and "host" mean the same thing.
|
||||
if (platform_str == "cpu") {
|
||||
platform_str = "host";
|
||||
}
|
||||
// "gpu" and "cuda" mean the same thing.
|
||||
if (platform_str == "gpu") {
|
||||
platform_str = "cuda";
|
||||
}
|
||||
return platform_str;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/* static */ StatusOr<std::vector<se::Platform*>>
|
||||
PlatformUtil::GetSupportedPlatforms() {
|
||||
se::MultiPlatformManager::PlatformMap platform_map;
|
||||
@ -78,7 +100,7 @@ PlatformUtil::GetSupportedPlatforms() {
|
||||
return platforms;
|
||||
}
|
||||
|
||||
/* static */ StatusOr<se::Platform*> PlatformUtil::GetDefaultPlatform() {
|
||||
/* static */ StatusOr<se::Platform*> PlatformUtil::GetSolePlatform() {
|
||||
TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms());
|
||||
if (platforms.empty()) {
|
||||
return NotFound("no platforms found");
|
||||
@ -87,26 +109,42 @@ PlatformUtil::GetSupportedPlatforms() {
|
||||
}
|
||||
|
||||
// Multiple platforms present and we can't pick a reasonable default.
|
||||
auto l = [](string* out, const se::Platform* p) { out->append(p->Name()); };
|
||||
string platforms_string = tensorflow::str_util::Join(platforms, ", ", l);
|
||||
string platforms_string = tensorflow::str_util::Join(
|
||||
platforms, ", ",
|
||||
[](string* out, const se::Platform* p) { out->append(p->Name()); });
|
||||
return InvalidArgument(
|
||||
"must specify platform because more than one platform found: %s",
|
||||
platforms_string.c_str());
|
||||
}
|
||||
|
||||
/*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatform(
|
||||
const string& platform_name) {
|
||||
using tensorflow::str_util::Lowercase;
|
||||
string platform_str = Lowercase(platform_name);
|
||||
// "cpu" and "host" mean the same thing.
|
||||
if (platform_str == "cpu") {
|
||||
platform_str = "host";
|
||||
}
|
||||
// "gpu" and "cuda" mean the same thing.
|
||||
if (platform_str == "gpu") {
|
||||
platform_str = "cuda";
|
||||
/* static */ StatusOr<se::Platform*> PlatformUtil::GetDefaultPlatform() {
|
||||
TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms());
|
||||
if (platforms.empty()) {
|
||||
return NotFound("no platforms found");
|
||||
} else if (platforms.size() == 1) {
|
||||
return platforms[0];
|
||||
} else if (platforms.size() == 2) {
|
||||
for (int i = 0; i < 2; i++) {
|
||||
if (Lowercase(platforms[i]->Name()) == kInterpreter &&
|
||||
Lowercase(platforms[1 - i]->Name()) != kInterpreter) {
|
||||
return platforms[1 - i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Multiple platforms present and we can't pick a reasonable default.
|
||||
string platforms_string = tensorflow::str_util::Join(
|
||||
platforms, ", ",
|
||||
[](string* out, const se::Platform* p) { out->append(p->Name()); });
|
||||
return InvalidArgument(
|
||||
"must specify platform because more than one platform (except for the "
|
||||
"interpreter platform) found: %s",
|
||||
platforms_string.c_str());
|
||||
}
|
||||
|
||||
/*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatform(
|
||||
const string& platform_name) {
|
||||
string platform_str = CanonicalPlatformName(platform_name);
|
||||
TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms());
|
||||
for (se::Platform* platform : platforms) {
|
||||
if (Lowercase(platform->Name()) == platform_str) {
|
||||
@ -116,6 +154,32 @@ PlatformUtil::GetSupportedPlatforms() {
|
||||
return InvalidArgument("platform %s not found", platform_name.c_str());
|
||||
}
|
||||
|
||||
/*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatformExceptFor(
|
||||
const string& platform_name) {
|
||||
string platform_str = CanonicalPlatformName(platform_name);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms());
|
||||
std::vector<se::Platform*> matched;
|
||||
for (se::Platform* platform : platforms) {
|
||||
if (Lowercase(platform->Name()) != platform_name) {
|
||||
matched.push_back(platform);
|
||||
}
|
||||
}
|
||||
if (matched.empty()) {
|
||||
return InvalidArgument("unable to find platform that is not %s",
|
||||
platform_name.c_str());
|
||||
}
|
||||
if (matched.size() == 1) {
|
||||
return matched[0];
|
||||
}
|
||||
string matched_string = tensorflow::str_util::Join(
|
||||
matched, ", ",
|
||||
[](string* out, const se::Platform* p) { out->append(p->Name()); });
|
||||
return InvalidArgument(
|
||||
"found multiple platforms %s, but expected one platform except for %s",
|
||||
matched_string.c_str(), platform_name.c_str());
|
||||
}
|
||||
|
||||
// Returns whether the device underlying the given StreamExecutor is supported
|
||||
// by XLA.
|
||||
static bool IsDeviceSupported(se::StreamExecutor* executor) {
|
||||
|
@ -37,16 +37,28 @@ class PlatformUtil {
|
||||
static StatusOr<std::vector<perftools::gputools::Platform*>>
|
||||
GetSupportedPlatforms();
|
||||
|
||||
// Convenience function which returns the default supported platform. If
|
||||
// Convenience function which returns the default supported platform for
|
||||
// tests. If exactly one supported platform is present, then this platform is
|
||||
// the default platform. If exactly two platforms are present and one of them
|
||||
// is the interpreter platform, then the other platform is the default
|
||||
// platform. Otherwise returns an error.
|
||||
static StatusOr<perftools::gputools::Platform*> GetDefaultPlatform();
|
||||
|
||||
// Convenience function which returns the sole supported platform. If
|
||||
// exactly one supported platform is present, then this platform is the
|
||||
// default platform. Otherwise returns an error.
|
||||
static StatusOr<perftools::gputools::Platform*> GetDefaultPlatform();
|
||||
static StatusOr<perftools::gputools::Platform*> GetSolePlatform();
|
||||
|
||||
// Returns the platform according to the given name. Returns error if there is
|
||||
// no such platform.
|
||||
static StatusOr<perftools::gputools::Platform*> GetPlatform(
|
||||
const string& platform_name);
|
||||
|
||||
// Returns exactly one platform that does not have given name. Returns error
|
||||
// if there is no such platform, or there are multiple such platforms.
|
||||
static StatusOr<perftools::gputools::Platform*> GetPlatformExceptFor(
|
||||
const string& platform_name);
|
||||
|
||||
// Returns a vector of StreamExecutors for the given platform. The vector is
|
||||
// indexed by device ordinal (device numbering used by StreamExecutor). If an
|
||||
// element is nullptr, then the device is present by not supported by XLA.
|
||||
|
@ -105,7 +105,9 @@ cc_library(
|
||||
hdrs = ["hlo_test_base.h"],
|
||||
deps = [
|
||||
":literal_test_util",
|
||||
":test_utils",
|
||||
"//tensorflow/compiler/xla:shape_layout",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
@ -115,6 +117,9 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:computation_layout",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_runner",
|
||||
"//tensorflow/compiler/xla/service:interpreter_plugin", # reference backend
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/core:test",
|
||||
@ -1678,6 +1683,45 @@ xla_test(
|
||||
],
|
||||
)
|
||||
|
||||
# A demo of textual IR based test.
|
||||
xla_test(
|
||||
name = "sample_text_test",
|
||||
srcs = ["sample_text_test.cc"],
|
||||
# You can leave this empty if you want to test all supported backends.
|
||||
backends = [
|
||||
"cpu",
|
||||
"gpu",
|
||||
],
|
||||
deps = [
|
||||
":hlo_test_base",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
# A demo of test that loads an hlo module from a file and compares results on gpu and cpu.
|
||||
tf_cc_test(
|
||||
name = "sample_file_test",
|
||||
srcs = ["sample_file_test.cc"],
|
||||
data = ["isolated_convolution.hlo"],
|
||||
tags = ["requires-gpu-sm35"],
|
||||
deps = [
|
||||
":hlo_test_base",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla/service:cpu_plugin", # reference backend
|
||||
"//tensorflow/compiler/xla/service:gpu_plugin", # test backend
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
filegroup(
|
||||
|
@ -15,13 +15,22 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -30,18 +39,72 @@ namespace se = ::perftools::gputools;
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
using tensorflow::StringPiece;
|
||||
using tensorflow::gtl::ArraySlice;
|
||||
using tensorflow::gtl::optional;
|
||||
|
||||
constexpr char kInterpreter[] = "interpreter";
|
||||
|
||||
// Helper functions to get test and reference platforms.
|
||||
se::Platform* GetReferencePlatform() {
|
||||
auto result = PlatformUtil::GetPlatform(kInterpreter);
|
||||
TF_CHECK_OK(result.status()) << "could not get interpreter platform";
|
||||
return result.ValueOrDie();
|
||||
}
|
||||
|
||||
se::Platform* GetTestPlatform() {
|
||||
auto result = PlatformUtil::GetDefaultPlatform();
|
||||
TF_CHECK_OK(result.status()) << "could not get test platform";
|
||||
return result.ValueOrDie();
|
||||
}
|
||||
|
||||
bool ProgramShapesEqual(const ProgramShape& lhs, const ProgramShape& rhs) {
|
||||
if (lhs.parameters_size() != rhs.parameters_size()) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < lhs.parameters_size(); i++) {
|
||||
if (!ShapeUtil::Equal(lhs.parameters(i), rhs.parameters(i))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return ShapeUtil::Equal(lhs.result(), rhs.result());
|
||||
}
|
||||
|
||||
ProgramShape GetProgramShapeWithLayout(const HloModule& module) {
|
||||
ProgramShape program_shape;
|
||||
const auto* entry = module.entry_computation();
|
||||
for (const auto* param : entry->parameter_instructions()) {
|
||||
*program_shape.add_parameters() = param->shape();
|
||||
*program_shape.add_parameter_names() = param->name();
|
||||
}
|
||||
*program_shape.mutable_result() = entry->root_instruction()->shape();
|
||||
return program_shape;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
HloTestBase::HloTestBase()
|
||||
: HloTestBase(GetTestPlatform(), GetReferencePlatform()) {}
|
||||
|
||||
HloTestBase::HloTestBase(se::Platform* test_platform,
|
||||
se::Platform* reference_platform)
|
||||
: test_runner_(test_platform), reference_runner_(reference_platform) {}
|
||||
|
||||
/* static */
|
||||
std::unique_ptr<HloModule> HloTestBase::CreateNewModule() {
|
||||
HloModuleConfig config;
|
||||
config.set_debug_options(GetDebugOptionsForTest());
|
||||
return MakeUnique<HloModule>(TestName(), VersionedComputationHandle(),
|
||||
config);
|
||||
}
|
||||
|
||||
/*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() {
|
||||
auto debug_options = legacy_flags::GetDebugOptionsFromFlags();
|
||||
// TODO(b/38354253): Change tests to use Parameters instead of Constants.
|
||||
debug_options.add_xla_disable_hlo_passes("constant_folding");
|
||||
|
||||
config.set_debug_options(debug_options);
|
||||
|
||||
return MakeUnique<HloModule>(TestName(), VersionedComputationHandle(),
|
||||
config);
|
||||
return debug_options;
|
||||
}
|
||||
|
||||
StatusOr<perftools::gputools::DeviceMemoryBase> HloTestBase::Execute(
|
||||
@ -49,25 +112,168 @@ StatusOr<perftools::gputools::DeviceMemoryBase> HloTestBase::Execute(
|
||||
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
|
||||
arguments,
|
||||
Shape* result_shape) {
|
||||
return runner_.Execute(std::move(module), arguments, result_shape);
|
||||
return test_runner_.Execute(std::move(module), arguments, result_shape);
|
||||
}
|
||||
|
||||
se::DeviceMemoryBase HloTestBase::TransferToDevice(const Literal& literal) {
|
||||
return runner_.TransferToDevice(literal).ValueOrDie();
|
||||
return test_runner_.TransferToDevice(literal).ValueOrDie();
|
||||
}
|
||||
|
||||
std::unique_ptr<Literal> HloTestBase::TransferFromDevice(
|
||||
const Shape& shape, se::DeviceMemoryBase device_base) {
|
||||
return runner_.TransferFromDevice(shape, device_base).ValueOrDie();
|
||||
return test_runner_.TransferFromDevice(shape, device_base).ValueOrDie();
|
||||
}
|
||||
|
||||
std::unique_ptr<Literal> HloTestBase::ExecuteAndTransfer(
|
||||
std::unique_ptr<HloModule> module,
|
||||
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
|
||||
return runner_.ExecuteAndTransfer(std::move(module), arguments).ValueOrDie();
|
||||
return test_runner_.ExecuteAndTransfer(std::move(module), arguments)
|
||||
.ValueOrDie();
|
||||
}
|
||||
|
||||
Backend& HloTestBase::backend() { return runner_.backend(); }
|
||||
StatusOr<std::unique_ptr<HloModule>> HloTestBase::MakeReferenceModule(
|
||||
const HloModule& test_module,
|
||||
const std::function<void(HloModule*)>& reference_preprocessor) {
|
||||
std::unique_ptr<HloModule> reference_module = test_module.Clone();
|
||||
const auto& program_shape = GetProgramShapeWithLayout(test_module);
|
||||
|
||||
if (reference_preprocessor != nullptr) {
|
||||
reference_preprocessor(reference_module.get());
|
||||
if (!ProgramShapesEqual(program_shape,
|
||||
GetProgramShapeWithLayout(*reference_module))) {
|
||||
return InvalidArgument(
|
||||
"reference preprocessor must not modify the program shape");
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(VerifyHloModule(*reference_runner_.backend().platform(),
|
||||
reference_module.get()));
|
||||
return std::move(reference_module);
|
||||
}
|
||||
|
||||
template <typename LiteralPtr>
|
||||
StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
|
||||
std::unique_ptr<HloModule> module, const ArraySlice<LiteralPtr> arguments,
|
||||
const optional<ErrorSpec>& error, bool run_hlo_passes,
|
||||
const std::function<void(HloModule*)>& reference_preprocessor) {
|
||||
static_assert(
|
||||
std::is_same<Literal*, LiteralPtr>::value ||
|
||||
std::is_same<std::unique_ptr<Literal>, LiteralPtr>::value,
|
||||
"The LiteralPtr type only accepts Literal* or std::unique_ptr<Literal>.");
|
||||
TF_RETURN_IF_ERROR(
|
||||
VerifyHloModule(*test_runner_.backend().platform(), module.get()));
|
||||
TF_ASSIGN_OR_RETURN(auto reference_module,
|
||||
MakeReferenceModule(*module, reference_preprocessor));
|
||||
|
||||
// Execute on two backends.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto test,
|
||||
test_runner_.Execute(std::move(module), arguments, run_hlo_passes));
|
||||
TF_ASSIGN_OR_RETURN(auto reference,
|
||||
reference_runner_.Execute(std::move(reference_module),
|
||||
arguments, run_hlo_passes));
|
||||
return LiteralTestUtil::NearOrEqual(/*expected=*/*reference, /*actual=*/*test,
|
||||
error);
|
||||
}
|
||||
|
||||
template <typename LiteralPtr>
|
||||
::testing::AssertionResult HloTestBase::RunAndCompare(
|
||||
std::unique_ptr<HloModule> module, const ArraySlice<LiteralPtr> arguments,
|
||||
const optional<ErrorSpec>& error,
|
||||
const std::function<void(HloModule*)>& reference_preprocessor) {
|
||||
auto result =
|
||||
RunAndCompareInternal(std::move(module), arguments, error,
|
||||
/*run_hlo_passes=*/true, reference_preprocessor);
|
||||
if (!result.ok()) {
|
||||
return ::testing::AssertionFailure() << result.status();
|
||||
}
|
||||
return result.ValueOrDie();
|
||||
}
|
||||
|
||||
template <typename LiteralPtr>
|
||||
::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
|
||||
std::unique_ptr<HloModule> module, const ArraySlice<LiteralPtr> arguments,
|
||||
const optional<ErrorSpec>& error,
|
||||
const std::function<void(HloModule*)>& reference_preprocessor) {
|
||||
auto result =
|
||||
RunAndCompareInternal(std::move(module), arguments, error,
|
||||
/*run_hlo_passes=*/false, reference_preprocessor);
|
||||
if (!result.ok()) {
|
||||
return ::testing::AssertionFailure() << result.status();
|
||||
}
|
||||
return result.ValueOrDie();
|
||||
}
|
||||
|
||||
::testing::AssertionResult HloTestBase::RunAndCompare(
|
||||
std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
|
||||
const std::function<void(HloModule*)>& reference_preprocessor) {
|
||||
const auto& fake_arguments =
|
||||
MakeFakeArguments(module.get()).ConsumeValueOrDie();
|
||||
return RunAndCompare<std::unique_ptr<Literal>>(
|
||||
std::move(module), fake_arguments, error, reference_preprocessor);
|
||||
}
|
||||
|
||||
::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
|
||||
std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
|
||||
const std::function<void(HloModule*)>& reference_preprocessor) {
|
||||
const auto& fake_arguments =
|
||||
MakeFakeArguments(module.get()).ConsumeValueOrDie();
|
||||
return RunAndCompareNoHloPasses<std::unique_ptr<Literal>>(
|
||||
std::move(module), fake_arguments, error, reference_preprocessor);
|
||||
}
|
||||
|
||||
::testing::AssertionResult HloTestBase::RunAndCompare(
|
||||
const StringPiece hlo_string,
|
||||
const tensorflow::gtl::optional<ErrorSpec>& error,
|
||||
const std::function<void(HloModule*)>& reference_preprocessor) {
|
||||
auto module_or_status =
|
||||
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
|
||||
if (!module_or_status.ok()) {
|
||||
return ::testing::AssertionFailure() << "failed parsing hlo textual IR";
|
||||
}
|
||||
return RunAndCompare(module_or_status.ConsumeValueOrDie(), error,
|
||||
reference_preprocessor);
|
||||
}
|
||||
|
||||
::testing::AssertionResult HloTestBase::RunAndCompareFromFile(
|
||||
const string& filename, const tensorflow::gtl::optional<ErrorSpec>& error,
|
||||
const std::function<void(HloModule*)>& reference_preprocessor) {
|
||||
auto module_or_status =
|
||||
HloRunner::ReadModule(filename, GetDebugOptionsForTest());
|
||||
if (!module_or_status.ok()) {
|
||||
return ::testing::AssertionFailure()
|
||||
<< "failed reading hlo module from file";
|
||||
}
|
||||
return RunAndCompare(module_or_status.ConsumeValueOrDie(), error,
|
||||
reference_preprocessor);
|
||||
}
|
||||
|
||||
::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
|
||||
const StringPiece hlo_string,
|
||||
const tensorflow::gtl::optional<ErrorSpec>& error,
|
||||
const std::function<void(HloModule*)>& reference_preprocessor) {
|
||||
auto module_or_status =
|
||||
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
|
||||
if (!module_or_status.ok()) {
|
||||
return ::testing::AssertionFailure() << "failed parsing hlo textual IR";
|
||||
}
|
||||
return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error,
|
||||
reference_preprocessor);
|
||||
}
|
||||
|
||||
::testing::AssertionResult HloTestBase::RunAndCompareNoHloPassesFromFile(
|
||||
const string& filename, const tensorflow::gtl::optional<ErrorSpec>& error,
|
||||
const std::function<void(HloModule*)>& reference_preprocessor) {
|
||||
auto module_or_status =
|
||||
HloRunner::ReadModule(filename, GetDebugOptionsForTest());
|
||||
if (!module_or_status.ok()) {
|
||||
return ::testing::AssertionFailure()
|
||||
<< "failed reading hlo module from file";
|
||||
}
|
||||
return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error,
|
||||
reference_preprocessor);
|
||||
}
|
||||
|
||||
Backend& HloTestBase::backend() { return test_runner_.backend(); }
|
||||
|
||||
/* static */
|
||||
string HloTestBase::TestName() {
|
||||
|
@ -24,31 +24,74 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/computation_layout.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_runner.h"
|
||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_layout.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/lib/gtl/optional.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// A base class for tests which build and run HLO code. This is a lower level of
|
||||
// abstraction than using the client interface and enables, for one, explicitly
|
||||
// building a graph of HLO instructions to run.
|
||||
// A base class for tests which build and/or run HLO code. The class includes
|
||||
// support for running an HLO module on two platforms and compare the results.
|
||||
// This is a lower level of abstraction than using the client interface and
|
||||
// enables, for one, explicitly building a graph of HLO instructions to run.
|
||||
//
|
||||
// This can also be used to write text/file-based test cases. Note that the test
|
||||
// target is responsible for linking the needed backends. A covenient way to do
|
||||
// this is to make it an xla_test: it will generate test targets linking with
|
||||
// the respective backends, which will be used as the test backend; the
|
||||
// interpreter backend is already linked with hlo_test_base so it will be the
|
||||
// default reference backend. For example, if you want to compare both cpu vs.
|
||||
// interpreter, and gpu vs. interpreter, you can:
|
||||
//
|
||||
// xla_test (
|
||||
// name = "sample_text_test",
|
||||
// srcs = ["sample_text_test.cc"],
|
||||
// backends = [
|
||||
// "cpu",
|
||||
// "gpu",
|
||||
// ],
|
||||
// deps = [
|
||||
// "//third_party/tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
// ...
|
||||
// ],
|
||||
// )
|
||||
//
|
||||
// For a more detailed example, see "../tests/sample_text_test.cc".
|
||||
class HloTestBase : public ::testing::Test {
|
||||
protected:
|
||||
HloTestBase() {}
|
||||
// This uses the interpreter backend as the reference backend and
|
||||
// automatically finds another supported backend as the test backend. If the
|
||||
// interpreter is the only supported backend, it will be both the test backend
|
||||
// and the reference backend.
|
||||
HloTestBase();
|
||||
|
||||
// If your test doesn't use interpreter as the reference backend, you can use
|
||||
// this constructor. Note that your test target is responsible for linking in
|
||||
// both needed backends.
|
||||
HloTestBase(::perftools::gputools::Platform* test_platform,
|
||||
::perftools::gputools::Platform* reference_platform);
|
||||
|
||||
~HloTestBase() override {}
|
||||
|
||||
// Creates a new HLO module for a test. The module created will have
|
||||
// TestName() for its name; it will also automatically populate its debug
|
||||
// options from command-line flags. It's recommended to use this method to
|
||||
// create all HloModules for tests.
|
||||
// options from command-line flags. If you want a fresh HloModule object and
|
||||
// then add HloComputations to it, it's recommended to use this method in your
|
||||
// tests.
|
||||
static std::unique_ptr<HloModule> CreateNewModule();
|
||||
|
||||
// Populates debug options from command-line flags and adjusts the options for
|
||||
// testing. It is recommended to use this when you need to pass in
|
||||
// DebugOptions, e.g. when creating a module from a string or a file.
|
||||
static DebugOptions GetDebugOptionsForTest();
|
||||
|
||||
// Executes the given module and returns a global data handle.
|
||||
StatusOr<perftools::gputools::DeviceMemoryBase> Execute(
|
||||
std::unique_ptr<HloModule> module,
|
||||
@ -71,6 +114,73 @@ class HloTestBase : public ::testing::Test {
|
||||
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
|
||||
arguments);
|
||||
|
||||
// Executes the given hlo module on two backends and compares results.
|
||||
//
|
||||
// 'arguments': the input of the hlo module. The LiteralPtr type accepts
|
||||
// Literal* or std::unique_ptr<Literal>.
|
||||
//
|
||||
// 'error': if has value, expects the results to be near (within the error
|
||||
// bound). Otherwise, expects the results to be equal.
|
||||
//
|
||||
// 'reference_preprocessor': the module should be ready to run on the test
|
||||
// backend, but it might need to be tailored so that it is able to run on the
|
||||
// reference backend. Note that the program shape of the module must not be
|
||||
// modified.
|
||||
template <typename LiteralPtr>
|
||||
::testing::AssertionResult RunAndCompare(
|
||||
std::unique_ptr<HloModule> module,
|
||||
const tensorflow::gtl::ArraySlice<LiteralPtr> arguments,
|
||||
const tensorflow::gtl::optional<ErrorSpec>& error,
|
||||
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
|
||||
TF_MUST_USE_RESULT;
|
||||
|
||||
// Same as above, except that the module will be executed without Hlo
|
||||
// optimization.
|
||||
template <typename LiteralPtr>
|
||||
::testing::AssertionResult RunAndCompareNoHloPasses(
|
||||
std::unique_ptr<HloModule> module,
|
||||
const tensorflow::gtl::ArraySlice<LiteralPtr> arguments,
|
||||
const tensorflow::gtl::optional<ErrorSpec>& error,
|
||||
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
|
||||
TF_MUST_USE_RESULT;
|
||||
|
||||
// Executes an hlo module with fake inputs and compares the results.
|
||||
::testing::AssertionResult RunAndCompare(
|
||||
std::unique_ptr<HloModule> module,
|
||||
const tensorflow::gtl::optional<ErrorSpec>& error,
|
||||
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
|
||||
TF_MUST_USE_RESULT;
|
||||
|
||||
// Same as above, except that the module will be executed without Hlo
|
||||
// optimization.
|
||||
::testing::AssertionResult RunAndCompareNoHloPasses(
|
||||
std::unique_ptr<HloModule> module,
|
||||
const tensorflow::gtl::optional<ErrorSpec>& error,
|
||||
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
|
||||
TF_MUST_USE_RESULT;
|
||||
|
||||
// Convenient wrappers for executing and comparing an hlo module with fake
|
||||
// input. Module can be passed in directly, or parsed from an hlo_string,
|
||||
// or loaded from a file.
|
||||
::testing::AssertionResult RunAndCompare(
|
||||
const tensorflow::StringPiece hlo_string,
|
||||
const tensorflow::gtl::optional<ErrorSpec>& error,
|
||||
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
|
||||
TF_MUST_USE_RESULT;
|
||||
::testing::AssertionResult RunAndCompareFromFile(
|
||||
const string& filename, const tensorflow::gtl::optional<ErrorSpec>& error,
|
||||
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
|
||||
TF_MUST_USE_RESULT;
|
||||
::testing::AssertionResult RunAndCompareNoHloPasses(
|
||||
const tensorflow::StringPiece hlo_string,
|
||||
const tensorflow::gtl::optional<ErrorSpec>& error,
|
||||
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
|
||||
TF_MUST_USE_RESULT;
|
||||
::testing::AssertionResult RunAndCompareNoHloPassesFromFile(
|
||||
const string& filename, const tensorflow::gtl::optional<ErrorSpec>& error,
|
||||
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
|
||||
TF_MUST_USE_RESULT;
|
||||
|
||||
// Convenience method to force the layout of a given parameter in a module.
|
||||
// The layout of parameter number 'param_no' in the 'module' is set to
|
||||
// 'layout'.
|
||||
@ -101,12 +211,31 @@ class HloTestBase : public ::testing::Test {
|
||||
|
||||
static string TestName();
|
||||
|
||||
// Returns the backend owned by the HloRunner.
|
||||
// Returns the backend owned by the test runner.
|
||||
Backend& backend();
|
||||
|
||||
HloRunner runner_;
|
||||
HloRunner test_runner_;
|
||||
HloRunner reference_runner_;
|
||||
|
||||
ErrorSpec error_spec_{0.0001};
|
||||
|
||||
private:
|
||||
// Given the test module, makes a reference module that is ready to run on the
|
||||
// reference platform. This assumes that the given module is ready to run on
|
||||
// the test platform.
|
||||
StatusOr<std::unique_ptr<HloModule>> MakeReferenceModule(
|
||||
const HloModule& test_module,
|
||||
const std::function<void(HloModule*)>& reference_preprocessor);
|
||||
|
||||
// Runs the module on two platforms with or without running hlo passes and
|
||||
// compares the results. Returns whether the results are near or equal. If any
|
||||
// error happens before the results are computed, returns the error status.
|
||||
template <typename LiteralPtr>
|
||||
StatusOr<::testing::AssertionResult> RunAndCompareInternal(
|
||||
std::unique_ptr<HloModule> module,
|
||||
const tensorflow::gtl::ArraySlice<LiteralPtr> arguments,
|
||||
const tensorflow::gtl::optional<ErrorSpec>& error, bool run_hlo_passes,
|
||||
const std::function<void(HloModule*)>& reference_preprocessor);
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
8
tensorflow/compiler/xla/tests/isolated_convolution.hlo
Normal file
8
tensorflow/compiler/xla/tests/isolated_convolution.hlo
Normal file
@ -0,0 +1,8 @@
|
||||
HloModule convolution.167:
|
||||
|
||||
ENTRY %convolution.167 (parameter.0: f32[16,28,28,128], parameter.1: f32[3,3,128,128]) -> f32[16,28,28,128] {
|
||||
%parameter.0 = f32[16,28,28,128]{3,0,2,1} parameter(0)
|
||||
%parameter.1 = f32[3,3,128,128]{3,2,1,0} parameter(1)
|
||||
ROOT %convolution.167 = f32[16,28,28,128]{3,0,2,1} convolution(f32[16,28,28,128]{3,0,2,1} %parameter.0, f32[3,3,128,128]{3,2,1,0} %parameter.1), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01oi->b01f
|
||||
}
|
||||
|
@ -333,23 +333,37 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual,
|
||||
return result;
|
||||
}
|
||||
|
||||
/* static */ void LiteralTestUtil::ExpectEqualTuple(const Literal& expected,
|
||||
const Literal& actual) {
|
||||
/* static */ ::testing::AssertionResult LiteralTestUtil::EqualTuple(
|
||||
const Literal& expected, const Literal& actual) {
|
||||
VLOG(1) << "expected: " << expected.ToString();
|
||||
VLOG(1) << "actual: " << actual.ToString();
|
||||
|
||||
ASSERT_TRUE(ShapeUtil::IsTuple(expected.shape()));
|
||||
ASSERT_TRUE(ShapeUtil::IsTuple(actual.shape()));
|
||||
if (!ShapeUtil::IsTuple(expected.shape()) ||
|
||||
!ShapeUtil::IsTuple(actual.shape())) {
|
||||
return ::testing::AssertionFailure()
|
||||
<< "tuples expected shape = " << expected.shape().ShortDebugString()
|
||||
<< " actual shape = " << actual.shape().ShortDebugString();
|
||||
}
|
||||
AssertEqualShapes(expected.shape(), actual.shape());
|
||||
for (uint64 i = 0; i < expected.tuple_literals_size(); ++i) {
|
||||
const auto& expected_element = expected.tuple_literals(i);
|
||||
const auto& actual_element = actual.tuple_literals(i);
|
||||
if (ShapeUtil::IsTuple(expected_element.shape())) {
|
||||
ExpectEqualTuple(expected_element, actual_element);
|
||||
auto ret = EqualTuple(expected_element, actual_element);
|
||||
if (!ret) {
|
||||
return ret;
|
||||
}
|
||||
} else {
|
||||
ExpectEqual(expected_element, actual_element);
|
||||
return Equal(expected_element, actual_element);
|
||||
}
|
||||
}
|
||||
|
||||
return ::testing::AssertionSuccess();
|
||||
}
|
||||
|
||||
/* static */ void LiteralTestUtil::ExpectEqualTuple(const Literal& expected,
|
||||
const Literal& actual) {
|
||||
EXPECT_TRUE(EqualTuple(expected, actual));
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -615,8 +629,7 @@ bool NearComparator::ExpectValuesNear<bfloat16>(bfloat16 expected,
|
||||
if (!ShapeUtil::IsTuple(expected.shape()) ||
|
||||
!ShapeUtil::IsTuple(actual.shape())) {
|
||||
return ::testing::AssertionFailure()
|
||||
<< "tuples expected expected shape = "
|
||||
<< expected.shape().ShortDebugString()
|
||||
<< "tuples expected shape = " << expected.shape().ShortDebugString()
|
||||
<< " actual shape = " << actual.shape().ShortDebugString();
|
||||
}
|
||||
AssertEqualShapes(expected.shape(), actual.shape());
|
||||
@ -650,6 +663,32 @@ bool NearComparator::ExpectValuesNear<bfloat16>(bfloat16 expected,
|
||||
EXPECT_TRUE(NearTuple(expected, actual, error));
|
||||
}
|
||||
|
||||
/*static*/ ::testing::AssertionResult LiteralTestUtil::NearOrEqual(
|
||||
const Literal& expected, const Literal& actual,
|
||||
const tensorflow::gtl::optional<ErrorSpec>& error) {
|
||||
bool is_tuple = ShapeUtil::IsTuple(expected.shape());
|
||||
if (error.has_value()) {
|
||||
if (is_tuple) {
|
||||
VLOG(1) << "Expects near tuple";
|
||||
return NearTuple(expected, actual, *error);
|
||||
}
|
||||
VLOG(1) << "Expects near";
|
||||
return Near(expected, actual, *error);
|
||||
}
|
||||
if (is_tuple) {
|
||||
VLOG(1) << "Expects equal tuple";
|
||||
return EqualTuple(expected, actual);
|
||||
}
|
||||
VLOG(1) << "Expects equal";
|
||||
return Equal(expected, actual);
|
||||
}
|
||||
|
||||
/*static*/ void LiteralTestUtil::ExpectNearOrEqual(
|
||||
const Literal& expected, const Literal& actual,
|
||||
const tensorflow::gtl::optional<ErrorSpec>& error) {
|
||||
EXPECT_TRUE(NearOrEqual(expected, actual, error));
|
||||
}
|
||||
|
||||
/* static */ string LiteralTestUtil::MultiIndexAsString(
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||
return tensorflow::strings::StrCat(
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/lib/gtl/optional.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -110,6 +111,10 @@ class LiteralTestUtil {
|
||||
static void ExpectR4EqualArray4D(const Array4D<NativeT>& expected,
|
||||
const Literal& actual);
|
||||
|
||||
// Returns whether the two tuples are equal.
|
||||
static ::testing::AssertionResult EqualTuple(
|
||||
const Literal& expected, const Literal& actual) TF_MUST_USE_RESULT;
|
||||
|
||||
// Expects that the values of the elements in the expected and actual tuples
|
||||
// are equal. Tuples are matched recursively.
|
||||
static void ExpectEqualTuple(const Literal& expected, const Literal& actual);
|
||||
@ -177,6 +182,19 @@ class LiteralTestUtil {
|
||||
static void ExpectNearTuple(const Literal& expected, const Literal& actual,
|
||||
const ErrorSpec& error);
|
||||
|
||||
// If the error spec is given, returns whether the expected and the actual are
|
||||
// within the error bound; otherwise, returns whether they are equal. Tuples
|
||||
// will be compared recursively.
|
||||
static ::testing::AssertionResult NearOrEqual(
|
||||
const Literal& expected, const Literal& actual,
|
||||
const tensorflow::gtl::optional<ErrorSpec>& error) TF_MUST_USE_RESULT;
|
||||
|
||||
// If the error spec is given, expects the expected and the actual to be near;
|
||||
// otherwise, expects them to be equal. Tuples will be compared recursively.
|
||||
static void ExpectNearOrEqual(
|
||||
const Literal& expected, const Literal& actual,
|
||||
const tensorflow::gtl::optional<ErrorSpec>& error);
|
||||
|
||||
// Returns a multi-dimensional index as a string. For example: '{7, 8}' will
|
||||
// be returned for a 2-dimensional index with dimension 0 index equal to 7,
|
||||
// dimension 1 equal to 8.
|
||||
|
51
tensorflow/compiler/xla/tests/sample_file_test.cc
Normal file
51
tensorflow/compiler/xla/tests/sample_file_test.cc
Normal file
@ -0,0 +1,51 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This demonstrates how to use hlo_test_base to create a file based testcase
|
||||
// and compare results on gpu and cpu.
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class SampleFileTest : public HloTestBase {
|
||||
protected:
|
||||
SampleFileTest()
|
||||
: HloTestBase(
|
||||
/*test_platform=*/PlatformUtil::GetPlatform("gpu").ValueOrDie(),
|
||||
/*reference_platform=*/PlatformUtil::GetPlatform("cpu")
|
||||
.ValueOrDie()) {}
|
||||
};
|
||||
|
||||
TEST_F(SampleFileTest, Convolution) {
|
||||
const string& filename = "compiler/xla/tests/isolated_convolution.hlo";
|
||||
string test_srcdir = tensorflow::testing::TensorFlowSrcRoot();
|
||||
EXPECT_TRUE(RunAndCompareFromFile(
|
||||
tensorflow::io::JoinPath(test_srcdir, filename), ErrorSpec{0.01}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
66
tensorflow/compiler/xla/tests/sample_text_test.cc
Normal file
66
tensorflow/compiler/xla/tests/sample_text_test.cc
Normal file
@ -0,0 +1,66 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This demonstrates how to use hlo_test_base to create textual IR based
|
||||
// testcases.
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/lib/gtl/optional.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using tensorflow::gtl::nullopt;
|
||||
|
||||
class SampleTextTest : public HloTestBase {};
|
||||
|
||||
TEST_F(SampleTextTest, Axpy) {
|
||||
const string& hlo_string = R"(
|
||||
HloModule axpy_module:
|
||||
ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
|
||||
%alpha = f32[] parameter(0)
|
||||
%broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
|
||||
%x = f32[2,4]{1,0} parameter(1)
|
||||
%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
|
||||
%y = f32[2,4]{1,0} parameter(2)
|
||||
ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
|
||||
}
|
||||
)";
|
||||
EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_string, ErrorSpec{0.0001}));
|
||||
}
|
||||
|
||||
TEST_F(SampleTextTest, Tuple) {
|
||||
const string& hlo_string = R"(
|
||||
HloModule TupleCreate_module:
|
||||
ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) {
|
||||
%v1 = f32[] parameter(0)
|
||||
%v2 = f32[3]{0} parameter(1)
|
||||
%v3 = f32[2,3]{1,0} parameter(2)
|
||||
ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3)
|
||||
}
|
||||
)";
|
||||
EXPECT_TRUE(RunAndCompare(hlo_string, nullopt));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
Loading…
Reference in New Issue
Block a user