replace PFLR DeviceGetContext hardcode with Device::IsRemoteCallAllowed

This delegates decision of whether a particular device type
can host a remote call to the device implementation itself.

Change-Id: I189e9d965ade7386b0c24da6c90d369b1ed72c3e
This commit is contained in:
Pawel Piskorski 2020-08-31 19:12:57 +02:00
parent 270db009cf
commit 35d8bef702
6 changed files with 30 additions and 1 deletions

View File

@ -53,6 +53,8 @@ class CompositeDevice : public Device {
const std::vector<string>& underlying_devices, const string& device_name,
Status* status);
bool IsRemoteCallAllowed() const override { return false; }
private:
CompositeDevice(const DeviceAttributes& device_attributes,
const std::vector<string>& underlying_devices)

View File

@ -200,13 +200,15 @@ Status ProcessFunctionLibraryRuntime::GetDeviceContext(
// "TPU_SYSTEM" indicates that `device` is a CPU.
return Status::OK();
}
if (device_type == "GPU" || device_type == "TPU") {
if (device->IsRemoteCallAllowed()) {
auto* dev_info = flr->device()->tensorflow_gpu_device_info();
if (dev_info) {
*device_context = dev_info->default_context;
return Status::OK();
}
}
return errors::Internal("Device type: ", device_type,
" is currently unsupported for remote ",
"function executions");

View File

@ -149,6 +149,8 @@ class RenamedDevice : public Device {
bool IsLocal() const override { return underlying_device_->IsLocal(); }
bool IsRemoteCallAllowed() const override { return underlying_device_->IsRemoteCallAllowed(); }
private:
RenamedDevice(Device* underlying, const DeviceAttributes& attributes,
bool owns_underlying, bool isolate_session_state,

View File

@ -50,6 +50,8 @@ class RemoteDevice : public Device {
bool IsLocal() const override { return false; }
bool IsRemoteCallAllowed() const override { return true; }
private:
const string local_dev_name_;

View File

@ -54,4 +54,22 @@ DeviceAttributes Device::BuildDeviceAttributes(
return da;
}
bool Device::IsRemoteCallAllowed() const {
auto &type = parsed_name_.type;
if (type == "TPU") {
return true;
}
if (type == "TPU_SYSTEM") {
return true;
}
if (type == "CPU") {
return true;
}
if (type == "GPU") {
return true;
}
return false;
}
} // namespace tensorflow

View File

@ -178,6 +178,9 @@ class Device : public DeviceBase {
virtual bool IsLocal() const { return true; }
// Informs if this Device can be used as a caller in RemoteCall operation.
virtual bool IsRemoteCallAllowed() const;
protected:
void DeleteResourceMgr() {
delete rmgr_;