[XLA] Refactor CreateModuleConfig to share code between multiple call-sites.

Previously Service, LocalService and CompileOnlyService had their own code to
create a new HloModuleConfig, with much repetition (and some ommissions);
collect all these uses in a single method.

PiperOrigin-RevId: 163766869
This commit is contained in:
Eli Bendersky 2017-07-31 16:37:10 -07:00 committed by TensorFlower Gardener
parent 6150611ae2
commit 78a90370ef
5 changed files with 76 additions and 72 deletions

View File

@ -473,6 +473,7 @@ cc_library(
":shaped_buffer",
":user_computation",
":versioned_computation_handle",
"//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:shape_layout",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@ -480,7 +481,6 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
],

View File

@ -96,29 +96,17 @@ CompileOnlyService::CompileAheadOfTime(
std::shared_ptr<const ProgramShape> program_shape,
user_computation->ComputeProgramShape(versioned_handle.version));
HloModuleConfig hlo_module_config(*program_shape);
hlo_module_config.set_debug_options(debug_options);
auto* computation_layout =
hlo_module_config.mutable_entry_computation_layout();
if (debug_options.xla_hlo_profile()) {
hlo_module_config.enable_hlo_profiling(true);
}
for (int i = 0; i < instance.argument_layouts.size(); ++i) {
const Shape& argument_layout = *instance.argument_layouts[i];
if (ShapeUtil::IsTuple(argument_layout)) {
return Unimplemented("tuple arguments not supported yet");
}
TF_RETURN_IF_ERROR(
computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
argument_layout));
}
TF_RETURN_IF_ERROR(
computation_layout->mutable_result_layout()->CopyLayoutFromShape(
*instance.result_layout));
ExecutionOptions execution_options;
*execution_options.mutable_debug_options() = debug_options;
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(*program_shape, instance.argument_layouts,
&execution_options,
/*has_hybrid_result=*/false));
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> hlo_module,
computation_tracker_.BuildHloModule(
versioned_handle, hlo_module_config,
versioned_handle, *module_config,
/*include_unreachable_instructions=*/true));
hlo_modules.push_back(std::move(hlo_module));
}

View File

@ -19,7 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
@ -141,35 +141,19 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
ValidateResultShapeWithLayout(*result_layout, program_shape->result()));
}
// Construct computation layout from the argument layouts.
auto module_config = MakeUnique<HloModuleConfig>(*program_shape);
module_config->set_has_hybrid_result(has_hybrid_result);
module_config->set_replica_count(options_.number_of_replicas());
module_config->set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
if (execute_backend_->eigen_intra_op_thread_pool() != nullptr) {
module_config->set_intra_op_parallelism_threads(
execute_backend_->eigen_intra_op_thread_pool()->NumThreads());
}
if (module_config->debug_options().xla_hlo_profile()) {
module_config->enable_hlo_profiling(true);
}
auto* computation_layout = module_config->mutable_entry_computation_layout();
for (int i = 0; i < argument_layouts.size(); ++i) {
const Shape& shape = *argument_layouts[i];
if (ShapeUtil::IsTuple(shape)) {
return Unimplemented("tuple arguments not supported yet");
}
TF_RETURN_IF_ERROR(
computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
shape));
}
ExecutionOptions execution_options = CreateDefaultExecutionOptions();
if (result_layout != nullptr) {
TF_RETURN_IF_ERROR(
computation_layout->mutable_result_layout()->CopyLayoutFromShape(
*result_layout));
*execution_options.mutable_shape_with_output_layout() = *result_layout;
} else {
computation_layout->mutable_result_layout()->SetToDefaultLayout();
*execution_options.mutable_shape_with_output_layout() =
program_shape->result();
LayoutUtil::SetToDefaultLayout(
execution_options.mutable_shape_with_output_layout());
}
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(*program_shape, argument_layouts, &execution_options,
has_hybrid_result));
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
execute_backend_->stream_executor(device_ordinal));

View File

