[XLA] Service::CreateModuleConfig should take the backend as a parameter, rather than assuming it's creating a module for the execute backend. This fixes the TODO from service.cc that I added in my previous commit.

Change: 155389451
This commit is contained in:
A. Unique TensorFlower 2017-05-08 08:46:22 -08:00 committed by TensorFlower Gardener
parent e6bf33e8eb
commit dcb6e92b1d
2 changed files with 25 additions and 30 deletions

View File

@ -290,7 +290,7 @@ StatusOr<std::vector<const Allocation*>> Service::ResolveAndValidateArguments(
StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const Allocation*> arguments,
const ExecutionOptions& execution_options) {
const ExecutionOptions& execution_options, Backend* backend) {
auto module_config = MakeUnique<HloModuleConfig>(program_shape);
auto* computation_layout = module_config->mutable_entry_computation_layout();
@ -330,15 +330,7 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
module_config->enable_hlo_profiling(true);
}
// TODO(bmoses): Fix this properly. This value is wrong if we are creating a
// module for use with the compute_constant_backend_. However, so long as the
// execute_backend_ exists, it works out because we always use a CPU backend
// for the compute_constant_backend_ and CPU backends ignore this value. We
// do need to ensure that the execute_backend_ exists, however, to avoid a
// segfault when computing constants in a CompileOnlyService.
if (execute_backend_) {
module_config->set_replica_count(execute_backend_->Replicas().size());
}
module_config->set_replica_count(backend->Replicas().size());
module_config->set_fast_math_disabled(execution_options.disable_fast_math());
module_config->set_seed(execution_options.seed());
@ -486,7 +478,7 @@ StatusOr<std::shared_ptr<Executable>> Service::BuildAndCacheExecutable(
std::unique_ptr<Executable> executable_unique_ptr,
BuildExecutable(versioned_handle, std::move(module_config),
/*executable_for_compute_constant=*/false, arguments,
execute_backend_.get(), executor));
backend, executor));
if (profile != nullptr) {
uint64 end_micros = tensorflow::Env::Default()->NowMicros();
@ -587,15 +579,14 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
perftools::gputools::DeviceMemoryBase result;
if (backend->Replicas().size() == 1) {
TF_ASSIGN_OR_RETURN(
result,
ExecuteOnStreamWrapper<StatusOr<se::DeviceMemoryBase>>(
executable, &run_options[0], profile, execute_backend_.get(),
[&arguments](Executable* executable,
const ServiceExecutableRunOptions* run_options,
HloExecutionProfile* hlo_execution_profile) {
return executable->ExecuteOnStream(run_options, arguments,
hlo_execution_profile);
}));
result, ExecuteOnStreamWrapper<StatusOr<se::DeviceMemoryBase>>(
executable, &run_options[0], profile, backend,
[&arguments](Executable* executable,
const ServiceExecutableRunOptions* run_options,
HloExecutionProfile* hlo_execution_profile) {
return executable->ExecuteOnStream(run_options, arguments,
hlo_execution_profile);
}));
} else {
std::vector<
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>>
@ -678,7 +669,8 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
// the program and the argument allocations.
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(*program_shape, arg_allocations,
request.execution_options()));
request.execution_options(),
execute_backend_.get()));
VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: "
<< module_config->entry_computation_layout().ToString();
@ -763,9 +755,10 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg,
ResolveAndValidateArguments(arg->arguments(), execute_backend_.get(),
execute_backend_->default_device_ordinal()));
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(*program_shape, arg_allocations,
arg->execution_options()));
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(*program_shape, arg_allocations,
arg->execution_options(), execute_backend_.get()));
VLOG(3) << "Execute created HloModuleConfig computation layout: "
<< module_config->entry_computation_layout().ToString();
@ -830,9 +823,10 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg,
ResolveAndValidateArguments(arg->arguments(), execute_backend_.get(),
execute_backend_->default_device_ordinal()));
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(*program_shape, arg_allocations,
arg->execution_options()));
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(*program_shape, arg_allocations,
arg->execution_options(), execute_backend_.get()));
VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: "
<< module_config->entry_computation_layout().ToString();
@ -1153,7 +1147,8 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg,
}
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(program_shape, {}, execution_options));
CreateModuleConfig(program_shape, {}, execution_options,
compute_constant_backend_.get()));
TF_ASSIGN_OR_RETURN(
std::shared_ptr<Executable> executable,

View File

@ -265,11 +265,11 @@ class Service : public ServiceInterface {
tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments,
const Backend* backend, int device_ordinal);
// Create a Hlo module config foe the given program shape and arguments.
// Create a Hlo module config for the given program shape and arguments.
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const Allocation*> arguments,
const ExecutionOptions& execution_options);
const ExecutionOptions& execution_options, Backend* backend);
// Builds an Executable for the given parameters. If
// executable_for_compute_constant is true, then the executable is intended to