Move CreateModuleConfig to a new hlo_module_util header.

Move CreateExecutionOptions to xla/client/executable_build_options.h.

PiperOrigin-RevId: 347478735
Change-Id: Iab07bcd31e89ee8e24f2385d0a515dba03e505ad
This commit is contained in:
Qiao Zhang 2020-12-14 15:10:33 -08:00 committed by TensorFlower Gardener
parent fe6d0cf3f9
commit e7e9a0c449
8 changed files with 238 additions and 108 deletions

View File

@ -95,6 +95,7 @@ cc_library(
hdrs = ["executable_build_options.h"],
deps = [
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
namespace xla {
@ -99,4 +100,34 @@ string ExecutableBuildOptions::ToString() const {
device_ordinal_, result_layout, num_replicas_);
}
ExecutionOptions CreateExecutionOptions(
const ExecutableBuildOptions& build_options,
const ProgramShape* program_shape) {
ExecutionOptions execution_options = CreateDefaultExecutionOptions();
if (build_options.has_debug_options()) {
*execution_options.mutable_debug_options() = build_options.debug_options();
}
if (build_options.result_layout() != nullptr) {
*execution_options.mutable_shape_with_output_layout() =
build_options.result_layout()->ToProto();
} else {
Shape result_shape(program_shape->result());
LayoutUtil::SetToDefaultLayout(&result_shape);
*execution_options.mutable_shape_with_output_layout() =
result_shape.ToProto();
}
execution_options.set_num_replicas(build_options.num_replicas());
execution_options.set_num_partitions(build_options.num_partitions());
execution_options.set_use_spmd_partitioning(
build_options.use_spmd_partitioning());
execution_options.set_deduplicate_hlo(build_options.deduplicate_hlo());
if (build_options.has_device_assignment()) {
TF_CHECK_OK(build_options.device_assignment().Serialize(
execution_options.mutable_device_assignment()));
}
execution_options.set_alias_passthrough_params(
build_options.alias_passthrough_params());
return execution_options;
}
} // namespace xla

View File

@ -141,6 +141,12 @@ class ExecutableBuildOptions {
tensorflow::thread::ThreadPool* compile_thread_pool_ = nullptr;
};
// Creates an ExecutionOptions based on a given ExecutableBuildOptions and
// ProgramShape.
ExecutionOptions CreateExecutionOptions(
const ExecutableBuildOptions& build_options,
const ProgramShape* program_shape);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_

View File

@ -936,6 +936,7 @@ cc_library(
":hlo_evaluator",
":hlo_execution_profile",
":hlo_module_config",
":hlo_module_util",
":hlo_proto_util",
":platform_util",
":source_map_util",
@ -977,6 +978,7 @@ cc_library(
":hlo",
":hlo_execution_profile",
":hlo_module_config",
":hlo_module_util",
":platform_util",
":service",
":shaped_buffer",
@ -1528,6 +1530,21 @@ cc_library(
],
)
cc_library(
name = "hlo_module_util",
srcs = ["hlo_module_util.cc"],
hdrs = ["hlo_module_util.h"],
deps = [
":compiler",
":hlo_module_config",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:statusor",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "hlo_module_group_util",
srcs = ["hlo_module_group_util.cc"],

View File

@ -0,0 +1,131 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_module_util.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {
namespace {
Status ValidateResultShape(const Shape& client_shape,
const Shape& result_shape) {
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(client_shape));
if (!ShapeUtil::Compatible(client_shape, result_shape)) {
return InvalidArgument(
"Shape used to set computation result layout %s is not compatible "
"with result shape %s",
ShapeUtil::HumanStringWithLayout(client_shape),
ShapeUtil::HumanString(result_shape));
}
return Status::OK();
}
} // namespace
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
const ProgramShape& program_shape,
absl::Span<const Shape* const> argument_shapes,
const ExecutionOptions* execution_options, int default_num_replicas,
absl::optional<int> num_threads, const AotCompilationOptions* aot_options) {
auto config = absl::make_unique<HloModuleConfig>(program_shape);
ComputationLayout* computation_layout =
config->mutable_entry_computation_layout();
const int64 argument_shapes_size = argument_shapes.size();
if (program_shape.parameters_size() != argument_shapes_size) {
return InvalidArgument("computation takes %d parameters, but %u given",
program_shape.parameters_size(),
argument_shapes.size());
}
for (int i = 0, end = argument_shapes.size(); i < end; ++i) {
// Verify that shape of arguments matches the shape of the arguments in the
// ProgramShape.
if (!ShapeUtil::Compatible(*argument_shapes[i],
program_shape.parameters(i))) {
return InvalidArgument(
"Argument does not match shape of computation parameter %d: want "
"%s, got %s",
i, ShapeUtil::HumanString(program_shape.parameters(i)),
ShapeUtil::HumanString(*argument_shapes[i]));
}
TF_RETURN_IF_ERROR(
computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
*argument_shapes[i]));
}
if (execution_options != nullptr &&
execution_options->has_shape_with_output_layout()) {
const Shape shape_with_output_layout(
execution_options->shape_with_output_layout());
TF_RETURN_IF_ERROR(
ValidateResultShape(shape_with_output_layout, program_shape.result()));
TF_RETURN_IF_ERROR(
computation_layout->mutable_result_layout()->CopyLayoutFromShape(
shape_with_output_layout));
} else {
// If the result layout is not set, then choose the default.
computation_layout->mutable_result_layout()->SetToDefaultLayout();
}
if (execution_options != nullptr) {
if (execution_options->num_replicas() > 0) {
config->set_replica_count(execution_options->num_replicas());
} else {
config->set_replica_count(default_num_replicas);
}
if (execution_options->num_partitions() > 0) {
config->set_num_partitions(execution_options->num_partitions());
}
config->set_use_spmd_partitioning(
execution_options->use_spmd_partitioning());
config->set_deduplicate_hlo(execution_options->deduplicate_hlo());
config->set_seed(execution_options->seed());
config->set_launch_id(execution_options->launch_id());
config->set_debug_options(execution_options->debug_options());
} else {
config->set_replica_count(default_num_replicas);
config->set_debug_options(GetDebugOptionsFromFlags());
}
if (num_threads.has_value()) {
config->set_intra_op_parallelism_threads(*num_threads);
}
if (execution_options != nullptr &&
execution_options->has_device_assignment()) {
TF_ASSIGN_OR_RETURN(
auto device_assignment,
DeviceAssignment::Deserialize(execution_options->device_assignment()));
config->set_static_device_assignment(*device_assignment);
}
config->set_alias_passthrough_params(
execution_options->alias_passthrough_params());
if (aot_options != nullptr &&
aot_options->fusion_config_collection() != FusionConfigCollection::kOff) {
config->set_fusion_config_collection(
aot_options->fusion_config_collection());
*config->mutable_fusion_config() = aot_options->fusion_config();
}
return std::move(config);
}
} // namespace xla

View File

@ -0,0 +1,44 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_UTIL_H_
#include <memory>
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {
// Creates an HloModuleConfig for a given program shape and arguments.
// If execution_options does not set num_replicas, default_num_replicas is used.
// num_threads is optional; if not given, intra_op_parallelism_threads not set.
// aot_options is optional; if not given a default is used.
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
const ProgramShape& program_shape,
absl::Span<const Shape* const> argument_shapes,
const ExecutionOptions* execution_options, int default_num_replicas,
absl::optional<int> num_threads = absl::nullopt,
const AotCompilationOptions* aot_options = nullptr);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_UTIL_H_

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/hlo_module_util.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/shape_util.h"
@ -94,36 +95,6 @@ absl::optional<const OpMetadata*> ParameterMetadata(
return absl::nullopt;
}
ExecutionOptions CreateExecutionOptions(
const ExecutableBuildOptions& build_options,
const ProgramShape* program_shape) {
ExecutionOptions execution_options = CreateDefaultExecutionOptions();
if (build_options.has_debug_options()) {
*execution_options.mutable_debug_options() = build_options.debug_options();
}
if (build_options.result_layout() != nullptr) {
*execution_options.mutable_shape_with_output_layout() =
build_options.result_layout()->ToProto();
} else {
Shape result_shape(program_shape->result());
LayoutUtil::SetToDefaultLayout(&result_shape);
*execution_options.mutable_shape_with_output_layout() =
result_shape.ToProto();
}
execution_options.set_num_replicas(build_options.num_replicas());
execution_options.set_num_partitions(build_options.num_partitions());
execution_options.set_use_spmd_partitioning(
build_options.use_spmd_partitioning());
execution_options.set_deduplicate_hlo(build_options.deduplicate_hlo());
if (build_options.has_device_assignment()) {
TF_CHECK_OK(build_options.device_assignment().Serialize(
execution_options.mutable_device_assignment()));
}
execution_options.set_alias_passthrough_params(
build_options.alias_passthrough_params());
return execution_options;
}
} // namespace
StatusOr<std::vector<std::unique_ptr<Executable>>>

