diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index f1fa2ce3a52..edbc0078869 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -4459,6 +4459,30 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_runner_interface", + srcs = ["hlo_runner_interface.cc"], + hdrs = ["hlo_runner_interface.h"], + deps = [ + ":compiler", + ":computation_placer", + ":hlo", + ":hlo_module_group", + ":hlo_parser", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "hlo_runner", srcs = ["hlo_runner.cc"], @@ -4471,6 +4495,7 @@ cc_library( ":hlo", ":hlo_module_group", ":hlo_parser", + ":hlo_runner_interface", ":transfer_manager", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 0d71c6d49ed..86ff41ba273 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -34,58 +34,6 @@ limitations under the License. namespace xla { -/*static*/ StatusOr> -HloRunner::CreateModuleFromString(const absl::string_view hlo_string, - const DebugOptions& debug_options) { - HloModuleConfig config; - config.set_debug_options(debug_options); - return ParseAndReturnUnverifiedModule(hlo_string, config); -} - -namespace { - -// Creates an HloModule from the given proto. -StatusOr> HloProtoToModule( - const HloProto& proto, const DebugOptions& debug_options) { - TF_ASSIGN_OR_RETURN(HloModuleConfig config, - HloModule::CreateModuleConfigFromProto(proto.hlo_module(), - debug_options)); - TF_ASSIGN_OR_RETURN(auto module, - HloModule::CreateFromProto(proto.hlo_module(), config)); - return std::move(module); -} - -} // namespace - -/*static*/ StatusOr> -HloRunner::ReadModuleFromBinaryProtoFile(const std::string& filename, - const DebugOptions& debug_options) { - HloProto proto; - TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), - filename, &proto)); - return HloProtoToModule(proto, debug_options); -} - -/*static*/ StatusOr> -HloRunner::ReadModuleFromTextProtoFile(const std::string& filename, - const DebugOptions& debug_options) { - HloProto proto; - TF_RETURN_IF_ERROR( - tensorflow::ReadTextProto(tensorflow::Env::Default(), filename, &proto)); - return HloProtoToModule(proto, debug_options); -} - -/*static*/ StatusOr> -HloRunner::ReadModuleFromHloTextFile(const std::string& filename, - const DebugOptions& debug_options) { - string hlo_string; - TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(), - filename, &hlo_string)); - HloModuleConfig config; - config.set_debug_options(debug_options); - return ParseAndReturnUnverifiedModule(hlo_string, config); -} - HloRunner::HloRunner(se::Platform* platform, int intra_op_parallelism_threads) { BackendOptions backend_options; backend_options.set_platform(platform); @@ -155,26 +103,9 @@ StatusOr HloRunner::Execute(std::unique_ptr module, return TransferLiteralFromDevice(result.Result()); } -StatusOr HloRunner::Execute(std::unique_ptr module, - absl::Span arguments, - bool run_hlo_passes, - ExecutionProfile* profile) { - // Construct a vector of plain pointers for the arguments. - std::vector argument_pointers; - argument_pointers.reserve(arguments.size()); - for (const auto& argument : arguments) { - argument_pointers.push_back(&argument); - } - return Execute( - /*module=*/std::move(module), - /*arguments=*/argument_pointers, - /*run_hlo_passes=*/run_hlo_passes, - /*profile=*/profile); -} - -StatusOr HloRunner::Execute(std::unique_ptr executable, - absl::Span arguments, - ExecutionProfile* profile) { +StatusOr HloRunner::ExecuteWithExecutable( + std::unique_ptr executable, absl::Span arguments, + ExecutionProfile* profile) { TF_ASSIGN_OR_RETURN(std::vector argument_buffers, TransferLiteralsToDevice(arguments)); TF_ASSIGN_OR_RETURN(ExecutionOutput result, diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index 733bb8bff54..721a50232cc 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_runner_interface.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -42,48 +43,8 @@ namespace xla { // certain backend directly without using the client interface. HloModule can be // explicitly built, or loaded from a serialization file (e.g., hlo proto // file), or parsed from a hlo textual IR string. -class HloRunner { +class HloRunner : public HloRunnerInterface { public: - // The options used to configure a ExecuteReplicated() call. - struct ReplicatedExecuteOptions { - // The number of devices the HLO module should be replicated onto. - int64 num_replicas = 1; - - // The arguments to be fed to each replica. Since this is used for a - // replicated execution, all the arguments are the same for all replicas. - std::vector arguments; - - // If the HLO module being run has an infeed instruction, this will be the - // data which will be fed to it, for as many as infeed_steps steps. - const Literal* infeed = nullptr; - - // The number of times the infeed literal should be fed to the HLO module. - // For a clean exit, this should match the iterations-per-loop parameter - // used when generating the HLO module proto (that is usually the main - // while boundary counter). A value higher then iterations-per-loop would - // lead to infeed threads feeding to a gone computation, while a lower - // value would trigger a stuck ExecuteReplicated() call (the computation - // will be trying to infeed data which will never come). - int64 infeed_steps = -1; - - // The shape of the outfeed operation. If empty, the HLO module does not - // generate any outfeed. - Shape outfeed_shape; - - // A pointer to a vector where the outfeed values will be stored. If - // nullptr, the values will be read and discarded. - std::vector* outfeed_values = nullptr; - - // Whether the HLO passes should be run on the input module. Usually - // saved modules are coming from after the HLO pass pipeline, so triggering - // another run will likely cause errors. - bool run_hlo_passes = false; - - // If true, executes on multiple threads using se::Stream::ExecuteOnStream. - // Otherwise, executes using xla::Executable::ExecuteOnStreams. - bool use_threads = false; - }; - // intra_op_parallelism_threads: For the CPU backend only. It is the thread // pool size for parallel execution of an individual operator. The default // value of -1 will result in initializing the thread pool with the number of @@ -92,24 +53,7 @@ class HloRunner { explicit HloRunner(se::Platform* platform, int intra_op_parallelism_threads = -1); - ~HloRunner(); - - // Converts an HloModule from the given hlo textual IR string (in - // HloModule::ToString format). - static StatusOr> CreateModuleFromString( - const absl::string_view hlo_string, const DebugOptions& debug_options); - - // Reads the proto file in xla.HloProto format, creates and returns the - // HloModule. - static StatusOr> ReadModuleFromBinaryProtoFile( - const std::string& filename, const DebugOptions& debug_options); - static StatusOr> ReadModuleFromTextProtoFile( - const std::string& filename, const DebugOptions& debug_options); - - // Reads the hlo text dump file in HloModule::ToString format, creates and - // returns the HloModule. - static StatusOr> ReadModuleFromHloTextFile( - const std::string& filename, const DebugOptions& debug_options); + ~HloRunner() override; // Transfers data between the host and device. StatusOr TransferLiteralToDevice(const Literal& literal); @@ -124,19 +68,17 @@ class HloRunner { // // If run_hlo_passes is false, the module will be executed without Hlo // optimization. + + using HloRunnerInterface::Execute; + StatusOr Execute(std::unique_ptr module, absl::Span arguments, - bool run_hlo_passes = true, - ExecutionProfile* profile = nullptr); + bool run_hlo_passes, + ExecutionProfile* profile) override; - StatusOr Execute(std::unique_ptr module, - absl::Span arguments, - bool run_hlo_passes = true, - ExecutionProfile* profile = nullptr); - - StatusOr Execute(std::unique_ptr executable, - absl::Span arguments, - ExecutionProfile* profile = nullptr); + StatusOr ExecuteWithExecutable( + std::unique_ptr executable, + absl::Span arguments, ExecutionProfile* profile = nullptr); // As Execute(), but accepts and returns device buffers instead of host // buffers. @@ -159,13 +101,13 @@ class HloRunner { // value. StatusOr> ExecuteReplicated( std::unique_ptr module, - const ReplicatedExecuteOptions& options); + const ReplicatedExecuteOptions& options) override; // Same as above, but with specified device assignment. StatusOr> ExecuteReplicated( std::unique_ptr module, const ReplicatedExecuteOptions& options, - DeviceAssignment* device_assignment); + DeviceAssignment* device_assignment) override; // Same as above, but with a reusable Executable. This may update the profile // information in *executable. diff --git a/tensorflow/compiler/xla/service/hlo_runner_interface.cc b/tensorflow/compiler/xla/service/hlo_runner_interface.cc new file mode 100644 index 00000000000..7359f1f08b0 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_runner_interface.cc @@ -0,0 +1,90 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_runner_interface.h" + +#include "tensorflow/compiler/xla/service/hlo_parser.h" + +namespace xla { + +/*static*/ StatusOr> +HloRunnerInterface::CreateModuleFromString(const absl::string_view hlo_string, + const DebugOptions& debug_options) { + HloModuleConfig config; + config.set_debug_options(debug_options); + return ParseAndReturnUnverifiedModule(hlo_string, config); +} + +namespace { + +// Creates an HloModule from the given proto. +StatusOr> HloProtoToModule( + const HloProto& proto, const DebugOptions& debug_options) { + TF_ASSIGN_OR_RETURN(HloModuleConfig config, + HloModule::CreateModuleConfigFromProto(proto.hlo_module(), + debug_options)); + TF_ASSIGN_OR_RETURN(auto module, + HloModule::CreateFromProto(proto.hlo_module(), config)); + return std::move(module); +} + +} // namespace + +/*static*/ StatusOr> +HloRunnerInterface::ReadModuleFromBinaryProtoFile( + const std::string& filename, const DebugOptions& debug_options) { + HloProto proto; + TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), + filename, &proto)); + return HloProtoToModule(proto, debug_options); +} + +/*static*/ StatusOr> +HloRunnerInterface::ReadModuleFromTextProtoFile( + const std::string& filename, const DebugOptions& debug_options) { + HloProto proto; + TF_RETURN_IF_ERROR( + tensorflow::ReadTextProto(tensorflow::Env::Default(), filename, &proto)); + return HloProtoToModule(proto, debug_options); +} + +/*static*/ StatusOr> +HloRunnerInterface::ReadModuleFromHloTextFile( + const std::string& filename, const DebugOptions& debug_options) { + string hlo_string; + TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(), + filename, &hlo_string)); + HloModuleConfig config; + config.set_debug_options(debug_options); + return ParseAndReturnUnverifiedModule(hlo_string, config); +} + +StatusOr HloRunnerInterface::Execute( + std::unique_ptr module, absl::Span arguments, + bool run_hlo_passes, ExecutionProfile* profile) { + // Construct a vector of plain pointers for the arguments. + std::vector argument_pointers; + argument_pointers.reserve(arguments.size()); + for (const auto& argument : arguments) { + argument_pointers.push_back(&argument); + } + return Execute( + /*module=*/std::move(module), + /*arguments=*/argument_pointers, + /*run_hlo_passes=*/run_hlo_passes, + /*profile=*/profile); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_runner_interface.h b/tensorflow/compiler/xla/service/hlo_runner_interface.h new file mode 100644 index 00000000000..bee8349ac71 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_runner_interface.h @@ -0,0 +1,142 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_INTERFACE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_INTERFACE_H_ + +#include +#include +#include +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// A base class for running an HloModule. This executes the given HloModule on a +// certain backend directly without using the client interface. HloModule can be +// explicitly built, or loaded from a serialization file (e.g., hlo proto +// file), or parsed from a hlo textual IR string. +class HloRunnerInterface { + public: + // The options used to configure an ExecuteReplicated() call. + struct ReplicatedExecuteOptions { + // The number of devices the HLO module should be replicated onto. + int64 num_replicas = 1; + + // The arguments to be fed to each replica. Since this is used for a + // replicated execution, all the arguments are the same for all replicas. + std::vector arguments; + + // If the HLO module being run has an infeed instruction, this will be the + // data which will be fed to it, for as many as infeed_steps steps. + const Literal* infeed = nullptr; + + // The number of times the infeed literal should be fed to the HLO module. + // For a clean exit, this should match the iterations-per-loop parameter + // used when generating the HLO module proto (that is usually the main + // while boundary counter). A value higher then iterations-per-loop would + // lead to infeed threads feeding to a gone computation, while a lower + // value would trigger a stuck ExecuteReplicated() call (the computation + // will be trying to infeed data which will never come). + int64 infeed_steps = -1; + + // The shape of the outfeed operation. If empty, the HLO module does not + // generate any outfeed. + Shape outfeed_shape; + + // A pointer to a vector where the outfeed values will be stored. If + // nullptr, the values will be read and discarded. + std::vector* outfeed_values = nullptr; + + // Whether the HLO passes should be run on the input module. Usually + // saved modules are coming from after the HLO pass pipeline, so triggering + // another run will likely cause errors. + bool run_hlo_passes = false; + + // If true, executes on multiple threads using se::Stream::ExecuteOnStream. + // Otherwise, executes using xla::Executable::ExecuteOnStreams. + bool use_threads = false; + }; + + HloRunnerInterface() = default; + + virtual ~HloRunnerInterface() = default; + + // Converts an HloModule from the given hlo textual IR string (in + // HloModule::ToString format). + static StatusOr> CreateModuleFromString( + const absl::string_view hlo_string, const DebugOptions& debug_options); + + // Reads the proto file in xla.HloProto format, creates and returns the + // HloModule. + static StatusOr> ReadModuleFromBinaryProtoFile( + const std::string& filename, const DebugOptions& debug_options); + static StatusOr> ReadModuleFromTextProtoFile( + const std::string& filename, const DebugOptions& debug_options); + + // Reads the hlo text dump file in HloModule::ToString format, creates and + // returns the HloModule. + static StatusOr> ReadModuleFromHloTextFile( + const std::string& filename, const DebugOptions& debug_options); + + // Executes the given module with given literals as input and returns the + // result as a Literal. + // + // If run_hlo_passes is false, the module will be executed without Hlo + // optimization + StatusOr Execute(std::unique_ptr module, + absl::Span arguments, + bool run_hlo_passes = true) { + return Execute(std::move(module), arguments, run_hlo_passes, nullptr); + } + + StatusOr Execute(std::unique_ptr module, + absl::Span arguments, + bool run_hlo_passes = true, + ExecutionProfile* profile = nullptr); + + virtual StatusOr Execute(std::unique_ptr module, + absl::Span arguments, + bool run_hlo_passes, + ExecutionProfile* profile) = 0; + + // Executes a given HLO module into a set of replicas, and returns a map + // with the replica number as key, and the corresponding returned literal as + // value. + // TODO(b/172931928): change to non-virtual function. + virtual StatusOr> ExecuteReplicated( + std::unique_ptr module, + const ReplicatedExecuteOptions& options) = 0; + + // Same as above, but with specified device assignment. + virtual StatusOr> ExecuteReplicated( + std::unique_ptr module, + const ReplicatedExecuteOptions& options, + DeviceAssignment* device_assignment) = 0; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_INTERFACE_H_ diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 6c062deb363..c9d08cef857 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -507,9 +507,9 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( absl::optional canonical_output; for (int i = 0; i < n; ++i) { - StatusOr output = - test_runner_.Execute(std::move(executables[i]), fake_arguments[i], - /*profile=*/&((*profiles)[i])); + StatusOr output = test_runner_.ExecuteWithExecutable( + std::move(executables[i]), fake_arguments[i], + /*profile=*/&((*profiles)[i])); if (!output.ok()) { return ::testing::AssertionFailure() << output.status().error_message(); }