From 78a90370ef214015508103bb21cef2962f041c5c Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Mon, 31 Jul 2017 16:37:10 -0700 Subject: [PATCH] [XLA] Refactor CreateModuleConfig to share code between multiple call-sites. Previously Service, LocalService and CompileOnlyService had their own code to create a new HloModuleConfig, with much repetition (and some ommissions); collect all these uses in a single method. PiperOrigin-RevId: 163766869 --- tensorflow/compiler/xla/service/BUILD | 2 +- .../xla/service/compile_only_service.cc | 28 +++----- .../compiler/xla/service/local_service.cc | 38 ++++------- tensorflow/compiler/xla/service/service.cc | 64 ++++++++++++------- tensorflow/compiler/xla/service/service.h | 16 ++++- 5 files changed, 76 insertions(+), 72 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 353d492a428..757a1a9046b 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -473,6 +473,7 @@ cc_library( ":shaped_buffer", ":user_computation", ":versioned_computation_handle", + "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -480,7 +481,6 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", ], diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index d43dc5b214a..1dfe4a73b31 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -96,29 +96,17 @@ CompileOnlyService::CompileAheadOfTime( std::shared_ptr program_shape, user_computation->ComputeProgramShape(versioned_handle.version)); - HloModuleConfig hlo_module_config(*program_shape); - hlo_module_config.set_debug_options(debug_options); - auto* computation_layout = - hlo_module_config.mutable_entry_computation_layout(); - if (debug_options.xla_hlo_profile()) { - hlo_module_config.enable_hlo_profiling(true); - } - for (int i = 0; i < instance.argument_layouts.size(); ++i) { - const Shape& argument_layout = *instance.argument_layouts[i]; - if (ShapeUtil::IsTuple(argument_layout)) { - return Unimplemented("tuple arguments not supported yet"); - } - TF_RETURN_IF_ERROR( - computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( - argument_layout)); - } - TF_RETURN_IF_ERROR( - computation_layout->mutable_result_layout()->CopyLayoutFromShape( - *instance.result_layout)); + ExecutionOptions execution_options; + *execution_options.mutable_debug_options() = debug_options; + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig(*program_shape, instance.argument_layouts, + &execution_options, + /*has_hybrid_result=*/false)); TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, computation_tracker_.BuildHloModule( - versioned_handle, hlo_module_config, + versioned_handle, *module_config, /*include_unreachable_instructions=*/true)); hlo_modules.push_back(std::move(hlo_module)); } diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 45e37c6f65e..2042558a29b 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" @@ -141,35 +141,19 @@ StatusOr> LocalService::CompileExecutable( ValidateResultShapeWithLayout(*result_layout, program_shape->result())); } - // Construct computation layout from the argument layouts. - auto module_config = MakeUnique(*program_shape); - module_config->set_has_hybrid_result(has_hybrid_result); - module_config->set_replica_count(options_.number_of_replicas()); - module_config->set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - if (execute_backend_->eigen_intra_op_thread_pool() != nullptr) { - module_config->set_intra_op_parallelism_threads( - execute_backend_->eigen_intra_op_thread_pool()->NumThreads()); - } - if (module_config->debug_options().xla_hlo_profile()) { - module_config->enable_hlo_profiling(true); - } - auto* computation_layout = module_config->mutable_entry_computation_layout(); - for (int i = 0; i < argument_layouts.size(); ++i) { - const Shape& shape = *argument_layouts[i]; - if (ShapeUtil::IsTuple(shape)) { - return Unimplemented("tuple arguments not supported yet"); - } - TF_RETURN_IF_ERROR( - computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( - shape)); - } + ExecutionOptions execution_options = CreateDefaultExecutionOptions(); if (result_layout != nullptr) { - TF_RETURN_IF_ERROR( - computation_layout->mutable_result_layout()->CopyLayoutFromShape( - *result_layout)); + *execution_options.mutable_shape_with_output_layout() = *result_layout; } else { - computation_layout->mutable_result_layout()->SetToDefaultLayout(); + *execution_options.mutable_shape_with_output_layout() = + program_shape->result(); + LayoutUtil::SetToDefaultLayout( + execution_options.mutable_shape_with_output_layout()); } + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig(*program_shape, argument_layouts, &execution_options, + has_hybrid_result)); TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, execute_backend_->stream_executor(device_ordinal)); diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 25e3f57dfb1..ef29c7d5d13 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -286,51 +286,71 @@ StatusOr> Service::ResolveAndValidateArguments( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, - tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions& execution_options) { - auto module_config = MakeUnique(program_shape); - auto* computation_layout = module_config->mutable_entry_computation_layout(); + tensorflow::gtl::ArraySlice argument_shapes, + const ExecutionOptions* execution_options, bool has_hybrid_result) { + auto config = MakeUnique(program_shape); + auto* computation_layout = config->mutable_entry_computation_layout(); - if (program_shape.parameters_size() != arguments.size()) { + if (program_shape.parameters_size() != argument_shapes.size()) { return InvalidArgument("computation takes %d parameters, but %zu given", - program_shape.parameters_size(), arguments.size()); + program_shape.parameters_size(), + argument_shapes.size()); } - - for (size_t i = 0; i < arguments.size(); ++i) { + for (int i = 0; i < argument_shapes.size(); ++i) { // Verify that shape of arguments matches the shape of the arguments in the // ProgramShape. - if (!ShapeUtil::Compatible(arguments[i]->shape(), + if (!ShapeUtil::Compatible(*argument_shapes[i], program_shape.parameters(i))) { return InvalidArgument( - "computation expects parameter %lu to have shape %s, given shape %s", + "computation expects parameter %d to have shape %s, given shape %s", i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), - ShapeUtil::HumanString(arguments[i]->shape()).c_str()); + ShapeUtil::HumanString(*argument_shapes[i]).c_str()); } TF_RETURN_IF_ERROR( computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( - arguments[i]->shape())); + *argument_shapes[i])); } - if (!execution_options.has_shape_with_output_layout()) { - computation_layout->mutable_result_layout()->Clear(); - } else { + if (execution_options != nullptr && + execution_options->has_shape_with_output_layout()) { const auto& shape_with_output_layout = - execution_options.shape_with_output_layout(); + execution_options->shape_with_output_layout(); TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout(shape_with_output_layout, program_shape.result())); TF_RETURN_IF_ERROR( computation_layout->mutable_result_layout()->CopyLayoutFromShape( shape_with_output_layout)); + } else { + computation_layout->mutable_result_layout()->Clear(); } - if (execution_options.debug_options().xla_hlo_profile()) { - module_config->enable_hlo_profiling(true); + config->set_replica_count(options_.number_of_replicas()); + config->set_has_hybrid_result(has_hybrid_result); + if (execution_options != nullptr) { + config->set_seed(execution_options->seed()); + config->set_debug_options(execution_options->debug_options()); + config->enable_hlo_profiling( + execution_options->debug_options().xla_hlo_profile()); + } else { + config->set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); } - module_config->set_replica_count(options_.number_of_replicas()); - module_config->set_seed(execution_options.seed()); - module_config->set_debug_options(execution_options.debug_options()); + if (execute_backend_ != nullptr && + execute_backend_->eigen_intra_op_thread_pool() != nullptr) { + config->set_intra_op_parallelism_threads( + execute_backend_->eigen_intra_op_thread_pool()->NumThreads()); + } + return std::move(config); +} - return std::move(module_config); +StatusOr> Service::CreateModuleConfig( + const ProgramShape& program_shape, + tensorflow::gtl::ArraySlice arguments, + const ExecutionOptions& execution_options) { + std::vector argument_shapes; + for (const auto* arg : arguments) { + argument_shapes.push_back(&arg->shape()); + } + return CreateModuleConfig(program_shape, argument_shapes, &execution_options); } StatusOr>> Service::BuildExecutables( diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index ee1ed08436b..a07f7cd0426 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -245,6 +245,14 @@ class Service : public ServiceInterface { const Backend& backend() const { return *execute_backend_; } Backend* mutable_backend() { return execute_backend_.get(); } + private: + // A private overload for Service itself, used by other methods within this + // class. + StatusOr> CreateModuleConfig( + const ProgramShape& program_shape, + tensorflow::gtl::ArraySlice arguments, + const ExecutionOptions& execution_options); + protected: friend class LocalExecutable; @@ -263,10 +271,14 @@ class Service : public ServiceInterface { const Backend* backend, int device_ordinal); // Create a Hlo module config for the given program shape and arguments. + // execution_options is optional; if not given a default is used. + // has_hybrid_result is used to initialize the same-named field in + // HloModuleConfig -- see that class for documentation. StatusOr> CreateModuleConfig( const ProgramShape& program_shape, - tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions& execution_options); + tensorflow::gtl::ArraySlice argument_shapes, + const ExecutionOptions* execution_options, + bool has_hybrid_result = false); // Builds an Executable for the given parameters. If // executable_for_compute_constant is true, then the executable is intended to