@ -286,51 +286,71 @@ StatusOr<std::vector<const Allocation*>> Service::ResolveAndValidateArguments(
StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const Allocation*> arguments,
const ExecutionOptions& execution_options) {
auto module_config = MakeUnique<HloModuleConfig>(program_shape);
auto* computation_layout = module_config->mutable_entry_computation_layout();
tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,
const ExecutionOptions* execution_options, bool has_hybrid_result) {
auto config = MakeUnique<HloModuleConfig>(program_shape);
auto* computation_layout = config->mutable_entry_computation_layout();
if (program_shape.parameters_size() != arguments.size()) {
if (program_shape.parameters_size() != argument_shapes.size()) {
return InvalidArgument("computation takes %d parameters, but %zu given",
program_shape.parameters_size(), arguments.size());
program_shape.parameters_size(),
argument_shapes.size());
}
for (size_t i = 0; i < arguments.size(); ++i) {
for (int i = 0; i < argument_shapes.size(); ++i) {
// Verify that shape of arguments matches the shape of the arguments in the
// ProgramShape.
if (!ShapeUtil::Compatible(arguments[i]->shape(),
if (!ShapeUtil::Compatible(*argument_shapes[i],
program_shape.parameters(i))) {
return InvalidArgument(
"computation expects parameter %lu to have shape %s, given shape %s",
"computation expects parameter %d to have shape %s, given shape %s",
i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(),
ShapeUtil::HumanString(arguments[i]->shape()).c_str());
ShapeUtil::HumanString(*argument_shapes[i]).c_str());
}
TF_RETURN_IF_ERROR(
computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
arguments[i]->shape()));
*argument_shapes[i]));
}
if (!execution_options.has_shape_with_output_layout()) {
computation_layout->mutable_result_layout()->Clear();
} else {
if (execution_options != nullptr &&
execution_options->has_shape_with_output_layout()) {
const auto& shape_with_output_layout =
execution_options.shape_with_output_layout();
execution_options->shape_with_output_layout();
TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout(shape_with_output_layout,
program_shape.result()));
TF_RETURN_IF_ERROR(
computation_layout->mutable_result_layout()->CopyLayoutFromShape(
shape_with_output_layout));
} else {
computation_layout->mutable_result_layout()->Clear();
}
if (execution_options.debug_options().xla_hlo_profile()) {
module_config->enable_hlo_profiling(true);
config->set_replica_count(options_.number_of_replicas());
config->set_has_hybrid_result(has_hybrid_result);
if (execution_options != nullptr) {
config->set_seed(execution_options->seed());
config->set_debug_options(execution_options->debug_options());
config->enable_hlo_profiling(
execution_options->debug_options().xla_hlo_profile());
} else {
config->set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
}
module_config->set_replica_count(options_.number_of_replicas());
module_config->set_seed(execution_options.seed());
module_config->set_debug_options(execution_options.debug_options());
if (execute_backend_ != nullptr &&
execute_backend_->eigen_intra_op_thread_pool() != nullptr) {
config->set_intra_op_parallelism_threads(
execute_backend_->eigen_intra_op_thread_pool()->NumThreads());
}
return std::move(config);
}
return std::move(module_config);
StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const Allocation*> arguments,
const ExecutionOptions& execution_options) {
std::vector<const Shape*> argument_shapes;
for (const auto* arg : arguments) {
argument_shapes.push_back(&arg->shape());
}
return CreateModuleConfig(program_shape, argument_shapes, &execution_options);
}
StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(

View File

@ -245,6 +245,14 @@ class Service : public ServiceInterface {
const Backend& backend() const { return *execute_backend_; }
Backend* mutable_backend() { return execute_backend_.get(); }
private:
// A private overload for Service itself, used by other methods within this
// class.
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const Allocation*> arguments,
const ExecutionOptions& execution_options);
protected:
friend class LocalExecutable;
@ -263,10 +271,14 @@ class Service : public ServiceInterface {
const Backend* backend, int device_ordinal);
// Create a Hlo module config for the given program shape and arguments.
// execution_options is optional; if not given a default is used.
// has_hybrid_result is used to initialize the same-named field in
// HloModuleConfig -- see that class for documentation.
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const Allocation*> arguments,
const ExecutionOptions& execution_options);
tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,
const ExecutionOptions* execution_options,
bool has_hybrid_result = false);
// Builds an Executable for the given parameters. If
// executable_for_compute_constant is true, then the executable is intended to