From 38a28a71c647bc707163efa4be4babf692feccff Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 Oct 2019 19:07:41 -0700 Subject: [PATCH] Save the numbers of replicas and logical cores in HLO proto. This will allow us to reproduce the same compilation process when there is data and model parallelism. PiperOrigin-RevId: 273001189 --- .../xla/service/compile_only_service.cc | 40 +++++++++++++----- tensorflow/compiler/xla/service/compiler.h | 3 ++ tensorflow/compiler/xla/service/dump.cc | 18 ++++++++ tensorflow/compiler/xla/service/dump.h | 5 +++ .../compiler/xla/service/hlo_computation.cc | 10 ++++- tensorflow/compiler/xla/service/hlo_module.cc | 42 ++++++++++++++++--- tensorflow/compiler/xla/service/hlo_module.h | 9 +++- .../compiler/xla/service/hlo_module_config.h | 10 ++++- tensorflow/compiler/xla/service/service.cc | 20 +++++---- tensorflow/compiler/xla/service/service.h | 7 ++-- tensorflow/compiler/xla/xla.proto | 4 ++ 11 files changed, 137 insertions(+), 31 deletions(-) diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 0908c3423ad..00b80f5b612 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -67,24 +67,21 @@ CompileOnlyService::CompileAheadOfTime( const AotCompilationOptions& options, std::unique_ptr* metadata) { std::vector> hlo_modules; + + const DebugOptions& debug_options = options.debug_options(); + ExecutionOptions execution_options; + *execution_options.mutable_debug_options() = debug_options; + for (const AotXlaComputationInstance& instance : computations) { TF_RET_CHECK(instance.computation.has_host_program_shape()); - - const DebugOptions& debug_options = options.debug_options(); - ExecutionOptions execution_options; - *execution_options.mutable_debug_options() = debug_options; *execution_options.mutable_shape_with_output_layout() = instance.result_layout->ToProto(); - if (options.has_static_device_assignment()) { - TF_RETURN_IF_ERROR(options.static_device_assignment().Serialize( - execution_options.mutable_device_assignment())); - } + TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, CreateModuleConfig( ProgramShape(instance.computation.host_program_shape()), - instance.argument_layouts, &execution_options, - options.fusion_config_collection(), options.fusion_config())); + instance.argument_layouts, &execution_options, &options)); TF_ASSIGN_OR_RETURN( std::unique_ptr hlo_module, @@ -93,6 +90,29 @@ CompileOnlyService::CompileAheadOfTime( hlo_modules.push_back(std::move(hlo_module)); } + // Capture replica_count, num_cores, and device_assignment in ExecutionOptions + // to save in a proto dump. + if (options.replica_count() > 0) { + execution_options.set_num_replicas(options.replica_count()); + if (options.has_static_device_assignment()) { + CHECK_EQ(options.replica_count(), + options.static_device_assignment().replica_count()); + } + } + if (options.num_cores() > 0) { + execution_options.set_num_partitions(options.num_cores()); + if (options.has_static_device_assignment()) { + CHECK_EQ(options.num_cores(), + options.static_device_assignment().computation_count()); + } + } + if (options.has_static_device_assignment()) { + TF_RETURN_IF_ERROR(options.static_device_assignment().Serialize( + execution_options.mutable_device_assignment())); + } + execution_options.clear_shape_with_output_layout(); + DumpExecutionOptions(execution_options, debug_options); + return compiler_->CompileAheadOfTime( absl::make_unique(hlo_modules[0]->name(), absl::MakeSpan(hlo_modules)), diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 362f687e048..a0248839fdd 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -73,6 +73,9 @@ class AotCompilationOptions { // Returns the ID of the platform to which these options apply. virtual se::Platform::Id PlatformId() const = 0; + virtual int64 replica_count() const { return 0; } + virtual int64 num_cores() const { return 0; } + // Optional allocator that may be used for allocating temp space on the device // during compilation. se::DeviceMemoryAllocator* device_allocator() const { diff --git a/tensorflow/compiler/xla/service/dump.cc b/tensorflow/compiler/xla/service/dump.cc index 331c935bdc9..d9d6330b59b 100644 --- a/tensorflow/compiler/xla/service/dump.cc +++ b/tensorflow/compiler/xla/service/dump.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/platform/env.h" @@ -276,6 +277,23 @@ void DumpToFileInDirOrStdout(const HloModule& module, string_view suffix, CanonicalDebugOptions(module.config().debug_options())); } +void DumpExecutionOptions(const ExecutionOptions& execution_options, + const DebugOptions& debug_options) { + CanonicalDebugOptions opts(debug_options); + tensorflow::Env* env = tensorflow::Env::Default(); + const string& dir = opts.dump_to; + if (env->IsDirectory(dir).ok()) { + string filename = tensorflow::io::JoinPath(dir, "execution_options"); + if (opts.dump_as_text) { + TF_CHECK_OK(tensorflow::WriteTextProto( + env, absl::StrCat(filename, ".txt"), execution_options)); + } else { + TF_CHECK_OK(tensorflow::WriteBinaryProto( + env, absl::StrCat(filename, ".pb"), execution_options)); + } + } +} + void DumpHloModuleIfEnabled(const HloModule& module, string_view name) { CanonicalDebugOptions opts(module.config().debug_options()); if (opts.should_dump_module(module.name())) { diff --git a/tensorflow/compiler/xla/service/dump.h b/tensorflow/compiler/xla/service/dump.h index d245ad582c4..cae4170d6af 100644 --- a/tensorflow/compiler/xla/service/dump.h +++ b/tensorflow/compiler/xla/service/dump.h @@ -49,6 +49,11 @@ void DumpToFileInDirOrStdout(const HloModule& module, absl::string_view file_suffix, absl::string_view contents); +// Dumps the given execution options if dumping is enabled. Exactly +// where and in what formats it's dumped is determined by the debug options. +void DumpExecutionOptions(const ExecutionOptions& execution_options, + const DebugOptions& debug_options); + // Dumps the given HLO module if dumping is enabled for the module. Exactly // where and in what formats it's dumped is determined by the module's config. // diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index cbdada0b46b..63e83e6b41f 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -1014,8 +1014,14 @@ std::unique_ptr HloComputation::CloneWithReplacements( << operand->ToString() << ", used by " << instr->ToString(); new_operands.push_back(context->GetInstruction(replaced_operand)); } - instructions.push_back( - instr->CloneWithNewOperands(instr->shape(), new_operands, context)); + std::unique_ptr new_instr = + instr->CloneWithNewOperands(instr->shape(), new_operands, context); + if (instr->opcode() == HloOpcode::kParameter && + instr->parameter_replicated_at_leaf_buffers().has_value()) { + new_instr->set_parameter_replicated_at_leaf_buffers( + instr->parameter_replicated_at_leaf_buffers().value()); + } + instructions.push_back(std::move(new_instr)); } Builder builder(name() + "." + suffix); for (auto& instr : instructions) { diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index ac74d5b0f65..d068b772664 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -383,14 +383,33 @@ StatusOr> HloModule::CreateFromProto( } /* static */ -StatusOr HloModule::CreateModuleConfigFromProto( - const HloModuleProto& module, const DebugOptions& debug_options) { - TF_RET_CHECK(module.has_host_program_shape()) - << "No program shape found in the proto"; - ProgramShape program_shape(module.host_program_shape()); - +StatusOr HloModule::CreateModuleConfigFromShape( + const ProgramShape& program_shape, const DebugOptions& debug_options, + const ExecutionOptions* execution_options) { HloModuleConfig module_config(ProgramShape{program_shape}); module_config.set_debug_options(debug_options); + if (execution_options) { + if (execution_options->num_replicas() > 0) { + module_config.set_replica_count(execution_options->num_replicas()); + } + if (execution_options->num_partitions() > 0) { + module_config.set_num_partitions(execution_options->num_partitions()); + } + if (execution_options->has_device_assignment()) { + TF_ASSIGN_OR_RETURN(std::unique_ptr device_assignment, + DeviceAssignment::Deserialize( + execution_options->device_assignment())); + module_config.set_static_device_assignment(*device_assignment); + if (execution_options->num_replicas() > 0) { + CHECK_EQ(module_config.static_device_assignment().replica_count(), + module_config.replica_count()); + } + if (execution_options->num_partitions() > 0) { + CHECK_EQ(module_config.static_device_assignment().computation_count(), + module_config.num_partitions()); + } + } + } // The module config is constructed with default layouts regardless of what is // passed in via the ProgramShape. Set the layouts to the appropriate values. @@ -406,6 +425,17 @@ StatusOr HloModule::CreateModuleConfigFromProto( return module_config; } +/* static */ +StatusOr HloModule::CreateModuleConfigFromProto( + const HloModuleProto& module, const DebugOptions& debug_options, + const ExecutionOptions* execution_options) { + TF_RET_CHECK(module.has_host_program_shape()) + << "No program shape found in the proto"; + ProgramShape program_shape(module.host_program_shape()); + return CreateModuleConfigFromShape(program_shape, debug_options, + execution_options); +} + namespace { // Returns whether `hlo` is used outside the given subcomputation. // `instructions_in_subcomputation` is the instruction set of the given diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 6116a0cb648..745ad56d01a 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -221,7 +221,14 @@ class HloModule { // Creates and returns an HloModuleConfig with an appropriate program shape // for the HLO module in the given proto. static StatusOr CreateModuleConfigFromProto( - const HloModuleProto& module, const DebugOptions& debug_options); + const HloModuleProto& module, const DebugOptions& debug_options, + const ExecutionOptions* execution_options = nullptr); + + // Creates and returns an HloModuleConfig with an appropriate program shape + // for the HLO module in the given proto. + static StatusOr CreateModuleConfigFromShape( + const ProgramShape& program_shape, const DebugOptions& debug_options, + const ExecutionOptions* execution_options = nullptr); // Outlines the given expression from the given computation. // instructions_to_outline contains the instructions that form the expression. diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index 0db3792ae44..dee601d9e96 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -113,6 +113,11 @@ class HloModuleConfig { } int64 replica_count() const { return replica_count_; } + void set_num_partitions(int64 num_partitions) { + num_partitions_ = num_partitions; + } + int64 num_partitions() const { return num_partitions_; } + // Return a string which unambiguously represents all the fields of this data // structure. Used for generating a cache key for storing the compiled // executable. @@ -186,9 +191,12 @@ class HloModuleConfig { // Module/graph-level seed handle. uint64 seed_ = 0; - // The number of replicas to compile this binary for. + // The number of replicas (data parallelism) to compile this binary for. int64 replica_count_ = 1; + // The number of partitions (model parallelism) to compile this binary for. + int64 num_partitions_ = 1; + // The target maximum parallelism at which to partition HLOs for parallel // execution on the CPU backend. int64 intra_op_parallelism_threads_ = -1; diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index c94db226de1..345a077e321 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -266,8 +266,7 @@ StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, absl::Span argument_shapes, const ExecutionOptions* execution_options, - FusionConfigCollection fusion_config_collection, - const std::vector>& fusion_config) { + const AotCompilationOptions* aot_options) { auto config = absl::make_unique(program_shape); ComputationLayout* computation_layout = config->mutable_entry_computation_layout(); @@ -311,6 +310,9 @@ StatusOr> Service::CreateModuleConfig( } else { config->set_replica_count(options_.number_of_replicas()); } + if (execution_options->num_partitions() > 0) { + config->set_num_partitions(execution_options->num_partitions()); + } config->set_seed(execution_options->seed()); config->set_debug_options(execution_options->debug_options()); } else { @@ -334,9 +336,11 @@ StatusOr> Service::CreateModuleConfig( config->set_alias_passthrough_params( execution_options->alias_passthrough_params()); - if (fusion_config_collection != FusionConfigCollection::kOff) { - config->set_fusion_config_collection(fusion_config_collection); - *config->mutable_fusion_config() = fusion_config; + if (aot_options != nullptr && + aot_options->fusion_config_collection() != FusionConfigCollection::kOff) { + config->set_fusion_config_collection( + aot_options->fusion_config_collection()); + *config->mutable_fusion_config() = aot_options->fusion_config(); } return std::move(config); @@ -345,12 +349,14 @@ StatusOr> Service::CreateModuleConfig( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, absl::Span arguments, - const ExecutionOptions& execution_options) { + const ExecutionOptions& execution_options, + const AotCompilationOptions* aot_options) { std::vector argument_shapes; for (const auto* arg : arguments) { argument_shapes.push_back(&arg->on_host_shape()); } - return CreateModuleConfig(program_shape, argument_shapes, &execution_options); + return CreateModuleConfig(program_shape, argument_shapes, &execution_options, + aot_options); } StatusOr>> Service::BuildExecutables( diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index a37b78eed5b..3a4e17d7f44 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -189,9 +189,7 @@ class Service : public ServiceInterface { const ProgramShape& program_shape, absl::Span argument_shapes, const ExecutionOptions* execution_options, - FusionConfigCollection fusion_config_collection = - FusionConfigCollection::kOff, - const std::vector>& fusion_config = {}); + const AotCompilationOptions* aot_options = nullptr); private: // A private overload for Service itself, used by other methods within this @@ -199,7 +197,8 @@ class Service : public ServiceInterface { StatusOr> CreateModuleConfig( const ProgramShape& program_shape, absl::Span arguments, - const ExecutionOptions& execution_options); + const ExecutionOptions& execution_options, + const AotCompilationOptions* aot_options = nullptr); // Prepare the executors for executing parallel. StatusOr> GetExecutors( diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index a3bc092ac83..32b243bdec5 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -339,6 +339,10 @@ message ExecutionOptions { // Alias input and output buffers for parameters that are passed-through XLA // modules without being changed. bool alias_passthrough_params = 8; + + // Number of partitions of the computation to run (model parallelism). + // If zero, uses the default number of partitions for the XLA service. + int32 num_partitions = 9; } message GetDeviceHandlesRequest {