[XLA] Refactor CreateModuleConfig to share code between multiple call-sites.
Previously Service, LocalService and CompileOnlyService had their own code to create a new HloModuleConfig, with much repetition (and some ommissions); collect all these uses in a single method. PiperOrigin-RevId: 163766869
This commit is contained in:
parent
6150611ae2
commit
78a90370ef
tensorflow/compiler/xla/service
@ -473,6 +473,7 @@ cc_library(
|
||||
":shaped_buffer",
|
||||
":user_computation",
|
||||
":versioned_computation_handle",
|
||||
"//tensorflow/compiler/xla:execution_options_util",
|
||||
"//tensorflow/compiler/xla:shape_layout",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
@ -480,7 +481,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
],
|
||||
|
@ -96,29 +96,17 @@ CompileOnlyService::CompileAheadOfTime(
|
||||
std::shared_ptr<const ProgramShape> program_shape,
|
||||
user_computation->ComputeProgramShape(versioned_handle.version));
|
||||
|
||||
HloModuleConfig hlo_module_config(*program_shape);
|
||||
hlo_module_config.set_debug_options(debug_options);
|
||||
auto* computation_layout =
|
||||
hlo_module_config.mutable_entry_computation_layout();
|
||||
if (debug_options.xla_hlo_profile()) {
|
||||
hlo_module_config.enable_hlo_profiling(true);
|
||||
}
|
||||
for (int i = 0; i < instance.argument_layouts.size(); ++i) {
|
||||
const Shape& argument_layout = *instance.argument_layouts[i];
|
||||
if (ShapeUtil::IsTuple(argument_layout)) {
|
||||
return Unimplemented("tuple arguments not supported yet");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
|
||||
argument_layout));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation_layout->mutable_result_layout()->CopyLayoutFromShape(
|
||||
*instance.result_layout));
|
||||
ExecutionOptions execution_options;
|
||||
*execution_options.mutable_debug_options() = debug_options;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<HloModuleConfig> module_config,
|
||||
CreateModuleConfig(*program_shape, instance.argument_layouts,
|
||||
&execution_options,
|
||||
/*has_hybrid_result=*/false));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> hlo_module,
|
||||
computation_tracker_.BuildHloModule(
|
||||
versioned_handle, hlo_module_config,
|
||||
versioned_handle, *module_config,
|
||||
/*include_unreachable_instructions=*/true));
|
||||
hlo_modules.push_back(std::move(hlo_module));
|
||||
}
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/execution_options_util.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/backend.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_layout.h"
|
||||
@ -141,35 +141,19 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
|
||||
ValidateResultShapeWithLayout(*result_layout, program_shape->result()));
|
||||
}
|
||||
|
||||
// Construct computation layout from the argument layouts.
|
||||
auto module_config = MakeUnique<HloModuleConfig>(*program_shape);
|
||||
module_config->set_has_hybrid_result(has_hybrid_result);
|
||||
module_config->set_replica_count(options_.number_of_replicas());
|
||||
module_config->set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
|
||||
if (execute_backend_->eigen_intra_op_thread_pool() != nullptr) {
|
||||
module_config->set_intra_op_parallelism_threads(
|
||||
execute_backend_->eigen_intra_op_thread_pool()->NumThreads());
|
||||
}
|
||||
if (module_config->debug_options().xla_hlo_profile()) {
|
||||
module_config->enable_hlo_profiling(true);
|
||||
}
|
||||
auto* computation_layout = module_config->mutable_entry_computation_layout();
|
||||
for (int i = 0; i < argument_layouts.size(); ++i) {
|
||||
const Shape& shape = *argument_layouts[i];
|
||||
if (ShapeUtil::IsTuple(shape)) {
|
||||
return Unimplemented("tuple arguments not supported yet");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
|
||||
shape));
|
||||
}
|
||||
ExecutionOptions execution_options = CreateDefaultExecutionOptions();
|
||||
if (result_layout != nullptr) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation_layout->mutable_result_layout()->CopyLayoutFromShape(
|
||||
*result_layout));
|
||||
*execution_options.mutable_shape_with_output_layout() = *result_layout;
|
||||
} else {
|
||||
computation_layout->mutable_result_layout()->SetToDefaultLayout();
|
||||
*execution_options.mutable_shape_with_output_layout() =
|
||||
program_shape->result();
|
||||
LayoutUtil::SetToDefaultLayout(
|
||||
execution_options.mutable_shape_with_output_layout());
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<HloModuleConfig> module_config,
|
||||
CreateModuleConfig(*program_shape, argument_layouts, &execution_options,
|
||||
has_hybrid_result));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
|
||||
execute_backend_->stream_executor(device_ordinal));
|
||||
|
@ -286,51 +286,71 @@ StatusOr<std::vector<const Allocation*>> Service::ResolveAndValidateArguments(
|
||||
|
||||
StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
|
||||
const ProgramShape& program_shape,
|
||||
tensorflow::gtl::ArraySlice<const Allocation*> arguments,
|
||||
const ExecutionOptions& execution_options) {
|
||||
auto module_config = MakeUnique<HloModuleConfig>(program_shape);
|
||||
auto* computation_layout = module_config->mutable_entry_computation_layout();
|
||||
tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,
|
||||
const ExecutionOptions* execution_options, bool has_hybrid_result) {
|
||||
auto config = MakeUnique<HloModuleConfig>(program_shape);
|
||||
auto* computation_layout = config->mutable_entry_computation_layout();
|
||||
|
||||
if (program_shape.parameters_size() != arguments.size()) {
|
||||
if (program_shape.parameters_size() != argument_shapes.size()) {
|
||||
return InvalidArgument("computation takes %d parameters, but %zu given",
|
||||
program_shape.parameters_size(), arguments.size());
|
||||
program_shape.parameters_size(),
|
||||
argument_shapes.size());
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < arguments.size(); ++i) {
|
||||
for (int i = 0; i < argument_shapes.size(); ++i) {
|
||||
// Verify that shape of arguments matches the shape of the arguments in the
|
||||
// ProgramShape.
|
||||
if (!ShapeUtil::Compatible(arguments[i]->shape(),
|
||||
if (!ShapeUtil::Compatible(*argument_shapes[i],
|
||||
program_shape.parameters(i))) {
|
||||
return InvalidArgument(
|
||||
"computation expects parameter %lu to have shape %s, given shape %s",
|
||||
"computation expects parameter %d to have shape %s, given shape %s",
|
||||
i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(),
|
||||
ShapeUtil::HumanString(arguments[i]->shape()).c_str());
|
||||
ShapeUtil::HumanString(*argument_shapes[i]).c_str());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
|
||||
arguments[i]->shape()));
|
||||
*argument_shapes[i]));
|
||||
}
|
||||
if (!execution_options.has_shape_with_output_layout()) {
|
||||
computation_layout->mutable_result_layout()->Clear();
|
||||
} else {
|
||||
if (execution_options != nullptr &&
|
||||
execution_options->has_shape_with_output_layout()) {
|
||||
const auto& shape_with_output_layout =
|
||||
execution_options.shape_with_output_layout();
|
||||
execution_options->shape_with_output_layout();
|
||||
TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout(shape_with_output_layout,
|
||||
program_shape.result()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation_layout->mutable_result_layout()->CopyLayoutFromShape(
|
||||
shape_with_output_layout));
|
||||
} else {
|
||||
computation_layout->mutable_result_layout()->Clear();
|
||||
}
|
||||
|
||||
if (execution_options.debug_options().xla_hlo_profile()) {
|
||||
module_config->enable_hlo_profiling(true);
|
||||
config->set_replica_count(options_.number_of_replicas());
|
||||
config->set_has_hybrid_result(has_hybrid_result);
|
||||
if (execution_options != nullptr) {
|
||||
config->set_seed(execution_options->seed());
|
||||
config->set_debug_options(execution_options->debug_options());
|
||||
config->enable_hlo_profiling(
|
||||
execution_options->debug_options().xla_hlo_profile());
|
||||
} else {
|
||||
config->set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
|
||||
}
|
||||
|
||||
module_config->set_replica_count(options_.number_of_replicas());
|
||||
module_config->set_seed(execution_options.seed());
|
||||
module_config->set_debug_options(execution_options.debug_options());
|
||||
if (execute_backend_ != nullptr &&
|
||||
execute_backend_->eigen_intra_op_thread_pool() != nullptr) {
|
||||
config->set_intra_op_parallelism_threads(
|
||||
execute_backend_->eigen_intra_op_thread_pool()->NumThreads());
|
||||
}
|
||||
return std::move(config);
|
||||
}
|
||||
|
||||
return std::move(module_config);
|
||||
StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
|
||||
const ProgramShape& program_shape,
|
||||
tensorflow::gtl::ArraySlice<const Allocation*> arguments,
|
||||
const ExecutionOptions& execution_options) {
|
||||
std::vector<const Shape*> argument_shapes;
|
||||
for (const auto* arg : arguments) {
|
||||
argument_shapes.push_back(&arg->shape());
|
||||
}
|
||||
return CreateModuleConfig(program_shape, argument_shapes, &execution_options);
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
|
||||
|
@ -245,6 +245,14 @@ class Service : public ServiceInterface {
|
||||
const Backend& backend() const { return *execute_backend_; }
|
||||
Backend* mutable_backend() { return execute_backend_.get(); }
|
||||
|
||||
private:
|
||||
// A private overload for Service itself, used by other methods within this
|
||||
// class.
|
||||
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
|
||||
const ProgramShape& program_shape,
|
||||
tensorflow::gtl::ArraySlice<const Allocation*> arguments,
|
||||
const ExecutionOptions& execution_options);
|
||||
|
||||
protected:
|
||||
friend class LocalExecutable;
|
||||
|
||||
@ -263,10 +271,14 @@ class Service : public ServiceInterface {
|
||||
const Backend* backend, int device_ordinal);
|
||||
|
||||
// Create a Hlo module config for the given program shape and arguments.
|
||||
// execution_options is optional; if not given a default is used.
|
||||
// has_hybrid_result is used to initialize the same-named field in
|
||||
// HloModuleConfig -- see that class for documentation.
|
||||
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
|
||||
const ProgramShape& program_shape,
|
||||
tensorflow::gtl::ArraySlice<const Allocation*> arguments,
|
||||
const ExecutionOptions& execution_options);
|
||||
tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,
|
||||
const ExecutionOptions* execution_options,
|
||||
bool has_hybrid_result = false);
|
||||
|
||||
// Builds an Executable for the given parameters. If
|
||||
// executable_for_compute_constant is true, then the executable is intended to
|
||||
|
Loading…
Reference in New Issue
Block a user