[XLA] Service::CreateModuleConfig should take the backend as a parameter, rather than assuming it's creating a module for the execute backend. This fixes the TODO from service.cc that I added in my previous commit.
Change: 155389451
This commit is contained in:
parent
e6bf33e8eb
commit
dcb6e92b1d
@ -290,7 +290,7 @@ StatusOr<std::vector<const Allocation*>> Service::ResolveAndValidateArguments(
|
||||
StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
|
||||
const ProgramShape& program_shape,
|
||||
tensorflow::gtl::ArraySlice<const Allocation*> arguments,
|
||||
const ExecutionOptions& execution_options) {
|
||||
const ExecutionOptions& execution_options, Backend* backend) {
|
||||
auto module_config = MakeUnique<HloModuleConfig>(program_shape);
|
||||
auto* computation_layout = module_config->mutable_entry_computation_layout();
|
||||
|
||||
@ -330,15 +330,7 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
|
||||
module_config->enable_hlo_profiling(true);
|
||||
}
|
||||
|
||||
// TODO(bmoses): Fix this properly. This value is wrong if we are creating a
|
||||
// module for use with the compute_constant_backend_. However, so long as the
|
||||
// execute_backend_ exists, it works out because we always use a CPU backend
|
||||
// for the compute_constant_backend_ and CPU backends ignore this value. We
|
||||
// do need to ensure that the execute_backend_ exists, however, to avoid a
|
||||
// segfault when computing constants in a CompileOnlyService.
|
||||
if (execute_backend_) {
|
||||
module_config->set_replica_count(execute_backend_->Replicas().size());
|
||||
}
|
||||
module_config->set_replica_count(backend->Replicas().size());
|
||||
module_config->set_fast_math_disabled(execution_options.disable_fast_math());
|
||||
module_config->set_seed(execution_options.seed());
|
||||
|
||||
@ -486,7 +478,7 @@ StatusOr<std::shared_ptr<Executable>> Service::BuildAndCacheExecutable(
|
||||
std::unique_ptr<Executable> executable_unique_ptr,
|
||||
BuildExecutable(versioned_handle, std::move(module_config),
|
||||
/*executable_for_compute_constant=*/false, arguments,
|
||||
execute_backend_.get(), executor));
|
||||
backend, executor));
|
||||
|
||||
if (profile != nullptr) {
|
||||
uint64 end_micros = tensorflow::Env::Default()->NowMicros();
|
||||
@ -587,15 +579,14 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
|
||||
perftools::gputools::DeviceMemoryBase result;
|
||||
if (backend->Replicas().size() == 1) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
result,
|
||||
ExecuteOnStreamWrapper<StatusOr<se::DeviceMemoryBase>>(
|
||||
executable, &run_options[0], profile, execute_backend_.get(),
|
||||
[&arguments](Executable* executable,
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
return executable->ExecuteOnStream(run_options, arguments,
|
||||
hlo_execution_profile);
|
||||
}));
|
||||
result, ExecuteOnStreamWrapper<StatusOr<se::DeviceMemoryBase>>(
|
||||
executable, &run_options[0], profile, backend,
|
||||
[&arguments](Executable* executable,
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
return executable->ExecuteOnStream(run_options, arguments,
|
||||
hlo_execution_profile);
|
||||
}));
|
||||
} else {
|
||||
std::vector<
|
||||
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>>
|
||||
@ -678,7 +669,8 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
|
||||
// the program and the argument allocations.
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config,
|
||||
CreateModuleConfig(*program_shape, arg_allocations,
|
||||
request.execution_options()));
|
||||
request.execution_options(),
|
||||
execute_backend_.get()));
|
||||
VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: "
|
||||
<< module_config->entry_computation_layout().ToString();
|
||||
|
||||
@ -763,9 +755,10 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg,
|
||||
ResolveAndValidateArguments(arg->arguments(), execute_backend_.get(),
|
||||
execute_backend_->default_device_ordinal()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config,
|
||||
CreateModuleConfig(*program_shape, arg_allocations,
|
||||
arg->execution_options()));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<HloModuleConfig> module_config,
|
||||
CreateModuleConfig(*program_shape, arg_allocations,
|
||||
arg->execution_options(), execute_backend_.get()));
|
||||
|
||||
VLOG(3) << "Execute created HloModuleConfig computation layout: "
|
||||
<< module_config->entry_computation_layout().ToString();
|
||||
@ -830,9 +823,10 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg,
|
||||
ResolveAndValidateArguments(arg->arguments(), execute_backend_.get(),
|
||||
execute_backend_->default_device_ordinal()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config,
|
||||
CreateModuleConfig(*program_shape, arg_allocations,
|
||||
arg->execution_options()));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<HloModuleConfig> module_config,
|
||||
CreateModuleConfig(*program_shape, arg_allocations,
|
||||
arg->execution_options(), execute_backend_.get()));
|
||||
|
||||
VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: "
|
||||
<< module_config->entry_computation_layout().ToString();
|
||||
@ -1153,7 +1147,8 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg,
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config,
|
||||
CreateModuleConfig(program_shape, {}, execution_options));
|
||||
CreateModuleConfig(program_shape, {}, execution_options,
|
||||
compute_constant_backend_.get()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::shared_ptr<Executable> executable,
|
||||
|
@ -265,11 +265,11 @@ class Service : public ServiceInterface {
|
||||
tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments,
|
||||
const Backend* backend, int device_ordinal);
|
||||
|
||||
// Create a Hlo module config foe the given program shape and arguments.
|
||||
// Create a Hlo module config for the given program shape and arguments.
|
||||
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
|
||||
const ProgramShape& program_shape,
|
||||
tensorflow::gtl::ArraySlice<const Allocation*> arguments,
|
||||
const ExecutionOptions& execution_options);
|
||||
const ExecutionOptions& execution_options, Backend* backend);
|
||||
|
||||
// Builds an Executable for the given parameters. If
|
||||
// executable_for_compute_constant is true, then the executable is intended to
|
||||
|
Loading…
Reference in New Issue
Block a user