Move CreateModuleConfig to a new hlo_module_util header.
Move CreateExecutionOptions to xla/client/executable_build_options.h. PiperOrigin-RevId: 347478735 Change-Id: Iab07bcd31e89ee8e24f2385d0a515dba03e505ad
This commit is contained in:
parent
fe6d0cf3f9
commit
e7e9a0c449
tensorflow/compiler/xla
@ -95,6 +95,7 @@ cc_library(
|
||||
hdrs = ["executable_build_options.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla:execution_options_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/execution_options_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
|
||||
namespace xla {
|
||||
@ -99,4 +100,34 @@ string ExecutableBuildOptions::ToString() const {
|
||||
device_ordinal_, result_layout, num_replicas_);
|
||||
}
|
||||
|
||||
ExecutionOptions CreateExecutionOptions(
|
||||
const ExecutableBuildOptions& build_options,
|
||||
const ProgramShape* program_shape) {
|
||||
ExecutionOptions execution_options = CreateDefaultExecutionOptions();
|
||||
if (build_options.has_debug_options()) {
|
||||
*execution_options.mutable_debug_options() = build_options.debug_options();
|
||||
}
|
||||
if (build_options.result_layout() != nullptr) {
|
||||
*execution_options.mutable_shape_with_output_layout() =
|
||||
build_options.result_layout()->ToProto();
|
||||
} else {
|
||||
Shape result_shape(program_shape->result());
|
||||
LayoutUtil::SetToDefaultLayout(&result_shape);
|
||||
*execution_options.mutable_shape_with_output_layout() =
|
||||
result_shape.ToProto();
|
||||
}
|
||||
execution_options.set_num_replicas(build_options.num_replicas());
|
||||
execution_options.set_num_partitions(build_options.num_partitions());
|
||||
execution_options.set_use_spmd_partitioning(
|
||||
build_options.use_spmd_partitioning());
|
||||
execution_options.set_deduplicate_hlo(build_options.deduplicate_hlo());
|
||||
if (build_options.has_device_assignment()) {
|
||||
TF_CHECK_OK(build_options.device_assignment().Serialize(
|
||||
execution_options.mutable_device_assignment()));
|
||||
}
|
||||
execution_options.set_alias_passthrough_params(
|
||||
build_options.alias_passthrough_params());
|
||||
return execution_options;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -141,6 +141,12 @@ class ExecutableBuildOptions {
|
||||
tensorflow::thread::ThreadPool* compile_thread_pool_ = nullptr;
|
||||
};
|
||||
|
||||
// Creates an ExecutionOptions based on a given ExecutableBuildOptions and
|
||||
// ProgramShape.
|
||||
ExecutionOptions CreateExecutionOptions(
|
||||
const ExecutableBuildOptions& build_options,
|
||||
const ProgramShape* program_shape);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_
|
||||
|
@ -936,6 +936,7 @@ cc_library(
|
||||
":hlo_evaluator",
|
||||
":hlo_execution_profile",
|
||||
":hlo_module_config",
|
||||
":hlo_module_util",
|
||||
":hlo_proto_util",
|
||||
":platform_util",
|
||||
":source_map_util",
|
||||
@ -977,6 +978,7 @@ cc_library(
|
||||
":hlo",
|
||||
":hlo_execution_profile",
|
||||
":hlo_module_config",
|
||||
":hlo_module_util",
|
||||
":platform_util",
|
||||
":service",
|
||||
":shaped_buffer",
|
||||
@ -1528,6 +1530,21 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_module_util",
|
||||
srcs = ["hlo_module_util.cc"],
|
||||
hdrs = ["hlo_module_util.h"],
|
||||
deps = [
|
||||
":compiler",
|
||||
":hlo_module_config",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_module_group_util",
|
||||
srcs = ["hlo_module_group_util.cc"],
|
||||
|
131
tensorflow/compiler/xla/service/hlo_module_util.cc
Normal file
131
tensorflow/compiler/xla/service/hlo_module_util.cc
Normal file
@ -0,0 +1,131 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_util.h"
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
Status ValidateResultShape(const Shape& client_shape,
|
||||
const Shape& result_shape) {
|
||||
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(client_shape));
|
||||
if (!ShapeUtil::Compatible(client_shape, result_shape)) {
|
||||
return InvalidArgument(
|
||||
"Shape used to set computation result layout %s is not compatible "
|
||||
"with result shape %s",
|
||||
ShapeUtil::HumanStringWithLayout(client_shape),
|
||||
ShapeUtil::HumanString(result_shape));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
|
||||
const ProgramShape& program_shape,
|
||||
absl::Span<const Shape* const> argument_shapes,
|
||||
const ExecutionOptions* execution_options, int default_num_replicas,
|
||||
absl::optional<int> num_threads, const AotCompilationOptions* aot_options) {
|
||||
auto config = absl::make_unique<HloModuleConfig>(program_shape);
|
||||
ComputationLayout* computation_layout =
|
||||
config->mutable_entry_computation_layout();
|
||||
const int64 argument_shapes_size = argument_shapes.size();
|
||||
if (program_shape.parameters_size() != argument_shapes_size) {
|
||||
return InvalidArgument("computation takes %d parameters, but %u given",
|
||||
program_shape.parameters_size(),
|
||||
argument_shapes.size());
|
||||
}
|
||||
for (int i = 0, end = argument_shapes.size(); i < end; ++i) {
|
||||
// Verify that shape of arguments matches the shape of the arguments in the
|
||||
// ProgramShape.
|
||||
if (!ShapeUtil::Compatible(*argument_shapes[i],
|
||||
program_shape.parameters(i))) {
|
||||
return InvalidArgument(
|
||||
"Argument does not match shape of computation parameter %d: want "
|
||||
"%s, got %s",
|
||||
i, ShapeUtil::HumanString(program_shape.parameters(i)),
|
||||
ShapeUtil::HumanString(*argument_shapes[i]));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
|
||||
*argument_shapes[i]));
|
||||
}
|
||||
if (execution_options != nullptr &&
|
||||
execution_options->has_shape_with_output_layout()) {
|
||||
const Shape shape_with_output_layout(
|
||||
execution_options->shape_with_output_layout());
|
||||
TF_RETURN_IF_ERROR(
|
||||
ValidateResultShape(shape_with_output_layout, program_shape.result()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation_layout->mutable_result_layout()->CopyLayoutFromShape(
|
||||
shape_with_output_layout));
|
||||
} else {
|
||||
// If the result layout is not set, then choose the default.
|
||||
computation_layout->mutable_result_layout()->SetToDefaultLayout();
|
||||
}
|
||||
|
||||
if (execution_options != nullptr) {
|
||||
if (execution_options->num_replicas() > 0) {
|
||||
config->set_replica_count(execution_options->num_replicas());
|
||||
} else {
|
||||
config->set_replica_count(default_num_replicas);
|
||||
}
|
||||
if (execution_options->num_partitions() > 0) {
|
||||
config->set_num_partitions(execution_options->num_partitions());
|
||||
}
|
||||
config->set_use_spmd_partitioning(
|
||||
execution_options->use_spmd_partitioning());
|
||||
config->set_deduplicate_hlo(execution_options->deduplicate_hlo());
|
||||
config->set_seed(execution_options->seed());
|
||||
config->set_launch_id(execution_options->launch_id());
|
||||
config->set_debug_options(execution_options->debug_options());
|
||||
} else {
|
||||
config->set_replica_count(default_num_replicas);
|
||||
config->set_debug_options(GetDebugOptionsFromFlags());
|
||||
}
|
||||
|
||||
if (num_threads.has_value()) {
|
||||
config->set_intra_op_parallelism_threads(*num_threads);
|
||||
}
|
||||
|
||||
if (execution_options != nullptr &&
|
||||
execution_options->has_device_assignment()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto device_assignment,
|
||||
DeviceAssignment::Deserialize(execution_options->device_assignment()));
|
||||
config->set_static_device_assignment(*device_assignment);
|
||||
}
|
||||
config->set_alias_passthrough_params(
|
||||
execution_options->alias_passthrough_params());
|
||||
|
||||
if (aot_options != nullptr &&
|
||||
aot_options->fusion_config_collection() != FusionConfigCollection::kOff) {
|
||||
config->set_fusion_config_collection(
|
||||
aot_options->fusion_config_collection());
|
||||
*config->mutable_fusion_config() = aot_options->fusion_config();
|
||||
}
|
||||
|
||||
return std::move(config);
|
||||
}
|
||||
|
||||
} // namespace xla
|
44
tensorflow/compiler/xla/service/hlo_module_util.h
Normal file
44
tensorflow/compiler/xla/service/hlo_module_util.h
Normal file
@ -0,0 +1,44 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_UTIL_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Creates an HloModuleConfig for a given program shape and arguments.
|
||||
// If execution_options does not set num_replicas, default_num_replicas is used.
|
||||
// num_threads is optional; if not given, intra_op_parallelism_threads not set.
|
||||
// aot_options is optional; if not given a default is used.
|
||||
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
|
||||
const ProgramShape& program_shape,
|
||||
absl::Span<const Shape* const> argument_shapes,
|
||||
const ExecutionOptions* execution_options, int default_num_replicas,
|
||||
absl::optional<int> num_threads = absl::nullopt,
|
||||
const AotCompilationOptions* aot_options = nullptr);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_UTIL_H_
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_util.h"
|
||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_layout.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
@ -94,36 +95,6 @@ absl::optional<const OpMetadata*> ParameterMetadata(
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
ExecutionOptions CreateExecutionOptions(
|
||||
const ExecutableBuildOptions& build_options,
|
||||
const ProgramShape* program_shape) {
|
||||
ExecutionOptions execution_options = CreateDefaultExecutionOptions();
|
||||
if (build_options.has_debug_options()) {
|
||||
*execution_options.mutable_debug_options() = build_options.debug_options();
|
||||
}
|
||||
if (build_options.result_layout() != nullptr) {
|
||||
*execution_options.mutable_shape_with_output_layout() =
|
||||
build_options.result_layout()->ToProto();
|
||||
} else {
|
||||
Shape result_shape(program_shape->result());
|
||||
LayoutUtil::SetToDefaultLayout(&result_shape);
|
||||
*execution_options.mutable_shape_with_output_layout() =
|
||||
result_shape.ToProto();
|
||||
}
|
||||
execution_options.set_num_replicas(build_options.num_replicas());
|
||||
execution_options.set_num_partitions(build_options.num_partitions());
|
||||
execution_options.set_use_spmd_partitioning(
|
||||
build_options.use_spmd_partitioning());
|
||||
execution_options.set_deduplicate_hlo(build_options.deduplicate_hlo());
|
||||
if (build_options.has_device_assignment()) {
|
||||
TF_CHECK_OK(build_options.device_assignment().Serialize(
|
||||
execution_options.mutable_device_assignment()));
|
||||
}
|
||||
execution_options.set_alias_passthrough_params(
|
||||
build_options.alias_passthrough_params());
|
||||
return execution_options;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<Executable>>>
|
||||
|
@ -38,6 +38,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
|
||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||
#include "tensorflow/compiler/xla/service/source_map_util.h"
|
||||
@ -256,88 +257,16 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
|
||||
absl::Span<const Shape* const> argument_shapes,
|
||||
const ExecutionOptions* execution_options,
|
||||
const AotCompilationOptions* aot_options) {
|
||||
auto config = absl::make_unique<HloModuleConfig>(program_shape);
|
||||
ComputationLayout* computation_layout =
|
||||
config->mutable_entry_computation_layout();
|
||||
const int64 argument_shapes_size = argument_shapes.size();
|
||||
if (program_shape.parameters_size() != argument_shapes_size) {
|
||||
return InvalidArgument("computation takes %d parameters, but %u given",
|
||||
program_shape.parameters_size(),
|
||||
argument_shapes.size());
|
||||
}
|
||||
for (int i = 0, end = argument_shapes.size(); i < end; ++i) {
|
||||
// Verify that shape of arguments matches the shape of the arguments in the
|
||||
// ProgramShape.
|
||||
if (!ShapeUtil::Compatible(*argument_shapes[i],
|
||||
program_shape.parameters(i))) {
|
||||
return InvalidArgument(
|
||||
"Argument does not match shape of computation parameter %d: want "
|
||||
"%s, got %s",
|
||||
i, ShapeUtil::HumanString(program_shape.parameters(i)),
|
||||
ShapeUtil::HumanString(*argument_shapes[i]));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
|
||||
*argument_shapes[i]));
|
||||
}
|
||||
if (execution_options != nullptr &&
|
||||
execution_options->has_shape_with_output_layout()) {
|
||||
const Shape shape_with_output_layout(
|
||||
execution_options->shape_with_output_layout());
|
||||
TF_RETURN_IF_ERROR(
|
||||
ValidateResultShape(shape_with_output_layout, program_shape.result()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation_layout->mutable_result_layout()->CopyLayoutFromShape(
|
||||
shape_with_output_layout));
|
||||
} else {
|
||||
// If the result layout is not set, then choose the default.
|
||||
computation_layout->mutable_result_layout()->SetToDefaultLayout();
|
||||
}
|
||||
|
||||
if (execution_options != nullptr) {
|
||||
if (execution_options->num_replicas() > 0) {
|
||||
config->set_replica_count(execution_options->num_replicas());
|
||||
} else {
|
||||
config->set_replica_count(options_.number_of_replicas());
|
||||
}
|
||||
if (execution_options->num_partitions() > 0) {
|
||||
config->set_num_partitions(execution_options->num_partitions());
|
||||
}
|
||||
config->set_use_spmd_partitioning(
|
||||
execution_options->use_spmd_partitioning());
|
||||
config->set_deduplicate_hlo(execution_options->deduplicate_hlo());
|
||||
config->set_seed(execution_options->seed());
|
||||
config->set_launch_id(execution_options->launch_id());
|
||||
config->set_debug_options(execution_options->debug_options());
|
||||
} else {
|
||||
config->set_replica_count(options_.number_of_replicas());
|
||||
config->set_debug_options(GetDebugOptionsFromFlags());
|
||||
}
|
||||
|
||||
int default_num_replicas = options_.number_of_replicas();
|
||||
absl::optional<int> num_threads;
|
||||
if (execute_backend_ != nullptr &&
|
||||
execute_backend_->eigen_intra_op_thread_pool() != nullptr) {
|
||||
config->set_intra_op_parallelism_threads(
|
||||
execute_backend_->eigen_intra_op_thread_pool()->NumThreads());
|
||||
num_threads = execute_backend_->eigen_intra_op_thread_pool()->NumThreads();
|
||||
}
|
||||
|
||||
if (execution_options != nullptr &&
|
||||
execution_options->has_device_assignment()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto device_assignment,
|
||||
DeviceAssignment::Deserialize(execution_options->device_assignment()));
|
||||
config->set_static_device_assignment(*device_assignment);
|
||||
}
|
||||
config->set_alias_passthrough_params(
|
||||
execution_options->alias_passthrough_params());
|
||||
|
||||
if (aot_options != nullptr &&
|
||||
aot_options->fusion_config_collection() != FusionConfigCollection::kOff) {
|
||||
config->set_fusion_config_collection(
|
||||
aot_options->fusion_config_collection());
|
||||
*config->mutable_fusion_config() = aot_options->fusion_config();
|
||||
}
|
||||
|
||||
return std::move(config);
|
||||
return xla::CreateModuleConfig(program_shape, argument_shapes,
|
||||
execution_options, default_num_replicas,
|
||||
num_threads, aot_options);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
|
||||
|
Loading…
Reference in New Issue
Block a user