View File

@ -38,6 +38,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/hlo_module_util.h"
#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/service/source_map_util.h"
@ -256,88 +257,16 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
absl::Span<const Shape* const> argument_shapes,
const ExecutionOptions* execution_options,
const AotCompilationOptions* aot_options) {
auto config = absl::make_unique<HloModuleConfig>(program_shape);
ComputationLayout* computation_layout =
config->mutable_entry_computation_layout();
const int64 argument_shapes_size = argument_shapes.size();
if (program_shape.parameters_size() != argument_shapes_size) {
return InvalidArgument("computation takes %d parameters, but %u given",
program_shape.parameters_size(),
argument_shapes.size());
}
for (int i = 0, end = argument_shapes.size(); i < end; ++i) {
// Verify that shape of arguments matches the shape of the arguments in the
// ProgramShape.
if (!ShapeUtil::Compatible(*argument_shapes[i],
program_shape.parameters(i))) {
return InvalidArgument(
"Argument does not match shape of computation parameter %d: want "
"%s, got %s",
i, ShapeUtil::HumanString(program_shape.parameters(i)),
ShapeUtil::HumanString(*argument_shapes[i]));
}
TF_RETURN_IF_ERROR(
computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
*argument_shapes[i]));
}
if (execution_options != nullptr &&
execution_options->has_shape_with_output_layout()) {
const Shape shape_with_output_layout(
execution_options->shape_with_output_layout());
TF_RETURN_IF_ERROR(
ValidateResultShape(shape_with_output_layout, program_shape.result()));
TF_RETURN_IF_ERROR(
computation_layout->mutable_result_layout()->CopyLayoutFromShape(
shape_with_output_layout));
} else {
// If the result layout is not set, then choose the default.
computation_layout->mutable_result_layout()->SetToDefaultLayout();
}
if (execution_options != nullptr) {
if (execution_options->num_replicas() > 0) {
config->set_replica_count(execution_options->num_replicas());
} else {
config->set_replica_count(options_.number_of_replicas());
}
if (execution_options->num_partitions() > 0) {
config->set_num_partitions(execution_options->num_partitions());
}
config->set_use_spmd_partitioning(
execution_options->use_spmd_partitioning());
config->set_deduplicate_hlo(execution_options->deduplicate_hlo());
config->set_seed(execution_options->seed());
config->set_launch_id(execution_options->launch_id());
config->set_debug_options(execution_options->debug_options());
} else {
config->set_replica_count(options_.number_of_replicas());
config->set_debug_options(GetDebugOptionsFromFlags());
}
int default_num_replicas = options_.number_of_replicas();
absl::optional<int> num_threads;
if (execute_backend_ != nullptr &&
execute_backend_->eigen_intra_op_thread_pool() != nullptr) {
config->set_intra_op_parallelism_threads(
execute_backend_->eigen_intra_op_thread_pool()->NumThreads());
num_threads = execute_backend_->eigen_intra_op_thread_pool()->NumThreads();
}
if (execution_options != nullptr &&
execution_options->has_device_assignment()) {
TF_ASSIGN_OR_RETURN(
auto device_assignment,
DeviceAssignment::Deserialize(execution_options->device_assignment()));
config->set_static_device_assignment(*device_assignment);
}
config->set_alias_passthrough_params(
execution_options->alias_passthrough_params());
if (aot_options != nullptr &&
aot_options->fusion_config_collection() != FusionConfigCollection::kOff) {
config->set_fusion_config_collection(
aot_options->fusion_config_collection());
*config->mutable_fusion_config() = aot_options->fusion_config();
}
return std::move(config);
return xla::CreateModuleConfig(program_shape, argument_shapes,
execution_options, default_num_replicas,
num_threads, aot_options);
}
StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(