replace PFLR DeviceGetContext hardcode with Device::IsRemoteCallAllowed
This delegates decision of whether a particular device type can host a remote call to the device implementation itself. Change-Id: I189e9d965ade7386b0c24da6c90d369b1ed72c3e
This commit is contained in:
parent
270db009cf
commit
35d8bef702
@ -53,6 +53,8 @@ class CompositeDevice : public Device {
|
||||
const std::vector<string>& underlying_devices, const string& device_name,
|
||||
Status* status);
|
||||
|
||||
bool IsRemoteCallAllowed() const override { return false; }
|
||||
|
||||
private:
|
||||
CompositeDevice(const DeviceAttributes& device_attributes,
|
||||
const std::vector<string>& underlying_devices)
|
||||
|
@ -200,13 +200,15 @@ Status ProcessFunctionLibraryRuntime::GetDeviceContext(
|
||||
// "TPU_SYSTEM" indicates that `device` is a CPU.
|
||||
return Status::OK();
|
||||
}
|
||||
if (device_type == "GPU" || device_type == "TPU") {
|
||||
|
||||
if (device->IsRemoteCallAllowed()) {
|
||||
auto* dev_info = flr->device()->tensorflow_gpu_device_info();
|
||||
if (dev_info) {
|
||||
*device_context = dev_info->default_context;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
return errors::Internal("Device type: ", device_type,
|
||||
" is currently unsupported for remote ",
|
||||
"function executions");
|
||||
|
@ -149,6 +149,8 @@ class RenamedDevice : public Device {
|
||||
|
||||
bool IsLocal() const override { return underlying_device_->IsLocal(); }
|
||||
|
||||
bool IsRemoteCallAllowed() const override { return underlying_device_->IsRemoteCallAllowed(); }
|
||||
|
||||
private:
|
||||
RenamedDevice(Device* underlying, const DeviceAttributes& attributes,
|
||||
bool owns_underlying, bool isolate_session_state,
|
||||
|
@ -50,6 +50,8 @@ class RemoteDevice : public Device {
|
||||
|
||||
bool IsLocal() const override { return false; }
|
||||
|
||||
bool IsRemoteCallAllowed() const override { return true; }
|
||||
|
||||
private:
|
||||
const string local_dev_name_;
|
||||
|
||||
|
@ -54,4 +54,22 @@ DeviceAttributes Device::BuildDeviceAttributes(
|
||||
return da;
|
||||
}
|
||||
|
||||
|
||||
bool Device::IsRemoteCallAllowed() const {
|
||||
auto &type = parsed_name_.type;
|
||||
if (type == "TPU") {
|
||||
return true;
|
||||
}
|
||||
if (type == "TPU_SYSTEM") {
|
||||
return true;
|
||||
}
|
||||
if (type == "CPU") {
|
||||
return true;
|
||||
}
|
||||
if (type == "GPU") {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -178,6 +178,9 @@ class Device : public DeviceBase {
|
||||
|
||||
virtual bool IsLocal() const { return true; }
|
||||
|
||||
// Informs if this Device can be used as a caller in RemoteCall operation.
|
||||
virtual bool IsRemoteCallAllowed() const;
|
||||
|
||||
protected:
|
||||
void DeleteResourceMgr() {
|
||||
delete rmgr_;
|
||||
|
Loading…
Reference in New Issue
Block a user