diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 6b6d48233a7..4a7caf3ebd8 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -39,6 +39,14 @@ namespace se = ::perftools::gputools; namespace xla { +/*static*/ StatusOr> +HloRunner::CreateModuleFromString(const tensorflow::StringPiece hlo_string, + const DebugOptions& debug_options) { + HloModuleConfig config; + config.set_debug_options(debug_options); + return tools::Parse(hlo_string, config); +} + /*static*/ StatusOr> HloRunner::ReadModuleFromHloProtoFile(const std::string& filename, const DebugOptions& debug_options) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index 95cddafc91f..a65c66fd4b6 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -35,7 +35,8 @@ namespace xla { // A base class for running an HloModule. This executes the given HloModule on a // certain backend directly without using the client interface. HloModule can be -// explicitly built, or loaded from a serialization file (e.g., hlo proto file). +// explicitly built, or loaded from a serialization file (e.g., hlo proto +// file), or parsed from a hlo textual IR string. class HloRunner { public: HloRunner(); @@ -44,6 +45,12 @@ class HloRunner { ~HloRunner(); + // Converts an HloModule from the given hlo textual IR string (in + // HloModule::ToString format). + static StatusOr> CreateModuleFromString( + const tensorflow::StringPiece hlo_string, + const DebugOptions& debug_options); + // Reads the proto file in xla.HloProto format, creates and returns the // HloModule. Will try to parse the filename as binary proto, then try as // text proto if that fails. @@ -65,7 +72,8 @@ class HloRunner { // Executes the given module with given literals as input and returns the // result as a Literal. The LiteralPtr type accepts Literal* or // std::unique_ptr. - // If run_hlo_passes is true, the module will be executed without Hlo + // + // If run_hlo_passes is false, the module will be executed without Hlo // optimization. template StatusOr> Execute( diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index 63f3bfb36ce..aa974ee61a2 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -33,10 +33,32 @@ namespace se = ::perftools::gputools; namespace xla { +using tensorflow::str_util::Lowercase; + // Minimum supported CUDA compute capability is 3.5. constexpr int kMinCudaComputeCapabilityMajor = 3; constexpr int kMinCudaComputeCapabilityMinor = 5; +// The name of the interpreter platform. +constexpr char kInterpreter[] = "interpreter"; + +namespace { + +string CanonicalPlatformName(const string& name) { + string platform_str = Lowercase(name); + // "cpu" and "host" mean the same thing. + if (platform_str == "cpu") { + platform_str = "host"; + } + // "gpu" and "cuda" mean the same thing. + if (platform_str == "gpu") { + platform_str = "cuda"; + } + return platform_str; +} + +} // namespace + /* static */ StatusOr> PlatformUtil::GetSupportedPlatforms() { se::MultiPlatformManager::PlatformMap platform_map; @@ -78,7 +100,7 @@ PlatformUtil::GetSupportedPlatforms() { return platforms; } -/* static */ StatusOr PlatformUtil::GetDefaultPlatform() { +/* static */ StatusOr PlatformUtil::GetSolePlatform() { TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms()); if (platforms.empty()) { return NotFound("no platforms found"); @@ -87,26 +109,42 @@ PlatformUtil::GetSupportedPlatforms() { } // Multiple platforms present and we can't pick a reasonable default. - auto l = [](string* out, const se::Platform* p) { out->append(p->Name()); }; - string platforms_string = tensorflow::str_util::Join(platforms, ", ", l); + string platforms_string = tensorflow::str_util::Join( + platforms, ", ", + [](string* out, const se::Platform* p) { out->append(p->Name()); }); return InvalidArgument( "must specify platform because more than one platform found: %s", platforms_string.c_str()); } -/*static*/ StatusOr PlatformUtil::GetPlatform( - const string& platform_name) { - using tensorflow::str_util::Lowercase; - string platform_str = Lowercase(platform_name); - // "cpu" and "host" mean the same thing. - if (platform_str == "cpu") { - platform_str = "host"; - } - // "gpu" and "cuda" mean the same thing. - if (platform_str == "gpu") { - platform_str = "cuda"; +/* static */ StatusOr PlatformUtil::GetDefaultPlatform() { + TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms()); + if (platforms.empty()) { + return NotFound("no platforms found"); + } else if (platforms.size() == 1) { + return platforms[0]; + } else if (platforms.size() == 2) { + for (int i = 0; i < 2; i++) { + if (Lowercase(platforms[i]->Name()) == kInterpreter && + Lowercase(platforms[1 - i]->Name()) != kInterpreter) { + return platforms[1 - i]; + } + } } + // Multiple platforms present and we can't pick a reasonable default. + string platforms_string = tensorflow::str_util::Join( + platforms, ", ", + [](string* out, const se::Platform* p) { out->append(p->Name()); }); + return InvalidArgument( + "must specify platform because more than one platform (except for the " + "interpreter platform) found: %s", + platforms_string.c_str()); +} + +/*static*/ StatusOr PlatformUtil::GetPlatform( + const string& platform_name) { + string platform_str = CanonicalPlatformName(platform_name); TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms()); for (se::Platform* platform : platforms) { if (Lowercase(platform->Name()) == platform_str) { @@ -116,6 +154,32 @@ PlatformUtil::GetSupportedPlatforms() { return InvalidArgument("platform %s not found", platform_name.c_str()); } +/*static*/ StatusOr PlatformUtil::GetPlatformExceptFor( + const string& platform_name) { + string platform_str = CanonicalPlatformName(platform_name); + + TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms()); + std::vector matched; + for (se::Platform* platform : platforms) { + if (Lowercase(platform->Name()) != platform_name) { + matched.push_back(platform); + } + } + if (matched.empty()) { + return InvalidArgument("unable to find platform that is not %s", + platform_name.c_str()); + } + if (matched.size() == 1) { + return matched[0]; + } + string matched_string = tensorflow::str_util::Join( + matched, ", ", + [](string* out, const se::Platform* p) { out->append(p->Name()); }); + return InvalidArgument( + "found multiple platforms %s, but expected one platform except for %s", + matched_string.c_str(), platform_name.c_str()); +} + // Returns whether the device underlying the given StreamExecutor is supported // by XLA. static bool IsDeviceSupported(se::StreamExecutor* executor) { diff --git a/tensorflow/compiler/xla/service/platform_util.h b/tensorflow/compiler/xla/service/platform_util.h index a59d4ffe87f..69188820a70 100644 --- a/tensorflow/compiler/xla/service/platform_util.h +++ b/tensorflow/compiler/xla/service/platform_util.h @@ -37,16 +37,28 @@ class PlatformUtil { static StatusOr> GetSupportedPlatforms(); - // Convenience function which returns the default supported platform. If + // Convenience function which returns the default supported platform for + // tests. If exactly one supported platform is present, then this platform is + // the default platform. If exactly two platforms are present and one of them + // is the interpreter platform, then the other platform is the default + // platform. Otherwise returns an error. + static StatusOr GetDefaultPlatform(); + + // Convenience function which returns the sole supported platform. If // exactly one supported platform is present, then this platform is the // default platform. Otherwise returns an error. - static StatusOr GetDefaultPlatform(); + static StatusOr GetSolePlatform(); // Returns the platform according to the given name. Returns error if there is // no such platform. static StatusOr GetPlatform( const string& platform_name); + // Returns exactly one platform that does not have given name. Returns error + // if there is no such platform, or there are multiple such platforms. + static StatusOr GetPlatformExceptFor( + const string& platform_name); + // Returns a vector of StreamExecutors for the given platform. The vector is // indexed by device ordinal (device numbering used by StreamExecutor). If an // element is nullptr, then the device is present by not supported by XLA. diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index b99e046b9bc..24f4a9d05ad 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -105,7 +105,9 @@ cc_library( hdrs = ["hlo_test_base.h"], deps = [ ":literal_test_util", + ":test_utils", "//tensorflow/compiler/xla:shape_layout", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -115,6 +117,9 @@ cc_library( "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/compiler/xla/service:interpreter_plugin", # reference backend + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", @@ -1678,6 +1683,45 @@ xla_test( ], ) +# A demo of textual IR based test. +xla_test( + name = "sample_text_test", + srcs = ["sample_text_test.cc"], + # You can leave this empty if you want to test all supported backends. + backends = [ + "cpu", + "gpu", + ], + deps = [ + ":hlo_test_base", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + ], +) + +# A demo of test that loads an hlo module from a file and compares results on gpu and cpu. +tf_cc_test( + name = "sample_file_test", + srcs = ["sample_file_test.cc"], + data = ["isolated_convolution.hlo"], + tags = ["requires-gpu-sm35"], + deps = [ + ":hlo_test_base", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/service:cpu_plugin", # reference backend + "//tensorflow/compiler/xla/service:gpu_plugin", # test backend + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index d73c05ff925..e7a18828db0 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -15,13 +15,22 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include #include #include #include +#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -30,18 +39,72 @@ namespace se = ::perftools::gputools; namespace xla { +namespace { + +using tensorflow::StringPiece; +using tensorflow::gtl::ArraySlice; +using tensorflow::gtl::optional; + +constexpr char kInterpreter[] = "interpreter"; + +// Helper functions to get test and reference platforms. +se::Platform* GetReferencePlatform() { + auto result = PlatformUtil::GetPlatform(kInterpreter); + TF_CHECK_OK(result.status()) << "could not get interpreter platform"; + return result.ValueOrDie(); +} + +se::Platform* GetTestPlatform() { + auto result = PlatformUtil::GetDefaultPlatform(); + TF_CHECK_OK(result.status()) << "could not get test platform"; + return result.ValueOrDie(); +} + +bool ProgramShapesEqual(const ProgramShape& lhs, const ProgramShape& rhs) { + if (lhs.parameters_size() != rhs.parameters_size()) { + return false; + } + for (int i = 0; i < lhs.parameters_size(); i++) { + if (!ShapeUtil::Equal(lhs.parameters(i), rhs.parameters(i))) { + return false; + } + } + return ShapeUtil::Equal(lhs.result(), rhs.result()); +} + +ProgramShape GetProgramShapeWithLayout(const HloModule& module) { + ProgramShape program_shape; + const auto* entry = module.entry_computation(); + for (const auto* param : entry->parameter_instructions()) { + *program_shape.add_parameters() = param->shape(); + *program_shape.add_parameter_names() = param->name(); + } + *program_shape.mutable_result() = entry->root_instruction()->shape(); + return program_shape; +} + +} // namespace + +HloTestBase::HloTestBase() + : HloTestBase(GetTestPlatform(), GetReferencePlatform()) {} + +HloTestBase::HloTestBase(se::Platform* test_platform, + se::Platform* reference_platform) + : test_runner_(test_platform), reference_runner_(reference_platform) {} + /* static */ std::unique_ptr HloTestBase::CreateNewModule() { HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + return MakeUnique(TestName(), VersionedComputationHandle(), + config); +} +/*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() { auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); // TODO(b/38354253): Change tests to use Parameters instead of Constants. debug_options.add_xla_disable_hlo_passes("constant_folding"); - - config.set_debug_options(debug_options); - - return MakeUnique(TestName(), VersionedComputationHandle(), - config); + return debug_options; } StatusOr HloTestBase::Execute( @@ -49,25 +112,168 @@ StatusOr HloTestBase::Execute( tensorflow::gtl::ArraySlice arguments, Shape* result_shape) { - return runner_.Execute(std::move(module), arguments, result_shape); + return test_runner_.Execute(std::move(module), arguments, result_shape); } se::DeviceMemoryBase HloTestBase::TransferToDevice(const Literal& literal) { - return runner_.TransferToDevice(literal).ValueOrDie(); + return test_runner_.TransferToDevice(literal).ValueOrDie(); } std::unique_ptr HloTestBase::TransferFromDevice( const Shape& shape, se::DeviceMemoryBase device_base) { - return runner_.TransferFromDevice(shape, device_base).ValueOrDie(); + return test_runner_.TransferFromDevice(shape, device_base).ValueOrDie(); } std::unique_ptr HloTestBase::ExecuteAndTransfer( std::unique_ptr module, tensorflow::gtl::ArraySlice arguments) { - return runner_.ExecuteAndTransfer(std::move(module), arguments).ValueOrDie(); + return test_runner_.ExecuteAndTransfer(std::move(module), arguments) + .ValueOrDie(); } -Backend& HloTestBase::backend() { return runner_.backend(); } +StatusOr> HloTestBase::MakeReferenceModule( + const HloModule& test_module, + const std::function& reference_preprocessor) { + std::unique_ptr reference_module = test_module.Clone(); + const auto& program_shape = GetProgramShapeWithLayout(test_module); + + if (reference_preprocessor != nullptr) { + reference_preprocessor(reference_module.get()); + if (!ProgramShapesEqual(program_shape, + GetProgramShapeWithLayout(*reference_module))) { + return InvalidArgument( + "reference preprocessor must not modify the program shape"); + } + } + TF_RETURN_IF_ERROR(VerifyHloModule(*reference_runner_.backend().platform(), + reference_module.get())); + return std::move(reference_module); +} + +template +StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( + std::unique_ptr module, const ArraySlice arguments, + const optional& error, bool run_hlo_passes, + const std::function& reference_preprocessor) { + static_assert( + std::is_same::value || + std::is_same, LiteralPtr>::value, + "The LiteralPtr type only accepts Literal* or std::unique_ptr."); + TF_RETURN_IF_ERROR( + VerifyHloModule(*test_runner_.backend().platform(), module.get())); + TF_ASSIGN_OR_RETURN(auto reference_module, + MakeReferenceModule(*module, reference_preprocessor)); + + // Execute on two backends. + TF_ASSIGN_OR_RETURN( + auto test, + test_runner_.Execute(std::move(module), arguments, run_hlo_passes)); + TF_ASSIGN_OR_RETURN(auto reference, + reference_runner_.Execute(std::move(reference_module), + arguments, run_hlo_passes)); + return LiteralTestUtil::NearOrEqual(/*expected=*/*reference, /*actual=*/*test, + error); +} + +template +::testing::AssertionResult HloTestBase::RunAndCompare( + std::unique_ptr module, const ArraySlice arguments, + const optional& error, + const std::function& reference_preprocessor) { + auto result = + RunAndCompareInternal(std::move(module), arguments, error, + /*run_hlo_passes=*/true, reference_preprocessor); + if (!result.ok()) { + return ::testing::AssertionFailure() << result.status(); + } + return result.ValueOrDie(); +} + +template +::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( + std::unique_ptr module, const ArraySlice arguments, + const optional& error, + const std::function& reference_preprocessor) { + auto result = + RunAndCompareInternal(std::move(module), arguments, error, + /*run_hlo_passes=*/false, reference_preprocessor); + if (!result.ok()) { + return ::testing::AssertionFailure() << result.status(); + } + return result.ValueOrDie(); +} + +::testing::AssertionResult HloTestBase::RunAndCompare( + std::unique_ptr module, const optional& error, + const std::function& reference_preprocessor) { + const auto& fake_arguments = + MakeFakeArguments(module.get()).ConsumeValueOrDie(); + return RunAndCompare>( + std::move(module), fake_arguments, error, reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( + std::unique_ptr module, const optional& error, + const std::function& reference_preprocessor) { + const auto& fake_arguments = + MakeFakeArguments(module.get()).ConsumeValueOrDie(); + return RunAndCompareNoHloPasses>( + std::move(module), fake_arguments, error, reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompare( + const StringPiece hlo_string, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor) { + auto module_or_status = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() << "failed parsing hlo textual IR"; + } + return RunAndCompare(module_or_status.ConsumeValueOrDie(), error, + reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompareFromFile( + const string& filename, const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor) { + auto module_or_status = + HloRunner::ReadModule(filename, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() + << "failed reading hlo module from file"; + } + return RunAndCompare(module_or_status.ConsumeValueOrDie(), error, + reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( + const StringPiece hlo_string, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor) { + auto module_or_status = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() << "failed parsing hlo textual IR"; + } + return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error, + reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompareNoHloPassesFromFile( + const string& filename, const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor) { + auto module_or_status = + HloRunner::ReadModule(filename, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() + << "failed reading hlo module from file"; + } + return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error, + reference_preprocessor); +} + +Backend& HloTestBase::backend() { return test_runner_.backend(); } /* static */ string HloTestBase::TestName() { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 7f068dce36b..3cbbb7aa247 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -24,31 +24,74 @@ limitations under the License. #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/test.h" namespace xla { -// A base class for tests which build and run HLO code. This is a lower level of -// abstraction than using the client interface and enables, for one, explicitly -// building a graph of HLO instructions to run. +// A base class for tests which build and/or run HLO code. The class includes +// support for running an HLO module on two platforms and compare the results. +// This is a lower level of abstraction than using the client interface and +// enables, for one, explicitly building a graph of HLO instructions to run. +// +// This can also be used to write text/file-based test cases. Note that the test +// target is responsible for linking the needed backends. A covenient way to do +// this is to make it an xla_test: it will generate test targets linking with +// the respective backends, which will be used as the test backend; the +// interpreter backend is already linked with hlo_test_base so it will be the +// default reference backend. For example, if you want to compare both cpu vs. +// interpreter, and gpu vs. interpreter, you can: +// +// xla_test ( +// name = "sample_text_test", +// srcs = ["sample_text_test.cc"], +// backends = [ +// "cpu", +// "gpu", +// ], +// deps = [ +// "//third_party/tensorflow/compiler/xla/tests:hlo_test_base", +// ... +// ], +// ) +// +// For a more detailed example, see "../tests/sample_text_test.cc". class HloTestBase : public ::testing::Test { protected: - HloTestBase() {} + // This uses the interpreter backend as the reference backend and + // automatically finds another supported backend as the test backend. If the + // interpreter is the only supported backend, it will be both the test backend + // and the reference backend. + HloTestBase(); + + // If your test doesn't use interpreter as the reference backend, you can use + // this constructor. Note that your test target is responsible for linking in + // both needed backends. + HloTestBase(::perftools::gputools::Platform* test_platform, + ::perftools::gputools::Platform* reference_platform); ~HloTestBase() override {} // Creates a new HLO module for a test. The module created will have // TestName() for its name; it will also automatically populate its debug - // options from command-line flags. It's recommended to use this method to - // create all HloModules for tests. + // options from command-line flags. If you want a fresh HloModule object and + // then add HloComputations to it, it's recommended to use this method in your + // tests. static std::unique_ptr CreateNewModule(); + // Populates debug options from command-line flags and adjusts the options for + // testing. It is recommended to use this when you need to pass in + // DebugOptions, e.g. when creating a module from a string or a file. + static DebugOptions GetDebugOptionsForTest(); + // Executes the given module and returns a global data handle. StatusOr Execute( std::unique_ptr module, @@ -71,6 +114,73 @@ class HloTestBase : public ::testing::Test { tensorflow::gtl::ArraySlice arguments); + // Executes the given hlo module on two backends and compares results. + // + // 'arguments': the input of the hlo module. The LiteralPtr type accepts + // Literal* or std::unique_ptr. + // + // 'error': if has value, expects the results to be near (within the error + // bound). Otherwise, expects the results to be equal. + // + // 'reference_preprocessor': the module should be ready to run on the test + // backend, but it might need to be tailored so that it is able to run on the + // reference backend. Note that the program shape of the module must not be + // modified. + template + ::testing::AssertionResult RunAndCompare( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + + // Same as above, except that the module will be executed without Hlo + // optimization. + template + ::testing::AssertionResult RunAndCompareNoHloPasses( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + + // Executes an hlo module with fake inputs and compares the results. + ::testing::AssertionResult RunAndCompare( + std::unique_ptr module, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + + // Same as above, except that the module will be executed without Hlo + // optimization. + ::testing::AssertionResult RunAndCompareNoHloPasses( + std::unique_ptr module, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + + // Convenient wrappers for executing and comparing an hlo module with fake + // input. Module can be passed in directly, or parsed from an hlo_string, + // or loaded from a file. + ::testing::AssertionResult RunAndCompare( + const tensorflow::StringPiece hlo_string, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + ::testing::AssertionResult RunAndCompareFromFile( + const string& filename, const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + ::testing::AssertionResult RunAndCompareNoHloPasses( + const tensorflow::StringPiece hlo_string, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + ::testing::AssertionResult RunAndCompareNoHloPassesFromFile( + const string& filename, const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + // Convenience method to force the layout of a given parameter in a module. // The layout of parameter number 'param_no' in the 'module' is set to // 'layout'. @@ -101,12 +211,31 @@ class HloTestBase : public ::testing::Test { static string TestName(); - // Returns the backend owned by the HloRunner. + // Returns the backend owned by the test runner. Backend& backend(); - HloRunner runner_; + HloRunner test_runner_; + HloRunner reference_runner_; ErrorSpec error_spec_{0.0001}; + + private: + // Given the test module, makes a reference module that is ready to run on the + // reference platform. This assumes that the given module is ready to run on + // the test platform. + StatusOr> MakeReferenceModule( + const HloModule& test_module, + const std::function& reference_preprocessor); + + // Runs the module on two platforms with or without running hlo passes and + // compares the results. Returns whether the results are near or equal. If any + // error happens before the results are computed, returns the error status. + template + StatusOr<::testing::AssertionResult> RunAndCompareInternal( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::optional& error, bool run_hlo_passes, + const std::function& reference_preprocessor); }; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/isolated_convolution.hlo b/tensorflow/compiler/xla/tests/isolated_convolution.hlo new file mode 100644 index 00000000000..9452780930e --- /dev/null +++ b/tensorflow/compiler/xla/tests/isolated_convolution.hlo @@ -0,0 +1,8 @@ +HloModule convolution.167: + +ENTRY %convolution.167 (parameter.0: f32[16,28,28,128], parameter.1: f32[3,3,128,128]) -> f32[16,28,28,128] { + %parameter.0 = f32[16,28,28,128]{3,0,2,1} parameter(0) + %parameter.1 = f32[3,3,128,128]{3,2,1,0} parameter(1) + ROOT %convolution.167 = f32[16,28,28,128]{3,0,2,1} convolution(f32[16,28,28,128]{3,0,2,1} %parameter.0, f32[3,3,128,128]{3,2,1,0} %parameter.1), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01oi->b01f +} + diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index e1a948c096a..bf6631a4310 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -333,23 +333,37 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, return result; } -/* static */ void LiteralTestUtil::ExpectEqualTuple(const Literal& expected, - const Literal& actual) { +/* static */ ::testing::AssertionResult LiteralTestUtil::EqualTuple( + const Literal& expected, const Literal& actual) { VLOG(1) << "expected: " << expected.ToString(); VLOG(1) << "actual: " << actual.ToString(); - ASSERT_TRUE(ShapeUtil::IsTuple(expected.shape())); - ASSERT_TRUE(ShapeUtil::IsTuple(actual.shape())); + if (!ShapeUtil::IsTuple(expected.shape()) || + !ShapeUtil::IsTuple(actual.shape())) { + return ::testing::AssertionFailure() + << "tuples expected shape = " << expected.shape().ShortDebugString() + << " actual shape = " << actual.shape().ShortDebugString(); + } AssertEqualShapes(expected.shape(), actual.shape()); for (uint64 i = 0; i < expected.tuple_literals_size(); ++i) { const auto& expected_element = expected.tuple_literals(i); const auto& actual_element = actual.tuple_literals(i); if (ShapeUtil::IsTuple(expected_element.shape())) { - ExpectEqualTuple(expected_element, actual_element); + auto ret = EqualTuple(expected_element, actual_element); + if (!ret) { + return ret; + } } else { - ExpectEqual(expected_element, actual_element); + return Equal(expected_element, actual_element); } } + + return ::testing::AssertionSuccess(); +} + +/* static */ void LiteralTestUtil::ExpectEqualTuple(const Literal& expected, + const Literal& actual) { + EXPECT_TRUE(EqualTuple(expected, actual)); } namespace { @@ -615,8 +629,7 @@ bool NearComparator::ExpectValuesNear(bfloat16 expected, if (!ShapeUtil::IsTuple(expected.shape()) || !ShapeUtil::IsTuple(actual.shape())) { return ::testing::AssertionFailure() - << "tuples expected expected shape = " - << expected.shape().ShortDebugString() + << "tuples expected shape = " << expected.shape().ShortDebugString() << " actual shape = " << actual.shape().ShortDebugString(); } AssertEqualShapes(expected.shape(), actual.shape()); @@ -650,6 +663,32 @@ bool NearComparator::ExpectValuesNear(bfloat16 expected, EXPECT_TRUE(NearTuple(expected, actual, error)); } +/*static*/ ::testing::AssertionResult LiteralTestUtil::NearOrEqual( + const Literal& expected, const Literal& actual, + const tensorflow::gtl::optional& error) { + bool is_tuple = ShapeUtil::IsTuple(expected.shape()); + if (error.has_value()) { + if (is_tuple) { + VLOG(1) << "Expects near tuple"; + return NearTuple(expected, actual, *error); + } + VLOG(1) << "Expects near"; + return Near(expected, actual, *error); + } + if (is_tuple) { + VLOG(1) << "Expects equal tuple"; + return EqualTuple(expected, actual); + } + VLOG(1) << "Expects equal"; + return Equal(expected, actual); +} + +/*static*/ void LiteralTestUtil::ExpectNearOrEqual( + const Literal& expected, const Literal& actual, + const tensorflow::gtl::optional& error) { + EXPECT_TRUE(NearOrEqual(expected, actual, error)); +} + /* static */ string LiteralTestUtil::MultiIndexAsString( tensorflow::gtl::ArraySlice multi_index) { return tensorflow::strings::StrCat( diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index bf8c92f16dd..f53553c7017 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -110,6 +111,10 @@ class LiteralTestUtil { static void ExpectR4EqualArray4D(const Array4D& expected, const Literal& actual); + // Returns whether the two tuples are equal. + static ::testing::AssertionResult EqualTuple( + const Literal& expected, const Literal& actual) TF_MUST_USE_RESULT; + // Expects that the values of the elements in the expected and actual tuples // are equal. Tuples are matched recursively. static void ExpectEqualTuple(const Literal& expected, const Literal& actual); @@ -177,6 +182,19 @@ class LiteralTestUtil { static void ExpectNearTuple(const Literal& expected, const Literal& actual, const ErrorSpec& error); + // If the error spec is given, returns whether the expected and the actual are + // within the error bound; otherwise, returns whether they are equal. Tuples + // will be compared recursively. + static ::testing::AssertionResult NearOrEqual( + const Literal& expected, const Literal& actual, + const tensorflow::gtl::optional& error) TF_MUST_USE_RESULT; + + // If the error spec is given, expects the expected and the actual to be near; + // otherwise, expects them to be equal. Tuples will be compared recursively. + static void ExpectNearOrEqual( + const Literal& expected, const Literal& actual, + const tensorflow::gtl::optional& error); + // Returns a multi-dimensional index as a string. For example: '{7, 8}' will // be returned for a 2-dimensional index with dimension 0 index equal to 7, // dimension 1 equal to 8. diff --git a/tensorflow/compiler/xla/tests/sample_file_test.cc b/tensorflow/compiler/xla/tests/sample_file_test.cc new file mode 100644 index 00000000000..31b104f4e37 --- /dev/null +++ b/tensorflow/compiler/xla/tests/sample_file_test.cc @@ -0,0 +1,51 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This demonstrates how to use hlo_test_base to create a file based testcase +// and compare results on gpu and cpu. + +#include +#include + +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class SampleFileTest : public HloTestBase { + protected: + SampleFileTest() + : HloTestBase( + /*test_platform=*/PlatformUtil::GetPlatform("gpu").ValueOrDie(), + /*reference_platform=*/PlatformUtil::GetPlatform("cpu") + .ValueOrDie()) {} +}; + +TEST_F(SampleFileTest, Convolution) { + const string& filename = "compiler/xla/tests/isolated_convolution.hlo"; + string test_srcdir = tensorflow::testing::TensorFlowSrcRoot(); + EXPECT_TRUE(RunAndCompareFromFile( + tensorflow::io::JoinPath(test_srcdir, filename), ErrorSpec{0.01})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/sample_text_test.cc b/tensorflow/compiler/xla/tests/sample_text_test.cc new file mode 100644 index 00000000000..b4f2b74e3dc --- /dev/null +++ b/tensorflow/compiler/xla/tests/sample_text_test.cc @@ -0,0 +1,66 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This demonstrates how to use hlo_test_base to create textual IR based +// testcases. + +#include +#include + +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +using tensorflow::gtl::nullopt; + +class SampleTextTest : public HloTestBase {}; + +TEST_F(SampleTextTest, Axpy) { + const string& hlo_string = R"( +HloModule axpy_module: +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_string, ErrorSpec{0.0001})); +} + +TEST_F(SampleTextTest, Tuple) { + const string& hlo_string = R"( +HloModule TupleCreate_module: +ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) { + %v1 = f32[] parameter(0) + %v2 = f32[3]{0} parameter(1) + %v3 = f32[2,3]{1,0} parameter(2) + ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3) +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, nullopt)); +} + +} // namespace +} // namespace xla