[XLA] Fix fallout from CompileOnlyClient et al refactoring. That change broke computation of constants in the XLA2TF bridge, for a couple of reasons, which are described in the comments adjacent to the fixes.
Change: 155270959
This commit is contained in:
parent
166171df59
commit
1959ae333f
@ -53,7 +53,10 @@ class CompileOnlyService : public Service {
|
|||||||
const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,
|
const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,
|
||||||
const AotCompilationOptions& Options);
|
const AotCompilationOptions& Options);
|
||||||
|
|
||||||
// Override Service methods that require an execute backend.
|
// Override Service methods that require or imply the existence of an
|
||||||
|
// execute backend. Note that this does not include TransferToClient and
|
||||||
|
// TransferToClientInProcess, as computing contants produces global data
|
||||||
|
// that we may wish to transfer.
|
||||||
tensorflow::Status Execute(const ExecuteRequest* arg,
|
tensorflow::Status Execute(const ExecuteRequest* arg,
|
||||||
ExecuteResponse* result) override {
|
ExecuteResponse* result) override {
|
||||||
return Unimplemented("CompileOnlyService does not support execution.");
|
return Unimplemented("CompileOnlyService does not support execution.");
|
||||||
@ -76,35 +79,29 @@ class CompileOnlyService : public Service {
|
|||||||
WaitForExecutionResponse* result) override {
|
WaitForExecutionResponse* result) override {
|
||||||
return Unimplemented("CompileOnlyService does not support execution.");
|
return Unimplemented("CompileOnlyService does not support execution.");
|
||||||
}
|
}
|
||||||
tensorflow::Status TransferToClient(
|
|
||||||
const TransferToClientRequest* arg,
|
|
||||||
TransferToClientResponse* result) override {
|
|
||||||
return Unimplemented("CompileOnlyService does not support data transfers.");
|
|
||||||
}
|
|
||||||
tensorflow::Status TransferToClientInProcess(
|
|
||||||
const TransferToClientInProcessRequest* arg,
|
|
||||||
TransferToClientInProcessResponse* result) override {
|
|
||||||
return Unimplemented("CompileOnlyService does not support data transfers.");
|
|
||||||
}
|
|
||||||
tensorflow::Status TransferToServer(
|
tensorflow::Status TransferToServer(
|
||||||
const TransferToServerRequest* arg,
|
const TransferToServerRequest* arg,
|
||||||
TransferToServerResponse* result) override {
|
TransferToServerResponse* result) override {
|
||||||
return Unimplemented("CompileOnlyService does not support data transfers.");
|
return Unimplemented(
|
||||||
|
"CompileOnlyService does not support device data transfers.");
|
||||||
}
|
}
|
||||||
tensorflow::Status TransferToInfeed(
|
tensorflow::Status TransferToInfeed(
|
||||||
const TransferToInfeedRequest* arg,
|
const TransferToInfeedRequest* arg,
|
||||||
TransferToInfeedResponse* result) override {
|
TransferToInfeedResponse* result) override {
|
||||||
return Unimplemented("CompileOnlyService does not support data transfers.");
|
return Unimplemented(
|
||||||
|
"CompileOnlyService does not support device data transfers.");
|
||||||
}
|
}
|
||||||
tensorflow::Status TransferFromOutfeed(
|
tensorflow::Status TransferFromOutfeed(
|
||||||
const TransferFromOutfeedRequest* arg,
|
const TransferFromOutfeedRequest* arg,
|
||||||
TransferFromOutfeedResponse* result) override {
|
TransferFromOutfeedResponse* result) override {
|
||||||
return Unimplemented("CompileOnlyService does not support data transfers.");
|
return Unimplemented(
|
||||||
|
"CompileOnlyService does not support device data transfers.");
|
||||||
}
|
}
|
||||||
tensorflow::Status TransferToServerInProcess(
|
tensorflow::Status TransferToServerInProcess(
|
||||||
const TransferToServerInProcessRequest* arg,
|
const TransferToServerInProcessRequest* arg,
|
||||||
TransferToServerInProcessResponse* result) override {
|
TransferToServerInProcessResponse* result) override {
|
||||||
return Unimplemented("CompileOnlyService does not support data transfers.");
|
return Unimplemented(
|
||||||
|
"CompileOnlyService does not support device data transfers.");
|
||||||
}
|
}
|
||||||
tensorflow::Status ResetDevice(const ResetDeviceRequest* arg,
|
tensorflow::Status ResetDevice(const ResetDeviceRequest* arg,
|
||||||
ResetDeviceResponse* result) override {
|
ResetDeviceResponse* result) override {
|
||||||
|
@ -330,7 +330,15 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
|
|||||||
module_config->enable_hlo_profiling(true);
|
module_config->enable_hlo_profiling(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
module_config->set_replica_count(execute_backend_->Replicas().size());
|
// TODO(bmoses): Fix this properly. This value is wrong if we are creating a
|
||||||
|
// module for use with the compute_constant_backend_. However, so long as the
|
||||||
|
// execute_backend_ exists, it works out because we always use a CPU backend
|
||||||
|
// for the compute_constant_backend_ and CPU backends ignore this value. We
|
||||||
|
// do need to ensure that the execute_backend_ exists, however, to avoid a
|
||||||
|
// segfault when computing constants in a CompileOnlyService.
|
||||||
|
if (execute_backend_) {
|
||||||
|
module_config->set_replica_count(execute_backend_->Replicas().size());
|
||||||
|
}
|
||||||
module_config->set_fast_math_disabled(execution_options.disable_fast_math());
|
module_config->set_fast_math_disabled(execution_options.disable_fast_math());
|
||||||
module_config->set_seed(execution_options.seed());
|
module_config->set_seed(execution_options.seed());
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user